From 12eabf9e16da5476f75c9b9ed86f7f20666faada Mon Sep 17 00:00:00 2001 From: yankai Date: Fri, 30 Oct 2020 09:14:03 +0800 Subject: [PATCH] fix int8transpose parser --- mindspore/lite/src/ops/primitive_c.cc | 4 + mindspore/lite/test/CMakeLists.txt | 1 + mindspore/lite/tools/common/flag_parser.cc | 2 +- mindspore/lite/tools/converter/CMakeLists.txt | 1 + .../lite/tools/converter/anf_transform.cc | 14 +++ .../parser/onnx/onnx_transpose_parser.cc | 1 + .../unused_transpose_node_remove_pass.cc | 90 +++++++++++++++++++ .../graph/unused_transpose_node_remove_pass.h | 36 ++++++++ 8 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc create mode 100644 mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.h diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 5a1048e0bd..71f3235e48 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -824,6 +824,10 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new InstanceNorm(primitive); case schema::PrimitiveType_While: return new While(primitive); + case schema::PrimitiveType_OnnxInt8Quantize: + return new Quant(primitive); + case schema::PrimitiveType_OnnxInt8Dequantize: + return new Dequant(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 5f1772fe6e..4e692ae1d8 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -196,6 +196,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc + ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc ) endif() diff --git a/mindspore/lite/tools/common/flag_parser.cc b/mindspore/lite/tools/common/flag_parser.cc index 76bf598975..053261106a 100644 --- a/mindspore/lite/tools/common/flag_parser.cc +++ b/mindspore/lite/tools/common/flag_parser.cc @@ -152,7 +152,7 @@ std::string FlagParser::Usage(const Option &usgMsg) const { // first line, brief of the usage std::string usageString = usgMsg.IsSome() ? usgMsg.Get() + "\n" : ""; // usage of bin name - usageString += usageMsg.IsNone() ? "usage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; + usageString += usageMsg.IsNone() ? "\nusage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; // help line of help message, usageLine:message of parametors std::string helpLine = ""; std::string usageLine = ""; diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index b33bfd6f5b..624fcb8b57 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -47,6 +47,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/weight_format_hardcode_pass.cc ../optimizer/graph/clip_convert_activation_pass.cc ../optimizer/graph/unused_cast_node_remove_pass.cc + ../optimizer/graph/unused_transpose_node_remove_pass.cc ../optimizer/graph/identity_remove_pass.cc ) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 5c0cf91b0c..0ce7aff7fd 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -33,6 +33,7 @@ #include "tools/optimizer/graph/weight_format_transform_pass.h" #include "tools/optimizer/graph/clip_convert_activation_pass.h" #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" +#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/weight_quantizer.h" @@ -90,9 +91,22 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver if (config->fmk == lite::converter::FmkType_MS) { auto remove_unused_cast_pass = std::make_shared(); + if (remove_unused_cast_pass == nullptr) { + MS_LOG(ERROR) << "RemoveUnusedCastOpPass shoud be specified"; + return nullptr; + } remove_unused_cast_pass->SetFmkType(config->fmk); pm->AddPass(remove_unused_cast_pass); } + if (config->fmk == lite::converter::FmkType_ONNX) { + auto remove_unused_transpose_pass = std::make_shared(); + if (remove_unused_transpose_pass == nullptr) { + MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass shoud be specified"; + return nullptr; + } + remove_unused_transpose_pass->SetFmkType(config->fmk); + pm->AddPass(remove_unused_transpose_pass); + } pm->AddPass(std::make_shared()); convert_pm->AddPass(std::make_shared()); optimizer->AddPassManager(convert_pm); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc index 55c3600c22..2430d5a9cb 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc @@ -61,5 +61,6 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx } OnnxNodeRegistrar g_onnxTransposeParser("Transpose", new OnnxTransposeParser()); +OnnxNodeRegistrar g_onnxInt8TransposeParser("Int8Transpose", new OnnxTransposeParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc new file mode 100644 index 0000000000..72baed7f75 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" +#include +#include +#include "tools/optimizer/common/gllo_utils.h" +#include "mindspore/lite/include/errorcode.h" +#include "src/ops/primitive_c.h" + +namespace mindspore::opt { +static constexpr size_t kTransposeInput = 1; +const std::vector kPermNCHW{0, 3, 1, 2}; +const std::vector kPermNHWC{0, 2, 3, 1}; +void RemoveUnusedTransposeOpPass::SetFmkType(FmkType type) { this->fmk_type = type; } + +bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) { + if (this->fmk_type != lite::converter::FmkType_ONNX) { + MS_LOG(ERROR) << "The framework type of model should be onnx."; + return RET_ERROR; + } + MS_ASSERT(func_graph != nullptr); + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto type = opt::GetCNodeType(node); + if (type == schema::PrimitiveType_Transpose) { + auto transpose_cnode = node->cast(); + auto typeInput = opt::GetCNodeType(transpose_cnode->input(kTransposeInput)); + if (typeInput != schema::PrimitiveType_Conv2D) { + continue; + } + auto primPtr = GetValueNode>(transpose_cnode->input(0)); + if (primPtr == nullptr) { + MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC"; + return RET_ERROR; + } + auto primT = primPtr->GetPrimitiveT(); + if (primT == nullptr) { + MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC"; + return RET_ERROR; + } + std::vector perm = primT->value.AsTranspose()->perm; + if (perm == kPermNCHW) { + manager->Replace(transpose_cnode, transpose_cnode->input(1)); + } + } else if (type == schema::PrimitiveType_Conv2D) { + auto conv_node = node->cast(); + auto typeInput = opt::GetCNodeType(conv_node->input(kTransposeInput)); + if (typeInput != schema::PrimitiveType_Transpose) { + continue; + } + auto transpose_cnode = conv_node->input(kTransposeInput)->cast(); + auto primPtr = GetValueNode>(transpose_cnode->input(0)); + if (primPtr == nullptr) { + MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC"; + return RET_ERROR; + } + auto primT = primPtr->GetPrimitiveT(); + if (primT == nullptr) { + MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveT"; + return RET_ERROR; + } + std::vector perm = primT->value.AsTranspose()->perm; + if (perm == kPermNHWC) { + manager->Replace(transpose_cnode, transpose_cnode->input(1)); + } + } else { + continue; + } + } + return true; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.h b/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.h new file mode 100644 index 0000000000..7353b78a70 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_ +#include +#include "backend/optimizer/common/pass.h" +#include "tools/converter/converter_flags.h" + +using mindspore::lite::converter::FmkType; +namespace mindspore::opt { +class RemoveUnusedTransposeOpPass : public Pass { + public: + RemoveUnusedTransposeOpPass() : Pass("remove_unused_cast_pass") {} + ~RemoveUnusedTransposeOpPass() override = default; + void SetFmkType(FmkType fmkType); + bool Run(const FuncGraphPtr &graph) override; + + private: + FmkType fmk_type = lite::converter::FmkType_TF; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_