From 920cbb1e22e16e0e2a788580019ff825aa2ebe67 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Wed, 21 Oct 2020 09:16:45 +0800 Subject: [PATCH] add identity pass and adjust log --- mindspore/lite/include/errorcode.h | 5 ++ mindspore/lite/src/CMakeLists.txt | 1 + mindspore/lite/src/errorcode.cc | 48 +++++++++++++++++ mindspore/lite/src/ops/identity.h | 32 ++++++++++++ mindspore/lite/src/ops/primitive_c.cc | 3 ++ mindspore/lite/test/CMakeLists.txt | 2 + mindspore/lite/tools/converter/CMakeLists.txt | 2 + .../lite/tools/converter/anf_transform.cc | 6 +++ mindspore/lite/tools/converter/converter.cc | 8 ++- .../parser/onnx/onnx_model_parser.cc | 4 +- .../converter/parser/onnx/onnx_model_parser.h | 2 +- .../parser/onnx/onnx_reshape_parser.cc | 2 +- .../converter/quantizer/aware_quantizer.cc | 20 +++---- .../converter/quantizer/calc_quant_param.cc | 52 +++++++++---------- .../optimizer/graph/identity_remove_pass.cc | 48 +++++++++++++++++ .../optimizer/graph/identity_remove_pass.h | 36 +++++++++++++ 16 files changed, 230 insertions(+), 41 deletions(-) create mode 100644 mindspore/lite/src/errorcode.cc create mode 100644 mindspore/lite/src/ops/identity.h create mode 100644 mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc create mode 100644 mindspore/lite/tools/optimizer/graph/identity_remove_pass.h diff --git a/mindspore/lite/include/errorcode.h b/mindspore/lite/include/errorcode.h index 8d74c98385..72a51d7118 100644 --- a/mindspore/lite/include/errorcode.h +++ b/mindspore/lite/include/errorcode.h @@ -57,6 +57,11 @@ constexpr int RET_INFER_INVALID = -501; /**< Invalid infer shape before runtime. /* User input param error code, range: [-600, 700)*/ constexpr int RET_INPUT_PARAM_INVALID = -600; /**< Invalid input param by user. */ + +/// \brief Print description of errorcode. +/// +/// \param[in] error_code define return status of procedure. +void PrintErrorInfo(STATUS error_code); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index f418468b54..f6994dd21d 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -34,6 +34,7 @@ set(LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/model.cc + ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc ) if (SUPPORT_GPU) diff --git a/mindspore/lite/src/errorcode.cc b/mindspore/lite/src/errorcode.cc new file mode 100644 index 0000000000..e052af782d --- /dev/null +++ b/mindspore/lite/src/errorcode.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/errorcode.h" +#include +#include +#include + +namespace mindspore { +namespace lite { +void PrintErrorInfo(STATUS status) { + std::map info_map = {{RET_OK, "No error occurs."}, + {RET_ERROR, "Common error code."}, + {RET_NULL_PTR, "NULL pointer returned."}, + {RET_PARAM_INVALID, "Invalid parameter."}, + {RET_NO_CHANGE, "No change."}, + {RET_SUCCESS_EXIT, "No error but exit."}, + {RET_MEMORY_FAILED, "Fail to create memory."}, + {RET_NOT_SUPPORT, "Fail to support."}, + {RET_OUT_OF_TENSOR_RANGE, "Failed to check range."}, + {RET_INPUT_TENSOR_ERROR, "Failed to check input tensor."}, + {RET_REENTRANT_ERROR, "Exist executor running."}, + {RET_GRAPH_FILE_ERR, "Failed to verify graph file."}, + {RET_NOT_FIND_OP, "Failed to find operator."}, + {RET_INVALID_OP_NAME, "Invalid operator name."}, + {RET_INVALID_OP_ATTR, "Invalid operator attr."}, + {RET_OP_EXECUTE_FAILURE, "Failed to execution operator."}, + {RET_FORMAT_ERR, "Failed to checking tensor format."}, + {RET_INFER_ERR, "Failed to infer shape."}, + {RET_INFER_INVALID, "Invalid infer shape before runtime."}, + {RET_INPUT_PARAM_INVALID, "Invalid input param by user."}}; + std::cout << info_map[status] << std::endl; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/identity.h b/mindspore/lite/src/ops/identity.h new file mode 100644 index 0000000000..b58083edbd --- /dev/null +++ b/mindspore/lite/src/ops/identity.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/primitive_c.h" + +#ifndef LITE_MINDSPORE_LITE_C_OPS_IDENTITY_H_ +#define LITE_MINDSPORE_LITE_C_OPS_IDENTITY_H_ + +namespace mindspore { +namespace lite { +class Identity : public PrimitiveC { + public: + MS_DECLARE_PARENT(Identity, PrimitiveC); + Identity() = default; + explicit Identity(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +}; +} // namespace lite +} // namespace mindspore +#endif // LITE_MINDSPORE_LITE_C_OPS_IDENTITY_H_ diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index de430d0bac..60187883c3 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -137,6 +137,7 @@ #include "src/ops/upsample.h" #include "src/ops/layer_norm.h" #include "src/ops/non_max_suppression.h" +#include "src/ops/identity.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -729,6 +730,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new LayerNorm(primitive); case schema::PrimitiveType_NonMaxSuppression: return new NonMaxSuppression(primitive); + case schema::PrimitiveType_Identity: + return new Identity(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 1e4dda9b75..f424bea6d6 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -134,6 +134,7 @@ set(TEST_LITE_SRC ${LITE_DIR}/tools/common/storage.cc ${LITE_DIR}/tools/benchmark/benchmark.cc ${LITE_DIR}/test/st/benchmark_test.cc + ${LITE_DIR}/src/errorcode.cc ) ### gpu runtime if (SUPPORT_GPU) @@ -184,6 +185,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc + ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc ) endif() ### train diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index a94e90ab8b..6e8597ba08 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -46,6 +46,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/weight_format_hardcode_pass.cc ../optimizer/graph/clip_convert_activation_pass.cc ../optimizer/graph/unused_cast_node_remove_pass.cc + ../optimizer/graph/identity_remove_pass.cc ) add_subdirectory(../anf_importer anf_importer) @@ -75,6 +76,7 @@ set(LITE_SRC ${SRC_DIR}/executor.cc ${SRC_DIR}/model.cc ${SRC_DIR}/model_common.cc + ${SRC_DIR}/errorcode.cc ) if (SUPPORT_TRAIN) set(LITE_SRC diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index e2c352273b..93cdaaf47c 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -27,6 +27,7 @@ #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" #include "tools/optimizer/fusion/layer_norm_fusion.h" #include "tools/optimizer/fusion/batchmatmul_fusion.h" +#include "tools/optimizer/graph/identity_remove_pass.h" #include "tools/optimizer/graph/weight_format_hardcode_pass.h" #include "tools/optimizer/graph/weight_format_transform_pass.h" #include "tools/optimizer/graph/clip_convert_activation_pass.h" @@ -53,6 +54,11 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver // for now - trainning is not supporting fuse operations if (config != nullptr && config->trainModel == false) { // remove quantdtype when awaretraining + if (config->fmk == lite::converter::FmkType_ONNX) { + auto remove_identity_pass = std::make_shared(); + remove_identity_pass->SetFmkType(config->fmk); + pm->AddPass(remove_identity_pass); + } if (config->quantType == QuantType_AwareTraining) { pm->AddPass(std::make_shared()); } diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 9e4b2a8ae6..2ab80994e1 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -109,6 +109,7 @@ int RunConverter(int argc, const char **argv) { if (flags == nullptr) { MS_LOG(ERROR) << "new flags error "; std::cout << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << std::endl; + PrintErrorInfo(RET_MEMORY_FAILED); return RET_MEMORY_FAILED; } auto status = flags->Init(argc, argv); @@ -117,6 +118,7 @@ int RunConverter(int argc, const char **argv) { MS_LOG(ERROR) << "converter::Flags Init failed: " << status; std::cout << "CONVERTER::FLAGS INIT FAILED:" << status << std::endl; } + PrintErrorInfo(status); return status; } // Load graph @@ -148,6 +150,7 @@ int RunConverter(int argc, const char **argv) { default: { MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk; std::cout << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << std::endl; + PrintErrorInfo(RET_INPUT_PARAM_INVALID); return RET_INPUT_PARAM_INVALID; } } @@ -156,6 +159,7 @@ int RunConverter(int argc, const char **argv) { if (fb_graph == nullptr) { MS_LOG(ERROR) << "Convert model return nullptr"; std::cout << "CONVERT RESULT FAILED:" << status << std::endl; + PrintErrorInfo(status); return status; } @@ -163,15 +167,17 @@ int RunConverter(int argc, const char **argv) { Storage storage; fb_graph->version = Version(); status = storage.Save(*fb_graph, flags->outputFile); - if (status != 0) { + if (status != RET_OK) { MS_LOG(ERROR) << "Save graph to file failed"; std::cout << "SAVE GRAPH FAILED:" << status << std::endl; + PrintErrorInfo(status); return status; } delete fb_graph; MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!"; std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl; + PrintErrorInfo(status); return status; } } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 259fff072b..affba69b3c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -270,7 +270,7 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); interrupt = true; return RET_NOT_FIND_OP; - int status = ParseLoopAttr(dst_op, onnx_node, quantType, dst_graph); + int status = ParseSubgraph(dst_op, onnx_node, quantType, dst_graph); if (status != RET_OK || interrupt) { interrupt = true; return status; @@ -496,7 +496,7 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) } } -STATUS OnnxModelParser::ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, +STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType, schema::MetaGraphT *dst_graph) { MS_LOG(DEBUG) << "onnx LoopParser"; if (dst_op == nullptr) { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index d5f3b95b97..1fa17fc544 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -88,7 +88,7 @@ class OnnxModelParser : public ModelParser { void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); - STATUS ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType, + STATUS ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType, schema::MetaGraphT *dst_graph); private: diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc index f9fe53782e..81e9142732 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc @@ -61,7 +61,7 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: } } if (input_shape.int64_data_size() == 0) { - MS_LOG(WARNING) << "shape maybe from another op other than const initializer"; + MS_LOG(INFO) << "shape maybe from another op other than const initializer"; } else { for (int i = 0; i < input_shape.int64_data_size(); ++i) { shape.push_back(input_shape.int64_data(i)); diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index 9f3cd323bc..208d84232a 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -70,13 +70,13 @@ STATUS AwareQuantizer::GenerateQuantParam() { } auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); if (quantParamCalcer == nullptr) { - MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() - << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; + MS_LOG(INFO) << "Can not find QuantParamCalcer for " << node->name.c_str() + << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; node->quantType = static_cast(QuantType_QUANT_NONE); } else { auto status = quantParamCalcer->Calc(graph, *node); if (status != RET_OK) { - MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); + MS_LOG(INFO) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); node->quantType = schema::QuantType_QUANT_NONE; } else { node->quantType = schema::QuantType_AwareTraining; @@ -103,7 +103,7 @@ STATUS AwareQuantizer::DoQuantize() { GetCNodeTType(*node) == schema::PrimitiveType_MatMul) { auto inputIndexes = node->inputIndex; if (inputIndexes.size() < 2) { - MS_LOG(WARNING) << node->name.c_str() << " node input has invalid inputs tensor count"; + MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count"; return RET_ERROR; } // quant weight @@ -111,7 +111,7 @@ STATUS AwareQuantizer::DoQuantize() { if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) { status = QuantConvWeight(graph, node.get()); if (status != RET_OK) { - MS_LOG(WARNING) << "QuantConvWeight failed!"; + MS_LOG(ERROR) << "QuantConvWeight failed!"; return RET_ERROR; } } @@ -121,7 +121,7 @@ STATUS AwareQuantizer::DoQuantize() { if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) { status = QuantConvBias(graph, node.get()); if (status != RET_OK) { - MS_LOG(WARNING) << "QuantConvBias failed!"; + MS_LOG(ERROR) << "QuantConvBias failed!"; return RET_ERROR; } } @@ -129,7 +129,7 @@ STATUS AwareQuantizer::DoQuantize() { } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { status = QuantDetectionPostProcessConstTensor(graph, node.get()); if (status != RET_OK) { - MS_LOG(WARNING) << "QuantDetectionPostProcessConstTensor failed!"; + MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; return RET_ERROR; } } else if (GetCNodeTType(*node) == schema::PrimitiveType_Add || @@ -137,7 +137,7 @@ STATUS AwareQuantizer::DoQuantize() { GetCNodeTType(*node) == schema::PrimitiveType_Mul) { status = QuantArithmeticConstTensor(graph, node.get()); if (status != RET_OK) { - MS_LOG(WARNING) << "QuantArithmeticConstTensor failed!"; + MS_LOG(ERROR) << "QuantArithmeticConstTensor failed!"; return RET_ERROR; } } @@ -168,7 +168,7 @@ STATUS AwareQuantizer::QuantArithmeticConstTensor(const schema::MetaGraphT *grap } if (inTensor->dataType != TypeId::kNumberTypeFloat32 && inTensor->dataType != TypeId::kNumberTypeFloat && inTensor->dataType != TypeId::kNumberTypeUInt8) { - MS_LOG(WARNING) << node->name.c_str() << "'s weight data is not float or uint8"; + MS_LOG(ERROR) << node->name.c_str() << "'s weight data is not float or uint8"; return RET_ERROR; } @@ -303,7 +303,7 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem } if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat && weightTensor->dataType != TypeId::kNumberTypeUInt8) { - MS_LOG(WARNING) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; + MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; return RET_ERROR; } size_t wShapeSize = GetShapeSize(*(weightTensor.get())); diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc index 44c7f2acc2..79afd5f5d2 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc @@ -33,7 +33,7 @@ STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, Q return RET_OK; } if (tensor.dataType != TypeId::kNumberTypeFloat) { - MS_LOG(WARNING) << "Const Tensor without quantParam should has float dataType, in fact: " << tensor.dataType; + MS_LOG(ERROR) << "Const Tensor without quantParam should has float dataType, in fact: " << tensor.dataType; return RET_ERROR; } const auto *constData = reinterpret_cast(tensor.data.data()); @@ -83,7 +83,7 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { if (!tensor->data.empty() && !IsContain(graph->inputIndex, node.inputIndex.at(i))) { auto status = ComputeConstQuantParam((*tensor), quantParam.get()); if (status != RET_OK) { - MS_LOG(WARNING) << "ComputeConstQuantParam failed: " << status; + MS_LOG(INFO) << "ComputeConstQuantParam failed: " << status; return status; } tensor->quantParams.front() = std::move(quantParam); @@ -112,15 +112,15 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { auto status = QuantParamCalcer::Calc(subGraph, node); if (status != RET_OK) { - MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status; + MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: " << status; return status; } if (inputParamDone != node.inputIndex.size()) { - MS_LOG(WARNING) << "Can not determine inputTensor quantParam, node " << node.name; + MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name; return RET_ERROR; } if (outputParamDone != node.outputIndex.size()) { - MS_LOG(WARNING) << "Can not determine outputTensor quantParam, node " << node.name; + MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name; return RET_ERROR; } return RET_OK; @@ -129,7 +129,7 @@ int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { auto status = QuantParamCalcer::Calc(graph, node); if (status != RET_OK) { - MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status; + MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: " << status; return status; } if (inputParamDone != node.inputIndex.size()) { @@ -139,7 +139,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { auto outputQuantParam = GetTensorQuantParam(outTensor); MS_ASSERT(outputQuantParam != nullptr); if (outputQuantParam == nullptr || !outputQuantParam->inited) { - MS_LOG(WARNING) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name; + MS_LOG(ERROR) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name; return RET_ERROR; } for (unsigned int i : node.inputIndex) { @@ -159,7 +159,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { MS_ASSERT(inTensor != nullptr); auto inQuantParam = GetTensorQuantParam(inTensor); if (inQuantParam == nullptr || !inQuantParam->inited) { - MS_LOG(WARNING) << "Can not determine outputTensor quantParam from inputTensor for node %s" << node.name; + MS_LOG(ERROR) << "Can not determine outputTensor quantParam from inputTensor for node %s" << node.name; return RET_ERROR; } for (size_t i = 0; i < node.outputIndex.size(); i++) { @@ -188,12 +188,12 @@ class CalcConcat : public QuantParamCalcer { MS_ASSERT(node.outputIndex.size() == 1); auto status = QuantParamCalcer::Calc(graph, node); if (status != RET_OK) { - MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status; + MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: " << status; return status; } if (inputParamDone != node.inputIndex.size()) { - MS_LOG(WARNING) << "Can not determine concat inputTensor quantParam, node " << node.name; + MS_LOG(ERROR) << "Can not determine concat inputTensor quantParam, node " << node.name; return RET_ERROR; } @@ -233,7 +233,7 @@ class CalcConcat : public QuantParamCalcer { status = quant::CalQuantizationParams(outQuantParam.get(), minMin, maxMax, narrowRange, numBits); if (status != RET_OK) { - MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!"; + MS_LOG(ERROR) << "in aware quantization run CalQuantizationParams failed!"; return RET_ERROR; } outTensor->quantParams.emplace_back(std::move(outQuantParam)); @@ -253,12 +253,12 @@ class CalcAdd : public QuantParamCalcer { MS_ASSERT(node.outputIndex.size() == 1); auto status = QuantParamCalcer::Calc(graph, node); if (status != RET_OK) { - MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status; + MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: " << status; return status; } if (inputParamDone != 2) { - MS_LOG(WARNING) << "Can not determine add inputTensor quantParam, node " << node.name; + MS_LOG(ERROR) << "Can not determine add inputTensor quantParam, node " << node.name; return RET_ERROR; } if (outputParamDone != 1) { @@ -283,7 +283,7 @@ class CalcAdd : public QuantParamCalcer { biasTensor = &tensor1; paramTensor = &tensor0; } else { - MS_LOG(WARNING) << "Can not determine add outputTensor quantParam, node " << node.name; + MS_LOG(ERROR) << "Can not determine add outputTensor quantParam, node " << node.name; return RET_ERROR; } auto quantParam = GetTensorQuantParam(*paramTensor); @@ -298,7 +298,7 @@ class CalcAdd : public QuantParamCalcer { auto *bias = static_cast(oriTensorData); status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias)); if (status != RET_OK) { - MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!"; + MS_LOG(ERROR) << "in aware quantization run CalQuantizationParams failed!"; return RET_ERROR; } } else if ((*biasTensor)->dataType == TypeId::kNumberTypeUInt8) { @@ -307,11 +307,11 @@ class CalcAdd : public QuantParamCalcer { auto *bias = static_cast(oriTensorData); status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias)); if (status != RET_OK) { - MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!"; + MS_LOG(ERROR) << "in aware quantization run CalQuantizationParams failed!"; return RET_ERROR; } } else { - MS_LOG(WARNING) << "Unsupported tensor dataType: " << (*biasTensor)->dataType; + MS_LOG(ERROR) << "Unsupported tensor dataType: " << (*biasTensor)->dataType; return RET_ERROR; } } @@ -330,12 +330,12 @@ class CalcRealDiv : public QuantParamCalcer { MS_ASSERT(node.outputIndex.size() == 1); auto status = QuantParamCalcer::Calc(graph, node); if (status != RET_OK) { - MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status; + MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: " << status; return status; } if (inputParamDone != 2) { - MS_LOG(WARNING) << "Can not determine realdiv inputTensor quantParam, node " << node.name; + MS_LOG(ERROR) << "Can not determine realdiv inputTensor quantParam, node " << node.name; return RET_ERROR; } if (outputParamDone != 1) { @@ -361,7 +361,7 @@ class CalcRealDiv : public QuantParamCalcer { MS_ASSERT(*div != 0); status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max / (*div)); if (status != RET_OK) { - MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!"; + MS_LOG(ERROR) << "in aware quantization run CalQuantizationParams failed!"; return RET_ERROR; } } else if (tensor1->dataType == TypeId::kNumberTypeUInt8) { @@ -370,17 +370,17 @@ class CalcRealDiv : public QuantParamCalcer { auto *div = static_cast(oriTensorData); status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max + (*div)); if (status != RET_OK) { - MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!"; + MS_LOG(ERROR) << "in aware quantization run CalQuantizationParams failed!"; return RET_ERROR; } } else { - MS_LOG(WARNING) << "Unsupported tensor dataType: " << tensor1->dataType; + MS_LOG(ERROR) << "Unsupported tensor dataType: " << tensor1->dataType; return RET_ERROR; } outTensor->quantParams.front() = std::move(outQuantParam); } } else { - MS_LOG(WARNING) << "Can not determine realDiv outputTensor quantParam, node " << node.name; + MS_LOG(ERROR) << "Can not determine realDiv outputTensor quantParam, node " << node.name; return RET_ERROR; } } @@ -397,19 +397,19 @@ class CalcToSet : public QuantParamCalcer { MS_ASSERT(node.outputIndex.size() == 1); auto status = QuantParamCalcer::Calc(graph, node); if (status != RET_OK) { - MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: %d" << status; + MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: %d" << status; return status; } // input if (inputParamDone != node.inputIndex.size()) { - MS_LOG(WARNING) << "Can not determine inputTensor quantParam, node " << node.name; + MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name; return RET_ERROR; } // output if (outputParamDone != node.outputIndex.size()) { std::unique_ptr quantParam = std::make_unique(); if (quantParam == nullptr) { - MS_LOG(WARNING) << "new QuantParamT failed"; + MS_LOG(ERROR) << "new QuantParamT failed"; return RET_ERROR; } quantParam->scale = (max - min) / 256; diff --git a/mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc new file mode 100644 index 0000000000..43be64df65 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/identity_remove_pass.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "mindspore/lite/include/errorcode.h" +#include "src/ops/primitive_c.h" + +namespace mindspore::opt { +bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) { + if (this->fmk_type != lite::converter::FmkType_ONNX) { + MS_LOG(INFO) << "The framework type of model should be onnx."; + return RET_OK; + } + MS_ASSERT(func_graph != nullptr); + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto type = opt::GetCNodeType(node); + if (type != schema::PrimitiveType_Identity) { + continue; + } + auto identity_cnode = node->cast(); + if (identity_cnode->inputs().size() != lite::kDoubleNum) { + MS_LOG(ERROR) << "The `node input is a single tensor"; + return RET_ERROR; + } + manager->Replace(node, identity_cnode->input(1)); + } + return true; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/identity_remove_pass.h b/mindspore/lite/tools/optimizer/graph/identity_remove_pass.h new file mode 100644 index 0000000000..4a0ac14a65 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/identity_remove_pass.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ +#include +#include "backend/optimizer/common/pass.h" +#include "tools/converter/converter_flags.h" + +using mindspore::lite::converter::FmkType; +namespace mindspore::opt { +class RemoveIdentityOpPass : public Pass { + public: + RemoveIdentityOpPass() : Pass("remove_identity_pass") {} + ~RemoveIdentityOpPass() override = default; + void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; } + bool Run(const FuncGraphPtr &graph) override; + + private: + FmkType fmk_type = lite::converter::FmkType_ONNX; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_