diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index a66cd7be38..f3d7903333 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -553,7 +553,6 @@ table Slice { axes: [int]; begin: [int]; size: [int]; - step: [int]; } table Floor { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc index 4c48f24330..aaf9324392 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc @@ -137,6 +137,13 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor"; return RET_ERROR; } + void *attr = node->primitive->value.value; + if (node->primitive->value.type == schema::PrimitiveType_SpaceToDepth) { + reinterpret_cast(attr)->format = schema::Format_NHWC; + } + if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) { + reinterpret_cast(attr)->format = schema::Format_NHWC; + } STATUS status = RET_OK; #ifdef SUPPORT_TRAIN if (IsContain(GetNhwcAllInputOpList(), GetCNodeTType(**iter))) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc index d0829f2916..fb8ae53fcb 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc @@ -125,6 +125,74 @@ STATUS TransOpInsertPass::FindOutTransType() { return RET_OK; } +void TransOpInsertPass::TransformAttrByAxes(int *origin_attr, int *axes, int element_size) { + if (origin_attr == nullptr || axes == nullptr || element_size == 0) { + MS_LOG(INFO) << "Attr data is from other nodes."; + return; + } + auto axis_map = GetNc2NhAxisMap(); + std::vector cur_attr; + for (int dim = 0; dim < 4; ++dim) { + for (int index = 0; index < element_size; ++index) { + int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]]; + if (nhwc_dim == dim || (nhwc_dim + 4) == dim) { + cur_attr.push_back(origin_attr[index]); + } + } + } + for (int index = 0; index < element_size; ++index) { + origin_attr[index] = cur_attr[index]; + } +} + +STATUS TransOpInsertPass::ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr &node) { + if (node == nullptr && node->primitive == nullptr) { + MS_LOG(ERROR) << "node or primitive null"; + return RET_NULL_PTR; + } + auto type = node->primitive->value.type; + if (type == PrimitiveType_StridedSlice) { + // onnx input size is equal to 5 always. + if (node->inputIndex.size() == 5) { + for (int index = 1; index < 5; ++index) { + if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) { + MS_LOG(INFO) << "Here don't consider input is from other nodes."; + return RET_NOT_SUPPORT; + } + } + int element_num = graph->allTensors[node->inputIndex[1]]->dims[0]; + auto axes = graph->allTensors[node->inputIndex[3]]->data; + for (int index = 1; index < 5; ++index) { + TransformAttrByAxes(reinterpret_cast(graph->allTensors[node->inputIndex[index]]->data.data()), + reinterpret_cast(axes.data()), element_num); + } + } + } + if (type == PrimitiveType_Slice) { + auto attr = node->primitive->value.AsSlice(); + if (attr == nullptr) { + MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr."; + return RET_NULL_PTR; + } + // transform attr + attr->format = schema::Format_NHWC; + if (attr->begin.empty() || attr->size.empty()) { + MS_LOG(INFO) << "Here don't consider these attr are from other nodes."; + return RET_NOT_SUPPORT; + } + int element_num = attr->begin.size(); + if (attr->axes.empty()) { + for (int index = 0; index < element_num; ++index) { + attr->axes.push_back(index); + } + } + TransformAttrByAxes(attr->begin.data(), attr->axes.data(), element_num); + TransformAttrByAxes(attr->size.data(), attr->axes.data(), element_num); + TransformAttrByAxes(attr->axes.data(), attr->axes.data(), element_num); + } + return RET_OK; +} + STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr &node) { if (node == nullptr && node->primitive == nullptr) { MS_LOG(ERROR) << "node or primitive null"; @@ -153,19 +221,6 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni } node->primitive->value.AsConcat()->axis = axis_map[origin_axis]; } - if (type == PrimitiveType_StridedSlice) { - auto attr = node->primitive->value.AsStridedSlice(); - if (attr == nullptr) { - MS_LOG(ERROR) << "node->primitive->value.AsStridedSlice() is nullptr"; - return RET_NULL_PTR; - } - auto origin_begin = attr->begin; - attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; - auto origin_end = attr->end; - attr->end = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; - auto origin_stride = attr->stride; - attr->stride = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; - } if (type == PrimitiveType_Split) { auto origin_axis = node->primitive->value.AsSplit()->splitDim; auto axis_map = GetNc2NhAxisMap(); @@ -200,20 +255,8 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni } node->primitive->value.AsCrop()->offsets = offsets; } - if (type == PrimitiveType_Slice) { - auto attr = node->primitive->value.AsSlice(); - if (attr == nullptr) { - MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr"; - return RET_NULL_PTR; - } - auto origin_begin = attr->begin; - attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; - auto origin_end = attr->axes; - if (origin_end.size() >= 4) { - attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; - } - auto origin_stride = attr->size; - attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; + if (type == PrimitiveType_Slice || type == PrimitiveType_StridedSlice) { + return ChangeOpAttrForSlice(graph, node); } return RET_OK; } @@ -246,7 +289,7 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { } ret = ChangeOpAxis(graph, node); if (ret != RET_OK) { - MS_LOG(ERROR) << "ChangeOpAxis error"; + MS_LOG(INFO) << "no need to ChangeOpAxis"; return ret; } has_insert_nodes.push_back(node.get()); @@ -258,6 +301,10 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed"; return status; } + if ((*iter)->primitive->value.type == schema::PrimitiveType_StridedSlice || + (*iter)->primitive->value.type == schema::PrimitiveType_Slice) { + break; + } } auto output_tensor_size = (*iter)->outputIndex.size(); for (size_t i = 0; i < output_tensor_size; i++) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h index 5304632162..f918904df9 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h @@ -37,6 +37,10 @@ class TransOpInsertPass : public FormatTransPass { STATUS FindOutTransType(); + void TransformAttrByAxes(int *origin_attr, int *axes, int element_size); + + STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr &node); + STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr &node); private: