From f0ec27b1069724eee96cbb63b062ae41f4eadb0e Mon Sep 17 00:00:00 2001 From: wangzhe Date: Fri, 13 Nov 2020 11:37:48 +0800 Subject: [PATCH] add infershape pass & add slice_prepose pass (Softmax) --- mindspore/lite/test/CMakeLists.txt | 3 + mindspore/lite/tools/converter/CMakeLists.txt | 3 + .../lite/tools/converter/anf_transform.cc | 10 + .../lite/tools/optimizer/common/gllo_utils.cc | 49 +++ .../lite/tools/optimizer/common/gllo_utils.h | 2 + .../fusion/conv_tuplegetitem_fusion.cc | 79 +++++ .../fusion/conv_tuplegetitem_fusion.h | 31 ++ .../optimizer/fusion/layer_norm_fusion.cc | 1 + .../tools/optimizer/graph/infershape_pass.cc | 306 ++++++++++++++++++ .../tools/optimizer/graph/infershape_pass.h | 48 +++ .../optimizer/graph/slice_prepose_pass.cc | 247 ++++++++++++++ .../optimizer/graph/slice_prepose_pass.h | 54 ++++ 12 files changed, 833 insertions(+) create mode 100644 mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc create mode 100644 mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.h create mode 100644 mindspore/lite/tools/optimizer/graph/infershape_pass.cc create mode 100644 mindspore/lite/tools/optimizer/graph/infershape_pass.h create mode 100644 mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc create mode 100644 mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index c88a02e633..723341fc1a 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -191,12 +191,15 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc + ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc + ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc ) endif() ### train diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 6b881a27ea..a39fb35b9c 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -39,6 +39,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/conv_transform_fusion.cc ../optimizer/fusion/conv_scale_fusion.cc ../optimizer/fusion/conv_bn_fusion.cc + ../optimizer/fusion/conv_tuplegetitem_fusion.cc ../optimizer/fusion/constant_folding_fusion.cc ../optimizer/fusion/quant_dtype_cast_fusion.cc ../optimizer/fusion/layer_norm_fusion.cc @@ -51,6 +52,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/unused_cast_node_remove_pass.cc ../optimizer/graph/unused_transpose_node_remove_pass.cc ../optimizer/graph/identity_remove_pass.cc + ../optimizer/graph/infershape_pass.cc + ../optimizer/graph/slice_prepose_pass.cc ) add_subdirectory(../anf_importer anf_importer) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 226dd0a017..d0be80a815 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -23,6 +23,7 @@ #include "tools/optimizer/fusion/conv_tuple_activation_fusion.h" #include "tools/optimizer/fusion/conv_scale_fusion.h" #include "tools/optimizer/fusion/conv_bn_fusion.h" +#include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" #include "tools/optimizer/fusion/constant_folding_fusion.h" #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" #include "tools/optimizer/fusion/layer_norm_fusion.h" @@ -35,6 +36,8 @@ #include "tools/optimizer/graph/clip_convert_activation_pass.h" #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" +#include "tools/optimizer/graph/infershape_pass.h" +#include "tools/optimizer/graph/slice_prepose_pass.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/weight_quantizer.h" @@ -76,6 +79,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver schema::ActivationType_RELU)); pm->AddPass(std::make_shared(true, "conv_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6)); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared( true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU)); pm->AddPass(std::make_shared( @@ -89,6 +93,12 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver weight_format_transform_pass->SetFmkType(config->fmk); weight_format_transform_pass->SetQuantType(config->quantType); graph_pm->AddPass(weight_format_transform_pass); + auto infershape_pass = std::make_shared(); + infershape_pass->SetFmkType(config->fmk); + graph_pm->AddPass(infershape_pass); + auto slice_prepose_pass = std::make_shared(); + slice_prepose_pass->SetFmkType(config->fmk); + graph_pm->AddPass(slice_prepose_pass); if (config->fmk == lite::converter::FmkType_MS) { auto remove_unused_cast_pass = std::make_shared(); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 93b5fc0583..2872f18f36 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -406,6 +406,55 @@ ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) { auto param_value = std::dynamic_pointer_cast(param->default_param()); return param_value; } + +AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index) { + if (cnode == nullptr) { + MS_LOG(ERROR) << "CNodePtr is nullptr"; + return nullptr; + } + auto inputs = cnode->inputs(); + if (!(0 < index && index < inputs.size())) { + return nullptr; + } + auto input = inputs[index]; + if (input == nullptr) { + MS_LOG(ERROR) << "CNode input is nullptr"; + return nullptr; + } + + AbstractBasePtr abstract = nullptr; + if (utils::isa(input)) { + auto parameter = input->cast(); + abstract = parameter->abstract(); + } else if (utils::isa(input)) { + auto input_cnode = input->cast(); + if (GetCNodeType(input_cnode) == schema::PrimitiveType_TupleGetItem) { + auto tuple_inputs = input_cnode->inputs(); + MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize); + auto get_item_input_cnode = tuple_inputs.at(1); + MS_ASSERT(get_item_input_cnode != nullptr); + auto idx = GetTupleGetItemOutIndex(input_cnode); + if (!utils::isa(get_item_input_cnode->abstract())) { + MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple"; + return nullptr; + } + auto abstract_tuple = utils::cast(get_item_input_cnode->abstract()); + auto abstract_list = abstract_tuple->elements(); + if (abstract_list.size() <= idx) { + MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect"; + return nullptr; + } + abstract = abstract_list[idx]; + } else { + abstract = input_cnode->abstract(); + } + } else { + MS_LOG(ERROR) << "unsupported input node type"; + return nullptr; + } + return abstract; +} + bool IsParamNode(const BaseRef &n) { if (!utils::isa(n)) { return false; diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 31194c42cb..893a5b48fa 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -75,6 +75,8 @@ size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node); +AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index); + enum kTransFilterType { kKCHW2HWCK, // 0 kKCHW2KHWC, diff --git a/mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc new file mode 100644 index 0000000000..0f81b00946 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" +#include +#include "src/ops/primitive_c.h" +#include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "securec/include/securec.h" + +namespace mindspore::opt { +namespace { +constexpr size_t kTupleGetItemLen = 3; +bool IsTupleGetItemNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_TupleGetItem; + } + return false; +} +} // namespace + +const BaseRef ConvTupleGetItemFusion::DefinePattern() const { + auto tuple_var = std::make_shared(IsTupleGetItemNode); + auto tuple_index = std::make_shared(); + auto conv_var = std::make_shared(IsConvNode); + return VectorRef({tuple_var, conv_var, tuple_index}); +} + +const AnfNodePtr ConvTupleGetItemFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_LOG(DEBUG) << "conv_tuplegetitem_fusion pass"; + if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return nullptr; + } + auto tuple_cnode = node->cast(); + if (CheckIfCNodeIsNull(tuple_cnode) != lite::RET_OK || + CheckInputSize(tuple_cnode, kTupleGetItemLen) != lite::RET_OK) { + return nullptr; + } + auto idx = GetTupleGetItemOutIndex(tuple_cnode); + if (idx != 0) { + MS_LOG(DEBUG) << "TupleGetItem's idx is not 0"; + return nullptr; + } + auto conv_node = tuple_cnode->input(1); + if (CheckIfAnfNodeIsNull(conv_node) != lite::RET_OK) { + return nullptr; + } + auto conv_cnode = conv_node->cast(); + if (CheckIfCNodeIsNull(conv_cnode) != lite::RET_OK) { + return nullptr; + } + auto abstr = conv_cnode->abstract(); + if (utils::isa(abstr)) { + auto elements = utils::cast(abstr)->elements(); + if (elements.empty()) { + MS_LOG(ERROR) << "AbstractTuple is empty"; + return nullptr; + } + conv_node->set_abstract(elements[0]); + } + return conv_node; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.h new file mode 100644 index 0000000000..2d04aa90f2 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_ +#define LITE_MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_ +#include +#include "backend/optimizer/common/optimizer.h" +namespace mindspore::opt { +class ConvTupleGetItemFusion : public PatternProcessPass { + public: + explicit ConvTupleGetItemFusion(const std::string &name = "conv_tuplegetitem_fusion", bool multigraph = true) + : PatternProcessPass(name, multigraph) {} + ~ConvTupleGetItemFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace mindspore::opt + +#endif // LITE_MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc b/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc index 8ddd737303..8db4c793f6 100644 --- a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc @@ -324,6 +324,7 @@ const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const } auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, gamma_shape, epsilon); + layer_norm_cnode->set_abstract(add2_cnode->abstract()); layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success"; return layer_norm_cnode; diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc new file mode 100644 index 0000000000..c794e3a364 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -0,0 +1,306 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/infershape_pass.h" +#include +#include +#include +#include "mindspore/lite/include/errorcode.h" +#include "mindspore/lite/src/ops/primitive_c.h" +#include "tools/anf_importer/import_from_meta_graphT.h" + +namespace mindspore::opt { +abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { + MS_ASSERT(nullptr != tensor); + std::vector shape(tensor->shape()); + auto type_id = static_cast(tensor->data_type()); + auto type_ptr = TypeIdToType(type_id); + std::vector shape_vector; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); + auto new_abstract = std::make_shared(type_ptr, shape_vector); + if (new_abstract == nullptr) { + MS_LOG(ERROR) << "new AbstractTensor failed"; + return nullptr; + } + auto new_value = std::make_shared(); + if (new_value == nullptr) { + MS_LOG(ERROR) << "new ParamValueLite failed"; + return nullptr; + } + new_value->set_tensor_shape(tensor->shape()); + new_value->set_tensor_type(tensor->data_type()); + new_value->set_format(tensor->GetFormat()); + new_abstract->set_value(new_value); + return new_abstract; +} + +STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { + MS_ASSERT(parameter != nullptr); + auto old_abstract = parameter->abstract(); + if (old_abstract == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << parameter->name(); + return RET_ERROR; + } + if (!utils::isa(old_abstract)) { + MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << parameter->name(); + return RET_ERROR; + } + auto abstract_tensor = utils::cast(old_abstract); + + auto typePtr = abstract_tensor->element()->GetTypeTrack(); + if (typePtr == nullptr) { + MS_LOG(ERROR) << "typePtr is nullptr"; + return RET_ERROR; + } + + if (!utils::isa(abstract_tensor->BuildShape())) { + MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << parameter->name(); + return RET_ERROR; + } + auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); + std::vector shape; + (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), + [](const int64_t &value) { return static_cast(value); }); + + auto new_abstract = std::make_shared(typePtr, shape_vector); + auto new_value = std::make_shared(); + new_value->set_tensor_shape(shape); // scalar's shape is {} + new_value->set_tensor_type(typePtr->type_id()); + new_value->set_format(schema::Format_NHWC); // default format is NHWC + if (parameter->has_default()) { + auto param_value = std::dynamic_pointer_cast(parameter->default_param()); + new_value->set_format(param_value->format()); + new_value->set_tensor_size(param_value->tensor_size()); + + char *tensor_data = new (std::nothrow) char[new_value->tensor_size()]; + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new char[] failed"; + return RET_ERROR; + } + auto ret = memcpy_s(tensor_data, new_value->tensor_size(), param_value->tensor_addr(), param_value->tensor_size()); + if (ret != RET_OK) { + MS_LOG(ERROR) << "memcpy error: " << ret; + return RET_ERROR; + } + new_value->set_tensor_addr(tensor_data); + } + new_abstract->set_value(new_value); + parameter->set_abstract(new_abstract); + return RET_OK; +} + +void InferShapePass::FreeTensors(std::vector *tensors) { + for (auto tensor : *tensors) { + delete tensor; + } + tensors->clear(); + tensors->shrink_to_fit(); +} + +STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector *input_tensors) { + MS_ASSERT(cnode != nullptr); + MS_ASSERT(input_tensors != nullptr); + auto inputs = cnode->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + auto input = inputs[i]; + if (input == nullptr) { + MS_LOG(ERROR) << "input is nullptr"; + return RET_ERROR; + } + auto tensor = std::make_unique(); + if (tensor == nullptr) { + MS_LOG(ERROR) << "new input tensor failed"; + return RET_ERROR; + } + + if (utils::isa(cnode->input(i))) { + MS_LOG(ERROR) << "input is value node"; + continue; + } + + AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, i); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Abstract of CNode is nullptr"; + return RET_ERROR; + } + if (!utils::isa(abstract)) { + MS_LOG(DEBUG) << "Abstract of parameter should be abstract tensor"; + return RET_ERROR; + } + auto abstract_tensor = utils::cast(abstract); + if (!utils::isa(abstract_tensor->GetValueTrack())) { // input node not complete infershape + MS_LOG(DEBUG) << "Value of abstract is not ParamValueLite, indicate that infershape has failed"; + return RET_ERROR; + } + auto param_value_lite = utils::cast(abstract_tensor->GetValueTrack()); + if (param_value_lite == nullptr) { + MS_LOG(ERROR) << "ParamValueLite of abstract is nullptr"; + return RET_ERROR; + } + tensor->set_shape(param_value_lite->tensor_shape()); + tensor->set_data_type(param_value_lite->tensor_type()); + tensor->SetFormat(schema::Format(param_value_lite->format())); + + if (utils::isa(input)) { + auto parameter = input->cast(); + if (parameter->has_default()) { + auto param_value = std::dynamic_pointer_cast(parameter->default_param()); + auto ret = tensor->MallocData(); + if (ret != 0) { + MS_LOG(ERROR) << "Malloc tensor data failed"; + return RET_ERROR; + } + ret = memcpy_s(tensor->MutableData(), tensor->Size(), param_value->tensor_addr(), param_value->tensor_size()); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy error: " << ret; + return RET_ERROR; + } + } + } + input_tensors->push_back(tensor.release()); + } + return RET_OK; +} + +STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector *output_tensors) { + MS_ASSERT(output_tensors != nullptr); + auto abstract = cnode->abstract(); + if (abstract == nullptr) { + MS_LOG(ERROR) << "abstract is nullptr"; + return RET_ERROR; + } + size_t num_outputs = 1; + if (utils::isa(abstract)) { + auto abstract_tuple = abstract->cast(); + num_outputs = abstract_tuple->size(); + } + for (size_t i = 0; i < num_outputs; ++i) { + auto output_tensor = std::make_unique(); + if (output_tensor == nullptr) { + MS_LOG(ERROR) << "new output tensor failed"; + return RET_ERROR; + } + output_tensors->push_back(output_tensor.release()); + } + return RET_OK; +} + +STATUS InferShapePass::SetCNodeAbstract(const std::vector &output_tensors, + const std::shared_ptr &cnode) { + MS_ASSERT(cnode != nullptr); + if (output_tensors.size() == 0) { + MS_LOG(ERROR) << "empty output_tensors"; + return RET_ERROR; + } + if (output_tensors.size() == 1) { + auto tensor = output_tensors.front(); + auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor); + if (new_abstract == nullptr) { + return RET_ERROR; + } + cnode->set_abstract(new_abstract); + } else { + AbstractBasePtrList abstract_list; + for (size_t i = 0; i < output_tensors.size(); i++) { + auto tensor = output_tensors.front(); + auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor); + if (new_abstract == nullptr) { + return RET_ERROR; + } + abstract_list.emplace_back(new_abstract); + } + cnode->set_abstract(std::make_shared(abstract_list)); + } + return RET_OK; +} + +bool InferShapePass::Run(const FuncGraphPtr &func_graph) { + if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) { + MS_LOG(INFO) << "The framework type of model should be tf/tflite."; + return false; + } + MS_ASSERT(func_graph != nullptr); + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (utils::isa(node)) { + int status = SetParameterAbstract(node->cast()); + if (status != RET_OK) { + return false; + } + continue; + } + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + auto origin_primc = GetValueNode>(cnode->input(0)); + if (origin_primc == nullptr) { + MS_LOG(ERROR) << "origin_primc is nullptr"; + return false; + } + auto origin_primt = origin_primc->GetPrimitiveT(); + if (origin_primt == nullptr) { + MS_LOG(ERROR) << "origin_primt is nullptr"; + return false; + } + auto type = GetCNodeType(cnode); + if ((type == schema::PrimitiveType_TupleGetItem) || +#ifdef SUPPORT_TRAIN + (type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) || +#endif + (type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) { + continue; + } + std::vector input_tensors; + std::vector output_tensors; + auto status = GetCNodeInputTensors(cnode, &input_tensors); + if (status != RET_OK) { + MS_LOG(DEBUG) << "input shape unknown, infershape can't process cnode " << cnode->fullname_with_scope(); + FreeTensors(&input_tensors); + continue; + } + status = GetCNodeOutputTensors(cnode, &output_tensors); + if (status != RET_OK) { + FreeTensors(&input_tensors); + FreeTensors(&output_tensors); + continue; + } + auto primt = std::make_unique(); + if (primt == nullptr) { + MS_LOG(ERROR) << "primt is nullptr"; + return false; + } + *primt = *origin_primt; + auto primc = std::shared_ptr(lite::PrimitiveC::Create(primt.release())); + if (primc == nullptr) { + MS_LOG(ERROR) << "primc is nullptr"; + return false; + } + status = primc->InferShape(input_tensors, output_tensors); + if (status == RET_OK) { + status = SetCNodeAbstract(output_tensors, cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope(); + } + } + FreeTensors(&input_tensors); + FreeTensors(&output_tensors); + } + return true; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.h b/mindspore/lite/tools/optimizer/graph/infershape_pass.h new file mode 100644 index 0000000000..bc64a583fe --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INFERSHAPE_PASS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INFERSHAPE_PASS_H_ +#include +#include +#include +#include "tools/converter/converter_flags.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "backend/optimizer/common/pass.h" +#include "mindspore/lite/src/tensor.h" +#include "mindspore/lite/include/errorcode.h" +using mindspore::lite::STATUS; +using mindspore::lite::converter::FmkType; +namespace mindspore::opt { +class InferShapePass : public Pass { + public: + InferShapePass() : Pass("infershape_pass") {} + ~InferShapePass() override = default; + bool Run(const FuncGraphPtr &graph) override; + void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; } + + private: + void FreeTensors(std::vector *tensors); + abstract::AbstractTensorPtr ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor); + STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector *input_tensors); + STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector *output_tensors); + STATUS SetParameterAbstract(const ParameterPtr ¶meter); + STATUS SetCNodeAbstract(const std::vector &output_tensors, const std::shared_ptr &cnode); + + private: + FmkType fmk_type = lite::converter::FmkType_ONNX; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INFERSHAPE_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc new file mode 100644 index 0000000000..296e6ab946 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc @@ -0,0 +1,247 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/slice_prepose_pass.h" +#include +#include +#include "mindspore/lite/include/errorcode.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "backend/optimizer/common/helper.h" +#include "src/ops/primitive_c.h" +#include "schema/inner/model_generated.h" +#include "src/common/log_adapter.h" + +using mindspore::lite::PrimitiveC; +namespace mindspore::opt { +namespace { +std::vector GetCNodeInputShape(const CNodePtr &cnode, size_t index = 1) { + MS_ASSERT(cnode != nullptr); + std::vector empty_shape; + if (index < 1 || cnode->inputs().size() <= index) { + MS_LOG(ERROR) << "out of index"; + return empty_shape; + } + auto abstract = GetCNodeInputAbstract(cnode, index); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Abstract of CNode is nullptr"; + return empty_shape; + } + if (!utils::isa(abstract)) { + MS_LOG(DEBUG) << "abstract is not AbstractTensor"; + return empty_shape; + } + auto abstract_tensor = utils::cast(abstract); + if (!utils::isa(abstract_tensor->GetValueTrack())) { + MS_LOG(DEBUG) << "Value of abstract is not ParamValueLite, indicate that infershape has failed"; + return empty_shape; + } + auto param_value_lite = utils::cast(abstract_tensor->GetValueTrack()); + if (param_value_lite == nullptr) { + MS_LOG(ERROR) << "ParamValueLite of abstract is nullptr"; + return empty_shape; + } + return param_value_lite->tensor_shape(); +} +} // namespace + +schema::SliceT *SlicePreposePass::GetSliceT(const CNodePtr &cnode) { + if (cnode == nullptr) { + return nullptr; + } + auto primc = GetValueNode>(cnode->input(0)); + if (primc == nullptr) { + return nullptr; + } + auto primt = primc->GetPrimitiveT(); + if (primt == nullptr || primt->value.AsSlice() == nullptr) { + return nullptr; + } + return primt->value.AsSlice(); +} + +STATUS SlicePreposePass::SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &preceed_cnode, const int index, + const TransactionPtr &tr) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(preceed_cnode != nullptr); + if (slice_cnode->input(1) != preceed_cnode) { + MS_LOG(ERROR) << "preceed node must be slice node's direct parent"; + return RET_ERROR; + } + if (IsMultiOutputTensors(graph, preceed_cnode)) { + MS_LOG(ERROR) << "preceed node referenced by multi nodes not support swap"; + return RET_ERROR; + } + auto manager = graph->manager(); + if (manager == nullptr) { + MS_LOG(ERROR) << "manager is nullptr"; + return RET_ERROR; + } + auto node_users = manager->node_users()[slice_cnode]; + if (tr != nullptr) { // do swap with transaction + for (auto &node_user : node_users) { + tr->SetEdge(node_user.first, node_user.second, preceed_cnode); + } + tr->SetEdge(slice_cnode, 1, preceed_cnode->input(index)); + tr->SetEdge(preceed_cnode, index, slice_cnode); + } else { + for (auto &node_user : node_users) { + manager->SetEdge(node_user.first, node_user.second, preceed_cnode); + } + manager->SetEdge(slice_cnode, 1, preceed_cnode->input(index)); + manager->SetEdge(preceed_cnode, index, slice_cnode); + } + return RET_OK; +} + +/* + * Prepose condition: + * the softmax axis is not sliced + */ +bool SlicePreposePass::PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &softmax_cnode) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(softmax_cnode != nullptr); + auto softmax_primc = GetValueNode>(softmax_cnode->input(0)); + if (softmax_primc == nullptr) { + MS_LOG(ERROR) << "softmax_primc is nullptr"; + return false; + } + auto softmax_primt = softmax_primc->GetPrimitiveT(); + if (softmax_primt == nullptr || softmax_primt->value.AsSoftMax() == nullptr) { + MS_LOG(ERROR) << "softmax_primt is nullptr"; + return false; + } + auto softmax_attr = softmax_primt->value.AsSoftMax(); + auto softmax_axis = softmax_attr->axis; + auto shape = GetCNodeInputShape(softmax_cnode, 1); + if (softmax_axis == -1) { + if (shape.empty()) { // when softmax axis == -1, shape info is needed to determine whether slice can be preposed + return false; + } + softmax_axis += shape.size(); + } + + auto slice_t = GetSliceT(slice_cnode); + MS_ASSERT(slice_t != nullptr); + auto slice_axes = slice_t->axes; + auto slice_begin = slice_t->begin; + auto slice_size = slice_t->size; + + for (size_t i = 0; i < slice_axes.size(); ++i) { + if (slice_axes[i] == softmax_axis) { + if (slice_begin[i] != 0) { + return false; + } + if (slice_size[i] != -1) { + if (shape.empty() || slice_axes[i] >= static_cast(shape.size())) { + return false; + } + if (slice_size[i] < shape[slice_axes[i]]) { + return false; + } + } + } + } + auto status = SwapSliceWithPreceed(graph, slice_cnode, softmax_cnode, 1); + return status == RET_OK; +} + +bool SlicePreposePass::DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &preceed_cnode) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(preceed_cnode != nullptr); + auto preceed_node_type = GetCNodeType(preceed_cnode); + switch (preceed_node_type) { + case schema::PrimitiveType_SoftMax: { + return PreposeWithSoftmax(graph, slice_cnode, preceed_cnode); + } + default: { + MS_LOG(DEBUG) << "Node type " << preceed_node_type << " currently not support SlicePrepose"; + } + } + return false; +} + +bool SlicePreposePass::Run(const FuncGraphPtr &graph) { + if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) { + MS_LOG(INFO) << "The framework type of model should be tf/tflite."; + return false; + } + MS_ASSERT(graph != nullptr); + bool changed = false; + while (true) { + bool this_time_changed = false; + auto node_list = TopoSort(graph->get_return()); + for (auto &node : node_list) { + if (node->func_graph() != graph) { + continue; + } + if (!utils::isa(node) || GetCNodeType(node) != schema::PrimitiveType_Slice) { + continue; + } + auto slice_cnode = node->cast(); + if (slice_cnode->inputs().size() != lite::kDoubleNum) { // only support params from attrs now + MS_LOG(INFO) << "SlicePrepose not support more than two inputs now"; + continue; + } + auto primt = GetSliceT(slice_cnode); + if (primt == nullptr) { + MS_LOG(ERROR) << "primitive_t of slice is nullptr"; + continue; + } + auto preceed_node = slice_cnode->input(1); + if (preceed_node == nullptr) { + MS_LOG(ERROR) << "preceed node is nullptr"; + continue; + } + auto output_tensor_num = GetOutputTensorNum(preceed_node); + if (output_tensor_num > 1) { + continue; + } + auto output_node_list = GetRealNodeUsedList(graph, utils::cast(preceed_node)); + if (output_node_list->size() > 1) { // referenced by multi nodes + continue; + } else { + if (utils::isa(preceed_node)) { + /* + * if preceed_node is parameter without default param, it's input placeholder, so we can't prepose + * if preceed_node is parameter with default param, constant_folding will process it + */ + continue; + } + auto preceed_cnode = preceed_node->cast(); + if (preceed_cnode == nullptr) { + MS_LOG(ERROR) << "preceed_cnode is nullptr"; + continue; + } + if (DoPrepose(graph, slice_cnode, preceed_cnode)) { + this_time_changed = true; + break; + } + } + } + if (this_time_changed) { + changed = true; + } else { + break; + } + } + return changed; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h new file mode 100644 index 0000000000..99ee5b2a9c --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SLICE_PREPOSE_PASS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SLICE_PREPOSE_PASS_H_ + +#include +#include +#include +#include "tools/converter/converter_flags.h" +#include "backend/optimizer/common/pass.h" +#include "include/errorcode.h" +#include "mindspore/core/ir/manager.h" +#include "schema/inner/model_generated.h" + +using mindspore::lite::converter::FmkType; +namespace mindspore::opt { +using lite::RET_ERROR; +using lite::RET_OK; +using lite::STATUS; +using TransactionPtr = std::shared_ptr; +using NodeUsedListPtr = std::shared_ptr>>; +class SlicePreposePass : public Pass { + public: + SlicePreposePass() : Pass("slice_prepose_pass") {} + ~SlicePreposePass() override = default; + bool Run(const FuncGraphPtr &graph) override; + void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; } + + private: + schema::SliceT *GetSliceT(const CNodePtr &cnode); + bool DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode); + STATUS SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode, + const int index, const TransactionPtr &tr = nullptr); + bool PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &softmax_cnode); + + private: + FmkType fmk_type = lite::converter::FmkType_ONNX; +}; +} // namespace mindspore::opt + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SLICE_PREPOSE_PASS_H_