!8613 [MSLITE] add global format trans

From: @zhengjun10
Reviewed-by: @HilbertDavid,@hangangqiang
Signed-off-by: @HilbertDavid
pull/8613/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 576e6d1577

@ -601,6 +601,138 @@ STATUS ValidateFileStr(const std::string &modelFile, std::string fileType) {
}
}
void TransformAttrByAxes(int *origin_attr, int *axes, int element_size) {
if (origin_attr == nullptr || axes == nullptr || element_size == 0) {
MS_LOG(INFO) << "Attr data is from other nodes.";
return;
}
auto axis_map = GetNc2NhAxisMap();
std::vector<int> cur_attr;
for (int dim = 0; dim < 4; ++dim) {
for (int index = 0; index < element_size; ++index) {
int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]];
if (nhwc_dim == dim || (nhwc_dim + 4) == dim) {
cur_attr.push_back(origin_attr[index]);
}
}
}
for (int index = 0; index < element_size; ++index) {
origin_attr[index] = cur_attr[index];
}
}
STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) {
auto type = node->primitive->value.type;
if (type == schema::PrimitiveType_StridedSlice) {
// onnx input size is equal to 5 always.
if (node->inputIndex.size() == 5) {
for (int index = 1; index < 5; ++index) {
if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) {
MS_LOG(INFO) << "Here don't consider input is from other nodes.";
return RET_NOT_SUPPORT;
}
}
int element_num = graph->allTensors[node->inputIndex[1]]->dims[0];
auto axes = graph->allTensors[node->inputIndex[3]]->data;
for (int index = 1; index < 5; ++index) {
TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()),
reinterpret_cast<int *>(axes.data()), element_num);
}
}
}
if (type == schema::PrimitiveType_Slice) {
auto attr = node->primitive->value.AsSlice();
if (attr == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr.";
return RET_NULL_PTR;
}
// transform attr
attr->format = schema::Format_NHWC;
if (attr->begin.empty() || attr->size.empty()) {
MS_LOG(INFO) << "Here don't consider these attr are from other nodes.";
return RET_NOT_SUPPORT;
}
int element_num = attr->begin.size();
if (attr->axes.empty()) {
for (int index = 0; index < element_num; ++index) {
attr->axes.push_back(index);
}
}
TransformAttrByAxes(attr->begin.data(), attr->axes.data(), element_num);
TransformAttrByAxes(attr->size.data(), attr->axes.data(), element_num);
TransformAttrByAxes(attr->axes.data(), attr->axes.data(), element_num);
}
return RET_OK;
}
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) {
MS_ASSERT(node->primitive->value != nullptr);
auto type = node->primitive->value.type;
auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size();
if (input1_ndim != 4 && input1_ndim != 0) {
if (node->inputIndex.size() > 1) {
auto input2_ndim = graph->allTensors.at(node->inputIndex[1])->dims.size();
if (input2_ndim != 4 && input2_ndim != 0) {
MS_LOG(ERROR) << "change op axis only support 4 dims";
return RET_NOT_SUPPORT;
}
} else {
MS_LOG(ERROR) << "change op axis only support 4 dims";
return RET_NOT_SUPPORT;
}
}
if (type == schema::PrimitiveType_Concat) {
MS_ASSERT(node->primitive->value.AsConcat() != nullptr);
auto origin_axis = node->primitive->value.AsConcat()->axis;
auto axis_map = GetNc2NhAxisMap();
if (node->primitive->value.AsConcat() == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr";
return RET_NULL_PTR;
}
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
}
if (type == schema::PrimitiveType_Split) {
MS_ASSERT(node->primitive->value.AsSplit() != nullptr);
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
auto axis_map = GetNc2NhAxisMap();
if (node->primitive->value.AsSplit() == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsSplit() is nullptr";
return RET_NULL_PTR;
}
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
}
if (type == schema::PrimitiveType_Crop) {
MS_ASSERT(node->primitive->value.AsCrop() != nullptr);
auto origin_axis = node->primitive->value.AsCrop()->axis;
auto offsets = node->primitive->value.AsCrop()->offsets;
auto axis_map = GetNc2NhAxisMap();
if (node->primitive->value.AsCrop() == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr";
return RET_NULL_PTR;
}
node->primitive->value.AsCrop()->axis = axis_map[origin_axis];
// nchw->nhwc,offsets need pad 0;
if (axis_map[origin_axis] == 0) {
offsets = {offsets[0], offsets[2], offsets[3], offsets[1]};
} else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) {
// orgin_axis = 2 or orgin_axis = 3
offsets.push_back(0);
} else if (axis_map[origin_axis] == -1) {
// origin_axis = 1
offsets = {offsets[1], offsets[2], offsets[0]};
} else {
// axis error
MS_LOG(ERROR) << "Crop error";
return RET_ERROR;
}
node->primitive->value.AsCrop()->offsets = offsets;
}
if (type == schema::PrimitiveType_Slice || type == schema::PrimitiveType_StridedSlice) {
return ChangeOpAttrForSlice(graph, node);
}
return RET_OK;
}
std::string GetModelName(const std::string &modelFile) {
std::string modelName = modelFile;
modelName = modelName.substr(modelName.find_last_of('/') + 1);

@ -86,6 +86,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer);
STATUS ValidateFileStr(const std::string &modelFile, std::string fileType);
void TransformAttrByAxes(int *origin_attr, int *axes, int element_size);
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);
STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);
std::string GetModelName(const std::string &modelFile);
} // namespace lite
} // namespace mindspore

