From d73ab480a1accbc69d50ed232d87296e70fb17aa Mon Sep 17 00:00:00 2001 From: zengxianglong Date: Wed, 25 Nov 2020 20:33:26 +0800 Subject: [PATCH] add news models to the entrance guard and fix some bugs --- mindspore/lite/src/ops/conv2d.cc | 3 + mindspore/lite/test/CMakeLists.txt | 1 + mindspore/lite/test/models_mindspore.cfg | 1 + mindspore/lite/test/models_onnx_fp16.cfg | 1 + mindspore/lite/tools/converter/CMakeLists.txt | 1 + .../lite/tools/converter/anf_transform.cc | 4 +- .../graph/global_format_transform_pass.cc | 4 +- .../graph/update_conv2d_param_pass.cc | 64 +++++++++++++++++++ .../graph/update_conv2d_param_pass.h | 31 +++++++++ 9 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc create mode 100644 mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.h diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index 2e673b1aad..7c8d28eb3a 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -199,6 +199,9 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT attr->channelIn = dims[kAnfPopulaterInputNumOne]; } } + } else if (input_node->isa()) { + // The weight of convolution is the output from the other operators which could be folded by const folding pass. + attr->channelIn = -1; } primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 23ee1bb461..b30cfaa49b 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -196,6 +196,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc ${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc ${LITE_DIR}/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc + ${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc diff --git a/mindspore/lite/test/models_mindspore.cfg b/mindspore/lite/test/models_mindspore.cfg index e2d0510c1e..cca5f36835 100644 --- a/mindspore/lite/test/models_mindspore.cfg +++ b/mindspore/lite/test/models_mindspore.cfg @@ -11,3 +11,4 @@ resnext50.mindir 1.5 ocr_mobilenetV2.mindir 1.5 mobilenet_quant.mindir 5 mindspore_ghostnet_ssd_13x.mindir 1.5 +mindspore_ghost-nose-pets-811.mindir 0.5 diff --git a/mindspore/lite/test/models_onnx_fp16.cfg b/mindspore/lite/test/models_onnx_fp16.cfg index dec09e4812..fdcaf39c70 100644 --- a/mindspore/lite/test/models_onnx_fp16.cfg +++ b/mindspore/lite/test/models_onnx_fp16.cfg @@ -32,3 +32,4 @@ pointilism-9.onnx 3 rain-princess-9.onnx 5 udnie-9.onnx 3 adversarial_pruning.onnx 3 +residual_distill_res34_cifar10_bs_1_update.onnx 2 diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 4afe5d22db..00173c3fec 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -51,6 +51,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/clip_convert_activation_pass.cc ../optimizer/graph/group_depthwise_op_convert_pass.cc ../optimizer/graph/tflite_inputs_order_exchange_pass.cc + ../optimizer/graph/update_conv2d_param_pass.cc ../optimizer/graph/unused_cast_node_remove_pass.cc ../optimizer/graph/unused_transpose_node_remove_pass.cc ../optimizer/graph/identity_remove_pass.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index c91e22eb44..28332eca12 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -35,6 +35,7 @@ #include "tools/optimizer/graph/clip_convert_activation_pass.h" #include "tools/optimizer/graph/group_depthwise_op_convert_pass.h" #include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h" +#include "tools/optimizer/graph/update_conv2d_param_pass.h" #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" #include "tools/optimizer/graph/infershape_pass.h" @@ -63,6 +64,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver // fusion const_fold auto cf_pm = std::make_shared("constant folding pass manager", false); cf_pm->AddPass(std::make_shared()); + cf_pm->AddPass(std::make_shared()); // for now - trainning is not supporting fuse operations if (!config->trainModel) { @@ -78,11 +80,11 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver schema::ActivationType_RELU)); pm->AddPass(std::make_shared(true, "conv_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6)); - pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared( true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU)); pm->AddPass(std::make_shared( true, "conv_tuple_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6)); + pm->AddPass(std::make_shared()); } auto weight_format_hardcode_pass = std::make_shared(); weight_format_hardcode_pass->SetFmkType(config->fmk); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc index ddf69758e5..b78738f1c7 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc @@ -112,6 +112,7 @@ STATUS ConvertNcTensor2Nh(TensorT *tensor, const std::vector &pad_dims) { delete[] new_nhwc_data; return RET_OK; } + STATUS GlobalFormatTransformPass::TransWeightToNhwc(MetaGraphT *graph, const std::set &pre_not_trans_nodes) { MS_ASSERT(graph != nullptr); if (pre_not_trans_nodes.empty()) { @@ -185,7 +186,8 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc } // todo multi output,other edge need insert nh2nc node auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node); - if ((pre_node_output_indexs.size() != 1) && (node_type == schema::PrimitiveType_Activation)) { + if ((pre_node_output_indexs.size() != 1) && + (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat)) { pre_nh2nc_nodes->clear(); pre_not_trans_nodes->clear(); return RET_OK; diff --git a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc new file mode 100644 index 0000000000..95f247ea94 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/update_conv2d_param_pass.h" +#include +#include "mindspore/lite/include/errorcode.h" +#include "src/ops/primitive_c.h" + +namespace mindspore::opt { +bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + auto node_list = TopoSort(func_graph->get_return()); + int status = RET_OK; + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto type = opt::GetCNodeType(node); + if (type == schema::PrimitiveType_DepthwiseConv2D) { + auto dwconv2d_cnode = node->cast(); + auto primitive_c = GetValueNode>(dwconv2d_cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "Depthwise conv2D node has no primitiveC."; + return RET_ERROR; + } + auto primT = primitive_c->primitiveT(); + if (primT == nullptr) { + MS_LOG(ERROR) << "Depthwise conv2D node has no primitiveT."; + return RET_ERROR; + } + int channel_in = primT->value.AsDepthwiseConv2D()->channelIn; + if (channel_in == -1) { + auto input_node = node->cast()->input(lite::kAnfPopulaterInputNumTwo); + MS_ASSERT(input_node != nullptr); + if (input_node->isa()) { + auto param_node = input_node->cast(); + auto param = param_node->default_param(); + auto weight = std::dynamic_pointer_cast(param); + primT->value.AsDepthwiseConv2D()->channelIn = weight->tensor_shape().at(0); + } + } + } + if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { + MS_LOG(ERROR) << "remove identity pass is failed."; + return false; + } + } + return true; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.h b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.h new file mode 100644 index 0000000000..c6958d8062 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_ +#include "schema/inner/model_generated.h" +#include "backend/optimizer/common/pass.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore::opt { +class UpdateConv2DParamPass : public Pass { + public: + UpdateConv2DParamPass() : Pass("update_conv2d_param_pass") {} + ~UpdateConv2DParamPass() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_