@ -139,7 +139,8 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT
static const std::vector<schema::PrimitiveType> needInsertOpList = {
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add,
schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop};
schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop,
schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum};
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};

@ -28,6 +28,7 @@
#include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h"
#include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h"
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
@ -114,6 +115,10 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
if (ctx.trainModel == false && ctx.fmk != converter::FmkType_ONNX) {
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
}
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";

@ -12,6 +12,7 @@ file(GLOB GRAPH_PASS
${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc
)
set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)

@ -0,0 +1,197 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h"
#include <algorithm>
#include "third_party/securec/include/securec.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "tools/common/node_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
std::set<size_t> need_del_nodes;
std::set<size_t> need_trans_format_nodes;
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto type = node->primitive->value.type;
if (type != schema::PrimitiveType_Nchw2Nhwc) {
continue;
}
std::vector<size_t> pre_nh2nc_nodes;
std::vector<size_t> pre_not_trans_nodes;
auto status = FindPreNh2NcNodes(graph, iter - graph->nodes.begin(), &pre_nh2nc_nodes, &pre_not_trans_nodes);
if (status != RET_OK) {
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
return status;
}
std::copy(pre_nh2nc_nodes.begin(), pre_nh2nc_nodes.end(), std::inserter(need_del_nodes, need_del_nodes.end()));
std::copy(pre_not_trans_nodes.begin(), pre_not_trans_nodes.end(),
std::inserter(need_trans_format_nodes, need_trans_format_nodes.end()));
if (!pre_nh2nc_nodes.empty()) {
need_del_nodes.insert(iter - graph->nodes.begin());
}
}
if (need_del_nodes.empty()) {
return RET_OK;
}
for (auto del_node_index : need_del_nodes) {
auto node_name = graph->nodes.at(del_node_index)->name;
auto status = IsolateOneWayNode(graph, del_node_index);
if (status != RET_OK) {
MS_LOG(ERROR) << "Isolate Node failed, node: " << node_name << ", error: " << status;
return status;
}
}
auto status = TransWeightToNhwc(graph, need_trans_format_nodes);
if (status != RET_OK) {
MS_LOG(ERROR) << "trans weight to nhwc failed";
return status;
}
return RET_OK;
}
STATUS ConvertNcTensor2Nh(TensorT *tensor, const std::vector<int> &pad_dims) {
if (pad_dims.size() != 4) {
MS_LOG(ERROR) << "pad dims error";
return RET_ERROR;
}
auto batch = pad_dims[NCHW_N];
auto channel = pad_dims[NCHW_C];
auto area = pad_dims[NCHW_H] * pad_dims[NCHW_W];
auto size = batch * channel * area;
auto new_nhwc_data = new (std::nothrow) float[size];
if (new_nhwc_data == nullptr) {
MS_LOG(ERROR) << "create new nhwc data failed";
delete[] new_nhwc_data;
return RET_ERROR;
}
memset(new_nhwc_data, 0, sizeof(float) * size);
auto nchw_data = reinterpret_cast<float *>(tensor->data.data());
// nchw to nhwc
for (auto i = 0; i < batch; i++) {
float *src_batch = nchw_data + i * channel * area;
float *dst_batch = new_nhwc_data + i * channel * area;
for (int j = 0; j < area; ++j) {
float *src_area = src_batch + i;
float *dst_area = dst_batch + i * channel;
for (int k = 0; k < channel; ++k) {
dst_area[k] = src_area[k * area];
}
}
}
memcpy(nchw_data, new_nhwc_data, sizeof(float) * size);
delete[] new_nhwc_data;
return RET_OK;
}
STATUS GlobalFormatTransformPass::TransWeightToNhwc(MetaGraphT *graph, const std::set<size_t> &pre_not_trans_nodes) {
if (pre_not_trans_nodes.empty()) {
return RET_OK;
}
for (auto index : pre_not_trans_nodes) {
auto &cur_node = graph->nodes.at(index);
// need change axis from nchw to nhwc like concat,slice
auto ret = ChangeOpAxis(graph, cur_node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ChangeOpAxis error";
return ret;
}
auto node_input_indexs = cur_node->inputIndex;
for (auto input_index : node_input_indexs) {
// weight data need trans nhwc layerout
if (!IsContain(graph->inputIndex, input_index) &&
graph->allTensors.at(input_index)->nodeType == NodeType_ValueNode) {
auto &weight_tensor = graph->allTensors.at(input_index);
auto origin_dims = weight_tensor->dims;
weight_tensor->format = Format_NHWC;
if (origin_dims.size() > 4) {
MS_LOG(ERROR) << "tensor origin tensor size error";
return RET_ERROR;
}
if (origin_dims.size() == 0) {
continue;
}
auto pad_dims = origin_dims;
if (origin_dims.size() == 1) {
pad_dims = {1, 1, 1, origin_dims[0]};
} else if (origin_dims.size() == 2) {
pad_dims = {1, 1, origin_dims[0], origin_dims[1]};
} else if (origin_dims.size() == 3) {
pad_dims = {1, origin_dims[0], origin_dims[1], origin_dims[2]};
}
if (ConvertNcTensor2Nh(weight_tensor.get(), pad_dims) != RET_OK) {
MS_LOG(ERROR) << "Convert nchw to nhwc failed";
return RET_ERROR;
}
weight_tensor->dims = {pad_dims[NCHW_N], pad_dims[NCHW_H], pad_dims[NCHW_W], pad_dims[NCHW_C]};
}
}
}
return RET_OK;
}
STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index,
std::vector<size_t> *pre_nh2nc_nodes,
std::vector<size_t> *pre_not_trans_nodes) {
MS_ASSERT(graph != nullptr);
std::vector<size_t> bfs_queue = {nc2nh_index};
// find pre node nh2nc start nodes
while (!bfs_queue.empty()) {
auto cur_node_index = bfs_queue.back();
auto &cur_node = graph->nodes.at(cur_node_index);
bfs_queue.pop_back();
auto input_node_indexes = GetInputNodeIdx(*graph, *cur_node);
for (auto input_node_index : input_node_indexes) {
MS_ASSERT(graph->nodes.size() > input_node_index);
auto &pre_node = graph->nodes.at(input_node_index);
MS_ASSERT(pre_node != nullptr);
auto node_type = pre_node->primitive->value.type;
if (node_type == schema::PrimitiveType_Nhwc2Nchw) {
if (!IsContain(*pre_nh2nc_nodes, input_node_index)) {
pre_nh2nc_nodes->emplace_back(input_node_index);
}
} else if (IsContain(GetInsertOpList(), node_type)) {
if (!IsContain(bfs_queue, input_node_index)) {
bfs_queue.emplace_back(input_node_index);
}
// todo multi output,other edge need insert nh2nc node
auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node);
if ((pre_node_output_indexs.size() != 1) && (node_type == schema::PrimitiveType_Activation)) {
pre_nh2nc_nodes->clear();
pre_not_trans_nodes->clear();
return RET_OK;
}
} else {
pre_nh2nc_nodes->clear();
pre_not_trans_nodes->clear();
return RET_OK;
}
if (!IsContain(*pre_not_trans_nodes, cur_node_index) && cur_node_index != nc2nh_index) {
pre_not_trans_nodes->emplace_back(cur_node_index);
}
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H
#define MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H
#include <unordered_map>
#include <set>
#include <vector>
#include <memory>
#include <string>
#include <utility>
#include "tools/common/graph_util.h"
#include "tools/converter/optimizer.h"
using mindspore::schema::TensorT;
namespace mindspore {
namespace lite {
class GlobalFormatTransformPass : public GraphPass {
public:
GlobalFormatTransformPass() = default;
~GlobalFormatTransformPass() = default;
STATUS Run(MetaGraphT *graph) override;
protected:
STATUS TransWeightToNhwc(MetaGraphT *graph, const std::set<size_t> &pre_not_trans_nodes);
STATUS FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index, std::vector<size_t> *to_do_insert_nodes,
std::vector<size_t> *pre_not_trans_nodes);
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H

@ -127,146 +127,6 @@ STATUS TransOpInsertPass::FindOutTransType() {
return RET_OK;
}
void TransOpInsertPass::TransformAttrByAxes(int *origin_attr, int *axes, int element_size) {
if (origin_attr == nullptr || axes == nullptr || element_size == 0) {
MS_LOG(INFO) << "Attr data is from other nodes.";
return;
}
auto axis_map = GetNc2NhAxisMap();
std::vector<int> cur_attr;
for (int dim = 0; dim < 4; ++dim) {
for (int index = 0; index < element_size; ++index) {
int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]];
if (nhwc_dim == dim || (nhwc_dim + 4) == dim) {
cur_attr.push_back(origin_attr[index]);
}
}
}
for (int index = 0; index < element_size; ++index) {
origin_attr[index] = cur_attr[index];
}
}
STATUS TransOpInsertPass::ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
if (node == nullptr && node->primitive == nullptr) {
MS_LOG(ERROR) << "node or primitive null";
return RET_NULL_PTR;
}
auto type = node->primitive->value.type;
if (type == PrimitiveType_StridedSlice) {
// onnx input size is equal to 5 always.
if (node->inputIndex.size() == 5) {
for (int index = 1; index < 5; ++index) {
if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) {
MS_LOG(INFO) << "Here don't consider input is from other nodes.";
return RET_NOT_SUPPORT;
}
}
int element_num = graph->allTensors[node->inputIndex[1]]->dims[0];
auto axes = graph->allTensors[node->inputIndex[3]]->data;
for (int index = 1; index < 5; ++index) {
TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()),
reinterpret_cast<int *>(axes.data()), element_num);
}
}
}
if (type == PrimitiveType_Slice) {
auto attr = node->primitive->value.AsSlice();
if (attr == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr.";
return RET_NULL_PTR;
}
// transform attr
attr->format = schema::Format_NHWC;
if (attr->begin.empty() || attr->size.empty()) {
MS_LOG(INFO) << "Here don't consider these attr are from other nodes.";
return RET_NOT_SUPPORT;
}
int element_num = attr->begin.size();
if (attr->axes.empty()) {
for (int index = 0; index < element_num; ++index) {
attr->axes.push_back(index);
}
}
TransformAttrByAxes(attr->begin.data(), attr->axes.data(), element_num);
TransformAttrByAxes(attr->size.data(), attr->axes.data(), element_num);
TransformAttrByAxes(attr->axes.data(), attr->axes.data(), element_num);
}
return RET_OK;
}
STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
if (node == nullptr && node->primitive == nullptr) {
MS_LOG(ERROR) << "node or primitive null";
return RET_NULL_PTR;
}
MS_ASSERT(node->primitive->value != nullptr);
auto type = node->primitive->value.type;
auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size();
if (input1_ndim != 4) {
if (node->inputIndex.size() > 1) {
auto input2_ndim = graph->allTensors.at(node->inputIndex[1])->dims.size();
if (input2_ndim != 4 && input2_ndim != 0) {
MS_LOG(ERROR) << "change op axis only support 4 dims";
return RET_NOT_SUPPORT;
}
} else {
MS_LOG(ERROR) << "change op axis only support 4 dims";
return RET_NOT_SUPPORT;
}
}
if (type == PrimitiveType_Concat) {
MS_ASSERT(node->primitive->value.AsConcat() != nullptr);
auto origin_axis = node->primitive->value.AsConcat()->axis;
auto axis_map = GetNc2NhAxisMap();
if (node->primitive->value.AsConcat() == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr";
return RET_NULL_PTR;
}
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
}
if (type == PrimitiveType_Split) {
MS_ASSERT(node->primitive->value.AsSplit() != nullptr);
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
auto axis_map = GetNc2NhAxisMap();
if (node->primitive->value.AsSplit() == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsSplit() is nullptr";
return RET_NULL_PTR;
}
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
}
if (type == PrimitiveType_Crop) {
MS_ASSERT(node->primitive->value.AsCrop() != nullptr);
auto origin_axis = node->primitive->value.AsCrop()->axis;
auto offsets = node->primitive->value.AsCrop()->offsets;
auto axis_map = GetNc2NhAxisMap();
if (node->primitive->value.AsCrop() == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr";
return RET_NULL_PTR;
}
node->primitive->value.AsCrop()->axis = axis_map[origin_axis];
// nchw->nhwc,offsets need pad 0;
if (axis_map[origin_axis] == 0) {
offsets = {offsets[0], offsets[2], offsets[3], offsets[1]};
} else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) {
// orgin_axis = 2 or orgin_axis = 3
offsets.push_back(0);
} else if (axis_map[origin_axis] == -1) {
// origin_axis = 1
offsets = {offsets[1], offsets[2], offsets[0]};
} else {
// axis error
MS_LOG(ERROR) << "Crop error";
return RET_ERROR;
}
node->primitive->value.AsCrop()->offsets = offsets;
}
if (type == PrimitiveType_Slice || type == PrimitiveType_StridedSlice) {
return ChangeOpAttrForSlice(graph, node);
}
return RET_OK;
}
STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
bool changed = true;

@ -41,8 +41,6 @@ class TransOpInsertPass : public FormatTransPass {
STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node);
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node);
private:
FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW;
FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW;

Loading…
Cancel
Save