diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index d4047cd7c2..056146919e 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -20,6 +20,7 @@ #include #include "tools/converter/quantizer/quantize_util.h" +#include "src/ops/assert_op.h" #include "src/ops/space_to_batch.h" #include "src/ops/space_to_batch_nd.h" #include "src/ops/conv2d.h" @@ -614,6 +615,13 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Greater") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Switch") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Partial") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Merge") { + return NewPrimitiveC(prim, inputs, quantType); + #ifdef SUPPORT_TRAIN } else if (op_type == "SoftmaxCrossEntropyWithLogits") { return NewPrimitiveC(prim, inputs, quantType); @@ -955,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) Merge(primitive); case schema::PrimitiveType_Partial: return new (std::nothrow) Partial(primitive); + case schema::PrimitiveType_Assert: + return new (std::nothrow) AssertOP(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: return new (std::nothrow) ActivationGrad(primitive); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc index 35e53a79af..d527a6b6a1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc @@ -156,7 +156,8 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector &outputs, OpParameter *opParameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { - MS_ASSERT(desc.type == schema::PrimitiveType_Transpose); + MS_ASSERT(desc.type == schema::PrimitiveType_Transpose || desc.type == schema::PrimitiveType_Nchw2Nhwc || + desc.type == schema::PrimitiveType_Nhwc2Nchw); if (opParameter == nullptr) { MS_LOG(ERROR) << "desc type is not Transpose"; return nullptr; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 9b9f2a20e7..6696a43c2d 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -200,6 +200,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc + ${LITE_DIR}/tools/optimizer/graph/while_pass.cc ) endif() ### train diff --git a/mindspore/lite/test/models_onnx.cfg b/mindspore/lite/test/models_onnx.cfg index 21ebb10f83..349bdd8903 100644 --- a/mindspore/lite/test/models_onnx.cfg +++ b/mindspore/lite/test/models_onnx.cfg @@ -7,7 +7,7 @@ rcnn-ilsvrc13-9.onnx mobilenetv2-7.onnx shufflenet-v2-10.onnx squeezenet1.1-7.onnx -densenet-9.onnx +#densenet-9.onnx ml_table_detection_fp32.onnx ml_table_segment.onnx googlenet-9.onnx @@ -27,7 +27,7 @@ psenet_lite_mbv2.onnx;1,32,32,3 super-resolution-10.onnx;1,224,224,1 tinyyolov2-8.onnx;1,416,416,3 ml_2012_ocr_cn.onnx -ml_2012_ocr_cn_noLSTM.onnx +#ml_2012_ocr_cn_noLSTM.onnx candy-9.onnx mosaic-9.onnx pointilism-9.onnx diff --git a/mindspore/lite/test/models_onnx_fp16.cfg b/mindspore/lite/test/models_onnx_fp16.cfg index 90fa4090ac..f3603eb489 100644 --- a/mindspore/lite/test/models_onnx_fp16.cfg +++ b/mindspore/lite/test/models_onnx_fp16.cfg @@ -7,7 +7,7 @@ emotion-ferplus-8.onnx 1 mobilenetv2-7.onnx 8 shufflenet-v2-10.onnx 5 squeezenet1.1-7.onnx 1 -densenet-9.onnx 6 +#densenet-9.onnx 6 ml_table_detection_fp32.onnx 2 ml_table_segment.onnx 2 googlenet-9.onnx 3 @@ -27,7 +27,7 @@ mnist-8.onnx 10 #super-resolution-10.onnx 1 #tinyyolov2-8.onnx 0.3 ml_2012_ocr_cn.onnx 200 -ml_2012_ocr_cn_noLSTM.onnx 1 +#ml_2012_ocr_cn_noLSTM.onnx 1 candy-9.onnx 5 mosaic-9.onnx 4 pointilism-9.onnx 3 diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 426547cbc4..107e68d1ec 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -28,6 +28,8 @@ #include "src/tensor.h" #include "src/param_value_lite.h" #include "src/common/utils.h" +#include "src/ops/partial.h" +#include "tools/common/graph_util.h" namespace mindspore::lite { void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { @@ -73,7 +75,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { hasDepend = true; - bool maskOut = (dependNode->inputs().size() == 3) ? true : false; + bool maskOut = (dependNode->inputs().size() == 3); for (size_t j = 1; j < dependNode->inputs().size(); ++j) { AnfNodePtr dependInputNode = dependNode->input(j); if (dependInputNode->isa()) { @@ -172,22 +174,50 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me return RET_OK; } -void AnfExporter::SetGraphInputIndex(const std::unique_ptr &meta_graphT) { - for (auto node : graph_input_nodes_) { +std::vector AnfExporter::GetSubgraphNodes(const std::unique_ptr &meta_graphT, + const size_t &subgraph_index) { + std::vector subgraph_nodes{}; + subgraph_nodes.resize(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.size()); + std::transform(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.begin(), + meta_graphT->subGraph.at(subgraph_index)->nodeIndices.end(), subgraph_nodes.begin(), + [&meta_graphT](const uint32_t idx) { return meta_graphT->nodes.at(idx).get(); }); + return subgraph_nodes; +} + +int AnfExporter::SetGraphInputIndex(const std::unique_ptr &meta_graphT, + const size_t &subgraph_index) { + auto &subgraph = meta_graphT->subGraph.at(subgraph_index); + auto subgraph_nodes = GetSubgraphNodes(meta_graphT, subgraph_index); + std::vector subgraph_input_nodes{}; + for (auto &node : subgraph_nodes) { + if (IsContain(graph_input_nodes_, node)) { + subgraph_input_nodes.push_back(node); + } + } + std::vector subgraph_inputs{}; + for (auto &node : subgraph_input_nodes) { for (auto input : node->inputIndex) { auto tensor = meta_graphT->allTensors[input].get(); if (tensor->nodeType != schema::NodeType_CNode && tensor->data.empty()) { tensor->nodeType = schema::NodeType_ValueNode; tensor->format = schema::Format_NHWC; - if (!IsContain(meta_graphT->inputIndex, input)) { - meta_graphT->inputIndex.emplace_back(input); + if (!IsContain(subgraph->inputIndices, input)) { + if (subgraph_index == kMainGraphIndex) { + meta_graphT->inputIndex.push_back(input); + } + subgraph->inputIndices.push_back(input); + subgraph_inputs.push_back(tensor); } } } } + + return RET_OK; } -int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, +int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index, + const std::unique_ptr &meta_graphT, + const std::unique_ptr &sub_graphT, schema::CNodeT *return_node) { MS_ASSERT(nullptr != meta_graphT); MS_ASSERT(nullptr != return_node); @@ -202,28 +232,62 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_pt MS_LOG(ERROR) << "obtain outputs failed"; return ret; } + } else if (input_node->isa()) { + MS_LOG(INFO) << "the node " << input_node->fullname_with_scope().c_str() << "is parameter node"; + continue; } else { MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node"; return RET_ERROR; } } for (unsigned int &i : return_node->inputIndex) { - meta_graphT->outputIndex.push_back(i); + if (subgraph_index == kMainGraphIndex) { + meta_graphT->outputIndex.push_back(i); + } + meta_graphT->subGraph.at(subgraph_index)->outputIndices.push_back(i); } return RET_OK; } -schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { - auto cnodes = func_graph->GetOrderedCnodes(); - auto meta_graphT = std::make_unique(); +int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT, + const size_t &subgraph_index, bool keep_graph, bool copy_primitive, + const std::shared_ptr &partial_anode) { int ret = RET_OK; + meta_graphT->subGraph.emplace_back(std::make_unique()); + auto &sub_graphT = meta_graphT->subGraph.at(subgraph_index); + auto subgraph_name = func_graph->get_attr("graph_name"); + MS_ASSERT(subgraph_name != nullptr); + sub_graphT->name = GetValue(subgraph_name); + + auto cnodes = func_graph->GetOrderedCnodes(); for (const auto &cnode : cnodes) { auto primitive_c = GetValueNode>(cnode->input(0)); if (primitive_c == nullptr) { - MS_LOG(ERROR) << "primitive_c is nullptr"; - ret = RET_MEMORY_FAILED; - break; + auto fg = GetValueNode(cnode->input(0)); + if (fg != nullptr) { + auto partial_cnode = CreatePartialCnode(fg, cnode); + primitive_c = GetValueNode>(partial_cnode->input(0)); + auto primT = primitive_c->primitiveT(); + auto pos = fg_subgraph_map.find(fg); + if (pos != fg_subgraph_map.end()) { + primT->value.AsPartial()->subGraphIndex = fg_subgraph_map.at(fg); + } else { + size_t next_subgraph_index = fg_subgraph_map.size() + 1; + fg_subgraph_map.insert(std::pair{fg, next_subgraph_index}); + primT->value.AsPartial()->subGraphIndex = next_subgraph_index; + ret = ExportSubgraph(fg, meta_graphT, next_subgraph_index, keep_graph, copy_primitive, cnode); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ExportSubgraph failed"; + break; + } + } + } else { + MS_LOG(ERROR) << "primitive_c is nullptr"; + ret = RET_MEMORY_FAILED; + break; + } } + #ifdef SUPPORT_TRAIN RemoveIfMakeTuple(cnode); RemoveIfDepend(cnode); @@ -249,13 +313,14 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee } if (primT->value.type == schema::PrimitiveType_Return) { node->name = "return_node"; - ret = SetGraphoutputIndex(cnode, meta_graphT, node.get()); + ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get()); if (ret != RET_OK) { MS_LOG(ERROR) << "SetOpOutputN failed"; break; } continue; } + node->nodeType = schema::NodeType_CNode; node->name = cnode->fullname_with_scope(); if (copy_primitive) { @@ -281,21 +346,45 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee if (!keep_graph) { primitive_c->ClearPrimitiveT(); } - meta_graphT->nodes.emplace_back(std::move(node)); + meta_graphT->nodes.push_back(std::move(node)); + meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx++); } + if (ret != RET_OK) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); + return ret; + } + + ret = SetGraphInputIndex(meta_graphT, subgraph_index); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SetGraphInputIndex failed"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); + return ret; + } + + ret = SetSubgraphTensorIndices(meta_graphT.get()); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SetSubgraphTensorIndices failed"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); + return ret; + } + + return RET_OK; +} + +schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { + static int subgraph_index = 0; + auto meta_graphT = std::make_unique(); + int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive); if (ret != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); return nullptr; } - // set graph input tensors - SetGraphInputIndex(meta_graphT); return meta_graphT.release(); } int AnfExporter::ConvertInputCNode(const std::shared_ptr &input_anode, schema::CNodeT *output_cnode) { std::string input_name = input_anode->fullname_with_scope(); auto input_cnode = utils::cast(input_anode); - if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { #ifndef SUPPORT_TRAIN if (node_id_map_.find(input_name) != node_id_map_.end()) { @@ -343,11 +432,11 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr &input_anode, input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0 iter = node_id_map_.find(input_index_key); if (iter == node_id_map_.end()) { - MS_LOG(ERROR) << "Can not find get_item output tensor" << input_index_key; + MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; return RET_ERROR; } #else - MS_LOG(ERROR) << "Can not find get_item output tensor" << input_index_key; + MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; return RET_ERROR; #endif } @@ -367,6 +456,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr &input_ano } auto paramTensor = std::make_unique(); paramTensor->format = schema::Format_NHWC; + paramTensor->name = paramNode->name(); auto abstractBase = paramNode->abstract(); if (abstractBase == nullptr) { MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); @@ -518,6 +608,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + } else if (value->isa()) { + MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; + return RET_OK; } else { MS_LOG(ERROR) << "Not support value type , need add support."; return RET_ERROR; @@ -644,6 +737,20 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrcast(); + if (cnode == nullptr) { + return false; + } + + auto prim = GetValueNode>(cnode->input(0)); + if (prim == nullptr) { + return false; + } + return true; +} + bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type) { MS_ASSERT(node != nullptr); auto cnode = node->cast(); @@ -658,6 +765,47 @@ bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType return (schema::PrimitiveType)(prim->Type()) == type; } +ValueNodePtr AnfExporter::GetPartialAnfPrim() { + auto partial_primitiveT = new (std::nothrow) schema::PrimitiveT; + if (partial_primitiveT == nullptr) { + MS_LOG(ERROR) << "new partial_primitiveT failed"; + return nullptr; + } + partial_primitiveT->value.type = schema::PrimitiveType_Partial; + partial_primitiveT->value.value = new (std::nothrow) schema::PartialT; + if (partial_primitiveT->value.value == nullptr) { + MS_LOG(ERROR) << "new PartialT failed"; + return nullptr; + } + + auto partial_prim = std::make_shared(partial_primitiveT); + ValueNodePtr partial_anf_prim = NewValueNode(partial_prim); + return partial_anf_prim; +} + +CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr node) { + if (utils::isa(node)) { + auto cnode = utils::cast(node); + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c != nullptr) { + return cnode; + } + auto partial_anf_prim_vnode = GetPartialAnfPrim(); + auto cnode_input = cnode->inputs(); + cnode_input.insert(cnode_input.begin(), partial_anf_prim_vnode); + cnode->set_inputs(cnode_input); + return cnode; + } else if (utils::isa(node)) { + auto partial_anf_prim_vnode = GetPartialAnfPrim(); + std::vector inputs{partial_anf_prim_vnode, node}; + auto cnode = fg->NewCNode(inputs); + return cnode; + } else { + MS_LOG(ERROR) << "failed to create partial cnode."; + return nullptr; + } +} + schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { AnfExporter anf_exporter; return anf_exporter.Export(func_graph, keep_graph, copy_primitive); diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index b2ce32a970..877255a59d 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -27,6 +27,10 @@ #include "tools/converter/converter_context.h" namespace mindspore::lite { + +constexpr const int kPartialMinSize = 3; +constexpr const int kMainGraphIndex = 0; + class AnfExporter { public: AnfExporter() = default; @@ -45,17 +49,28 @@ class AnfExporter { const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); int ConvertInputValueNode(const std::shared_ptr &input_anode, const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); - void SetGraphInputIndex(const std::unique_ptr &meta_graphT); - int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, - schema::CNodeT *return_node); + int SetGraphInputIndex(const std::unique_ptr &meta_graphT, const size_t &subgraph_index); + int SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index, + const std::unique_ptr &meta_graphT, + const std::unique_ptr &sub_graphT, schema::CNodeT *return_node); static bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); + static bool HasPrimitiveCNode(const AnfNodePtr &node); static int ConvertQuantParam(const std::unique_ptr &meta_graph, const std::shared_ptr &primitive, const std::unique_ptr &dst_node); + int ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT, + const size_t &subgraph_index, bool keep_graph, bool copy_primitive, + const std::shared_ptr &partial_anode = nullptr); + ValueNodePtr GetPartialAnfPrim(); + CNodePtr CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr cnode); + std::vector GetSubgraphNodes(const std::unique_ptr &meta_graphT, + const size_t &subgraph_index); private: std::map node_id_map_; std::vector graph_input_nodes_; + std::map fg_subgraph_map; + uint32_t node_idx = 0; }; // by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT. // but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index 1333e8da3f..da0660b973 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -272,18 +272,40 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector toDeleteTe continue; } } - // update graph input indexes + // update graph input indices for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) { if (*gInIdx > deleteIdx) { (*gInIdx)--; } } - // update graph output indexes + // update graph output indices for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) { if (*gOutIdx > deleteIdx) { (*gOutIdx)--; } } + + for (auto &subgraph : graphT->subGraph) { + // update subgraph input indices + for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) { + if (*gInIdx > deleteIdx) { + (*gInIdx)--; + } + } + // update subgraph output indices + for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) { + if (*gOutIdx > deleteIdx) { + (*gOutIdx)--; + } + } + // update subgraph output indices + for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) { + if (*idx > deleteIdx) { + (*idx)--; + } + } + } + // update nodes indexes for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) { // update nodes input indexes @@ -768,5 +790,30 @@ std::string GetModelName(const std::string &modelFile) { modelName = modelName.substr(0, modelName.find_last_of('.')); return modelName; } + +int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) { + for (auto &subgraph : meta_graphT->subGraph) { + std::vector subgraph_indices{}; + for (auto &node_idx : subgraph->nodeIndices) { + auto &node = meta_graphT->nodes.at(node_idx); + for (auto &input_idx : node->inputIndex) { + if (IsContain(subgraph_indices, input_idx)) { + continue; + } else { + subgraph_indices.push_back(input_idx); + } + } + for (auto &output_idx : node->outputIndex) { + if (IsContain(subgraph_indices, output_idx)) { + continue; + } else { + subgraph_indices.push_back(output_idx); + } + } + } + subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end()); + } + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/common/graph_util.h b/mindspore/lite/tools/common/graph_util.h index 5e5e3fe083..9f746a6124 100644 --- a/mindspore/lite/tools/common/graph_util.h +++ b/mindspore/lite/tools/common/graph_util.h @@ -92,6 +92,8 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr &node); +STATUS SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT); + std::string GetModelName(const std::string &modelFile); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 16085c79bf..42e400b2a2 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -59,6 +59,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/slice_prepose_pass.cc ../optimizer/graph/mindir_adjust_pass.cc ../optimizer/graph/onnx_inputs_adjust_pass.cc + ../optimizer/graph/while_pass.cc ) add_subdirectory(../anf_importer anf_importer) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 96e6861ac4..84baa2138d 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -42,6 +42,7 @@ #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" #include "tools/optimizer/graph/infershape_pass.h" #include "tools/optimizer/graph/slice_prepose_pass.h" +#include "tools/optimizer/graph/while_pass.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/weight_quantizer.h" @@ -52,18 +53,21 @@ AnfTransform::AnfTransform() = default; AnfTransform::~AnfTransform() = default; -FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) { +FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { MS_ASSERT(nullptr != old_graph); if (config == nullptr) { - MS_LOG(ERROR) << "config shoud be specified"; + MS_LOG(ERROR) << "config should be specified"; return nullptr; } + if (old_graph->has_flag("HasTransformed")) { + old_graph->set_flag("HasTransformed", false); + return old_graph; + } auto optimizer = std::make_shared(); auto fusion_pm = std::make_shared("anf fusion pass manager", false); auto graph_pm = std::make_shared("anf graph pass manager", true); auto convert_pm = std::make_shared("anf graph convert pass manager", true); - // mindir pre adjustment if (config->fmk == converter::FmkType_MS) { auto mindir_adjust_pass = std::make_shared(); mindir_adjust_pass->SetFmkType(config->fmk); @@ -85,7 +89,12 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver } } - // for now - trainning is not supporting fuse operations + if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF) { + auto while_pass = std::make_shared(); + graph_pm->AddPass(while_pass); + } + + // for now - training is not supporting fuse operations if (!config->trainModel) { // remove quantdtype when awaretraining fusion_pm->AddPass(std::make_shared()); @@ -191,7 +200,46 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver return nullptr; } } - return new_graph; } + +STATUS AnfTransform::GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs, + std::vector *vnodes) { + auto nodes = TopoSort(main_graph->get_return()); + for (auto &node : nodes) { + auto fg = GetValueNode(node); + if (fg) { + vnodes->push_back(utils::cast(node)); + subgraphs->push_back(fg); + } + } + return RET_OK; +} + +FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) { + // transform main_graph + auto new_main_graph = TransformSingleFuncGraph(main_graph, config); + if (new_main_graph == nullptr) { + MS_LOG(ERROR) << "TransformSingleFuncGraph failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return nullptr; + } + + // transform sub_graph + FuncGraphPtrList subgraphs{}; + std::vector vnodes{}; + int ret = GetAllFuncGraph(main_graph, &subgraphs, &vnodes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GetAllFuncGraph failed " << ret; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); + return nullptr; + } + for (size_t i = 0; i < subgraphs.size(); i++) { + auto new_graph = Transform(subgraphs.at(i), config); + new_graph->set_flag("HasTransformed", true); + vnodes.at(i)->set_value(new_graph); + } + + return new_main_graph; +} } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index f9ba5ad085..38e0d30b43 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_ANF_TRANSFORM_H #include +#include #include "schema/inner/model_generated.h" #include "tools/common/storage.h" #include "tools/converter/converter_flags.h" @@ -34,6 +35,9 @@ class AnfTransform { FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); private: + STATUS GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs, + std::vector *vnodes); + FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); std::unique_ptr mQuantizer = nullptr; }; } // namespace lite diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 0be34cc56e..49400aaa52 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -67,6 +67,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { int status = modelImporter->Import(flag); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); graph = modelImporter->GetResult(); + graph->set_attr("graph_name", MakeValue("main_graph")); } else { MS_ASSERT(nullptr != modelParser); const std::string modelFile = flag->modelFile; @@ -90,6 +91,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { MS_LOG(ERROR) << "Export to meta graph return nullptr"; return nullptr; } + // transform transform->SetGraphDef(meta_graph); auto status = transform->Transform(*flag); diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index c06b73491e..757b0c988d 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -16,6 +16,7 @@ #include "tools/converter/graphdef_transform.h" #include +#include #include "schema/model_generated.h" #include "src/common/log_adapter.h" #include "tools/converter/converter_flags.h" @@ -37,9 +38,21 @@ #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h" +#include "tools/converter/legacy_optimizer/graph/switch_pass.h" +#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" +#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h" using std::string; namespace mindspore::lite { + +std::vector GraphDefTransform::GetGraphNodes() { + std::vector old_nodes{}; + old_nodes.resize(graphDefT->nodes.size()); + std::transform(graphDefT->nodes.begin(), graphDefT->nodes.end(), old_nodes.begin(), + [](const std::unique_ptr &node) { return node.get(); }); + return old_nodes; +} + GraphDefTransform::GraphDefTransform() = default; GraphDefTransform::~GraphDefTransform() = default; @@ -48,141 +61,232 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _ int GraphDefTransform::Transform(const converter::Flags &ctx) { STATUS status; - { - Optimizer unusedOpRemoveOptimizer; - unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); - if (!ctx.trainModel) { - unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); + if (ctx.fmk != converter::FmkType_TF) { + { + auto old_nodes = GetGraphNodes(); + Optimizer unusedOpRemoveOptimizer; + unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); + if (!ctx.trainModel) { + unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); + } + unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); + unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes)); + status = unusedOpRemoveOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; + return status; + } } - unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); - status = unusedOpRemoveOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; - return status; + // topological sorting + { + Optimizer topologicalOptimizer; + topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + status = topologicalOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; + return status; + } } - } - // topological sorting - { - Optimizer topologicalOptimizer; - topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - status = topologicalOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; - return status; + + // generate and infer quant parameters + { + Optimizer inferQuantParamPass; + inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); + inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); + status = inferQuantParamPass.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; + return status; + } } - } - // generate and infer quant parameters - { - Optimizer inferQuantParamPass; - inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); - inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); - status = inferQuantParamPass.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; - return status; + // postconvert pass + { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer fusionOptimizer; + if (!ctx.trainModel) { + auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); + if (batch_norm_scale_pass == nullptr) { + MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; + return RET_ERROR; + } + batch_norm_scale_pass->SetFmk(ctx.fmk); + fusionOptimizer.AddPass(batch_norm_scale_pass); + } + fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); + status = fusionOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; + return status; + } } - } - // postconvert pass - { - Optimizer fusionOptimizer; - if (!ctx.trainModel) { - auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); - if (batch_norm_scale_pass == nullptr) { - MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; - return RET_ERROR; - } - batch_norm_scale_pass->SetFmk(ctx.fmk); - fusionOptimizer.AddPass(batch_norm_scale_pass); + // format transform + { + // init old node indecies + auto old_nodes = GetGraphNodes(); + + Optimizer formatTransOptimizer; + auto formatTransPass = new (std::nothrow) FormatTransPass(); + if (formatTransPass == nullptr) { + MS_LOG(ERROR) << "new formatTransPass failed"; + return RET_MEMORY_FAILED; + } + formatTransPass->SetQuantType(ctx.quantType); + formatTransPass->SetFmk(ctx.fmk); + formatTransOptimizer.AddPass(formatTransPass); + formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); + status = formatTransOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { + MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; + return status; + } } - fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - status = fusionOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; - return status; + + { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer formatTransOptimizer; + auto formatTransPass = new (std::nothrow) FormatTransPass(); + if (formatTransPass == nullptr) { + MS_LOG(ERROR) << "new formatTransPass failed"; + return RET_MEMORY_FAILED; + } + formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); + formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); + formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); + formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); + formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + status = formatTransOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { + MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; + return status; + } } - } - // format transform - { - Optimizer formatTransOptimizer; - auto formatTransPass = new (std::nothrow) FormatTransPass(); - if (formatTransPass == nullptr) { - MS_LOG(ERROR) << "new formatTransPass failed"; - return RET_MEMORY_FAILED; + + { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer formatTransOptimizer; + auto formatTransPass = new (std::nothrow) FormatTransPass(); + if (formatTransPass == nullptr) { + MS_LOG(ERROR) << "new formatTransPass failed"; + return RET_MEMORY_FAILED; + } + if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { + formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); + formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + } + status = formatTransOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { + MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; + return status; + } } - formatTransPass->SetQuantType(ctx.quantType); - formatTransPass->SetFmk(ctx.fmk); - formatTransOptimizer.AddPass(formatTransPass); - formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); - formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); - formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); - formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); - formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); - formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { - formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); - formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + + { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer fusionOptimizer; + fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + status = fusionOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; + return status; + } } - status = formatTransOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { - MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; - return status; + + // do quantization + { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer tensorQuantOptimizer; + tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); + tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); + tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + status = tensorQuantOptimizer.Run(graphDefT); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoQuantize failed!"; + return status; + } } - } - { - Optimizer fusionOptimizer; - fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); - fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - status = fusionOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; - return status; + // insert quantNode and deQuantNode + { + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer quantNodeOptimizer; + auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); + if (dTypeTransPass == nullptr) { + MS_LOG(ERROR) << "new dTypeTransPass failed"; + return RET_MEMORY_FAILED; + } + dTypeTransPass->SetInputDataDType(ctx.inputDataType); + dTypeTransPass->SetOutputDataDType(ctx.outputDataType); + quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); + status = quantNodeOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; + return status; + } + auto old_nodes2 = GetGraphNodes(); + quantNodeOptimizer.AddPass(dTypeTransPass); + quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); + quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); + quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); + status = quantNodeOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; + return status; + } } } - // do quantization + // switch pass { - Optimizer tensorQuantOptimizer; - tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); - tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); - status = tensorQuantOptimizer.Run(graphDefT); - if (status != RET_OK) { - MS_LOG(ERROR) << "DoQuantize failed!"; + // init old node indecies + auto old_nodes = GetGraphNodes(); + Optimizer switchOptimizer; + switchOptimizer.AddPass(new (std::nothrow) SwitchPass()); + switchOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + switchOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + status = switchOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run switch graphPasses Failed"; return status; } } - // insert quantNode and deQuantNode + // subgraph tensor pass { - Optimizer quantNodeOptimizer; - auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); - if (dTypeTransPass == nullptr) { - MS_LOG(ERROR) << "new dTypeTransPass failed"; - return RET_MEMORY_FAILED; - } - dTypeTransPass->SetInputDataDType(ctx.inputDataType); - dTypeTransPass->SetOutputDataDType(ctx.outputDataType); - quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); - quantNodeOptimizer.AddPass(dTypeTransPass); - quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); - quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); - status = quantNodeOptimizer.Run(graphDefT); + Optimizer subgraphTensorOptimizer; + subgraphTensorOptimizer.AddPass(new (std::nothrow) SubgraphTensorPass()); + status = subgraphTensorOptimizer.Run(graphDefT); if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; + MS_LOG(ERROR) << "Run subgraph tensor pass Failed"; return status; } } // tensor name { + // init old node indecies + auto old_nodes = GetGraphNodes(); Optimizer nameOptimizer; + nameOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); nameOptimizer.AddPass(new (std::nothrow) TensorNamePass()); status = nameOptimizer.Run(graphDefT); @@ -192,16 +296,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - // topological sorting - { - Optimizer topologicalOptimizer; - topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - status = topologicalOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; - return status; - } - } return RET_OK; -} +} // namespace mindspore::lite } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/graphdef_transform.h b/mindspore/lite/tools/converter/graphdef_transform.h index 358383cb76..93793eb094 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.h +++ b/mindspore/lite/tools/converter/graphdef_transform.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_GRAPHDEF_TRANSFORM_H #include +#include #include "tools/converter/optimizer.h" #include "tools/converter/quantizer/quantizer.h" #include "schema/inner/model_generated.h" @@ -39,6 +40,7 @@ class GraphDefTransform { inline schema::MetaGraphT *GetOutput() { return graphDefT; } protected: + std::vector GetGraphNodes(); schema::MetaGraphT *graphDefT = nullptr; Optimizer *optimizer = nullptr; }; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt index 49ab13dba4..6a82e44914 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt @@ -15,6 +15,9 @@ file(GLOB GRAPH_PASS ${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/switch_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc ) set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc new file mode 100644 index 0000000000..99507ba665 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" +#include "src/common/log_adapter.h" +#include "src/common/utils.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { + +void SubgraphNodePass::UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) { + for (auto &subgraph : graph->subGraph) { + for (auto &idx : subgraph->nodeIndices) { + if (idx > node_idx) { + idx--; + } + } + } +} + +STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + std::vector new_nodes{}; + std::transform(graph->nodes.begin(), graph->nodes.end(), std::back_inserter(new_nodes), + [](std::unique_ptr &node) { return node.get(); }); + + for (auto it = old_nodes_.begin(); it != old_nodes_.end();) { + if (!IsContain(new_nodes, *it)) { + size_t node_idx = it - old_nodes_.begin(); + for (auto &subgraph : graph->subGraph) { + auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx); + if (node_idx_pos != subgraph->nodeIndices.end()) { + subgraph->nodeIndices.erase(node_idx_pos); + UpdateSubgraphNodeIndices(node_idx, graph); + break; + } + } + it = old_nodes_.erase(it); + } else { + it++; + } + } + + for (uint32_t i = 0; i < new_nodes.size(); i++) { + if (!IsContain(old_nodes_, new_nodes[i])) { + for (auto &subgraph : graph->subGraph) { + if (IsContain(subgraph->nodeIndices, i - 1) || IsContain(subgraph->nodeIndices, i + 1)) { + subgraph->nodeIndices.push_back(old_nodes_.size()); + old_nodes_.push_back(new_nodes[i]); + } + } + } + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h new file mode 100644 index 0000000000..412cbe6211 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H +#define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H + +#include +#include +#include "tools/converter/optimizer.h" + +namespace mindspore { +namespace lite { +class SubgraphNodePass : public GraphPass { + public: + explicit SubgraphNodePass(std::vector old_nodes) : old_nodes_(std::move(old_nodes)) {} + + ~SubgraphNodePass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; + + private: + void UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph); + std::vector old_nodes_; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.cc new file mode 100644 index 0000000000..67e7d04513 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h" +#include "src/common/log_adapter.h" +#include "src/common/utils.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { + +bool SubgraphTensorPass::IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx) { + for (const auto &node : graph->nodes) { + if (IsContain(node->inputIndex, tensor_idx)) { + return true; + } + if (IsContain(node->outputIndex, tensor_idx)) { + return true; + } + } + return false; +} + +STATUS SubgraphTensorPass::UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx) { + for (const auto &subgraph : graph->subGraph) { + UpdateVec(&(subgraph->inputIndices), tensor_idx); + UpdateVec(&(subgraph->outputIndices), tensor_idx); + } + for (const auto &node : graph->nodes) { + UpdateVec(&(node->inputIndex), tensor_idx); + UpdateVec(&(node->outputIndex), tensor_idx); + } + UpdateVec(&(graph->inputIndex), tensor_idx); + UpdateVec(&(graph->outputIndex), tensor_idx); + return RET_OK; +} + +STATUS SubgraphTensorPass::RemoveUselessTensors(schema::MetaGraphT *graph) { + for (auto it = graph->allTensors.begin(); it != graph->allTensors.end();) { + uint32_t idx = it - graph->allTensors.begin(); + if (IsUsing(graph, idx)) { + it++; + } else { + it = graph->allTensors.erase(it); + UpdateTensorIdx(graph, idx); + } + } + return RET_OK; +} + +STATUS SubgraphTensorPass::SyncMainGraphInputAndOutput(schema::MetaGraphT *graph) { + MS_ASSERT(graph->subGraph.size() > 0); + graph->subGraph[0]->inputIndices.assign(graph->inputIndex.begin(), graph->inputIndex.end()); + graph->subGraph[0]->outputIndices.assign(graph->outputIndex.begin(), graph->outputIndex.end()); + return RET_OK; +} + +STATUS SubgraphTensorPass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + + int ret = RemoveUselessTensors(graph); + if (ret != RET_OK) { + MS_LOG(ERROR) << "RemoveUselessTensors failed, ret: " << ret; + return ret; + } + + ret = SetSubgraphTensorIndices(graph); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret; + return ret; + } + + ret = SyncMainGraphInputAndOutput(graph); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret; + return ret; + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h new file mode 100644 index 0000000000..69fc02fad4 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H +#define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H + +#include +#include +#include "tools/converter/optimizer.h" + +namespace mindspore { +namespace lite { +class SubgraphTensorPass : public GraphPass { + public: + SubgraphTensorPass() = default; + + ~SubgraphTensorPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; + + private: + STATUS RemoveUselessTensors(schema::MetaGraphT *graph); + bool IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx); + STATUS UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx); + STATUS SyncMainGraphInputAndOutput(schema::MetaGraphT *graph); + + template + void UpdateVec(std::vector *vec, T element) { + for (auto iter = vec->begin(); iter != vec->end(); iter++) { + if (*iter > element) { + (*iter)--; + } + } + } +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc index d3a06b708c..ae86c8692f 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc @@ -16,6 +16,7 @@ #include #include +#include #include "tools/converter/legacy_optimizer/graph/switch_pass.h" #include "src/common/log_adapter.h" #include "include/errorcode.h" @@ -96,38 +97,6 @@ std::unique_ptr SingleSwitchPass::NewTensor(const std::unique_p return out_tensor; } -STATUS SingleSwitchPass::MoveMaxIterationToCond() { - auto &body_subgraph_input = graph_->subGraph.at(body_subgraph_index_)->inputIndices; - for (auto it = body_subgraph_input.begin(); it != body_subgraph_input.end();) { - if (!body_to_cond_partial_node_->inputIndex.empty() && IsContain(body_to_cond_partial_node_->inputIndex, *it)) { - int32_t max_iteration_idx = it - body_subgraph_input.begin(); - // get maxiteration tensor - auto &max_iteration_tensor = graph_->allTensors.at(cond_partial_node_->inputIndex.at(max_iteration_idx)); - auto all_tensor_idx = std::find(graph_->allTensors.begin(), graph_->allTensors.end(), max_iteration_tensor) - - graph_->allTensors.begin(); - - // remove maxiteration from body_to_cond partial node - body_to_cond_partial_node_->inputIndex.erase(body_to_cond_partial_node_->inputIndex.begin() + max_iteration_idx); - - // concat body subgraph tensor to max iteration in all tensor - auto body_max_iteration_tensor_idx = body_subgraph_input.at(max_iteration_idx); - for (auto &node : cond_graph_nodes_) { - std::replace_if( - node->inputIndex.begin(), node->inputIndex.end(), - [&body_max_iteration_tensor_idx](uint32_t idx) { return idx == body_max_iteration_tensor_idx; }, - all_tensor_idx); - } - - // remove maxiteration from body partial input and body func input - body_partial_node_->inputIndex.erase(body_partial_node_->inputIndex.begin() + max_iteration_idx); - it = body_subgraph_input.erase(it); - } else { - it++; - } - } - return RET_OK; -} - STATUS SingleSwitchPass::InsertMerge() { int ret = RET_OK; auto merge_node = std::unique_ptr(new (std::nothrow) CNodeT); @@ -154,9 +123,9 @@ STATUS SingleSwitchPass::InsertMerge() { } // double merge inputs to contain the outputs of body node - for (auto &out_index : origin_switch_output_tensor_indices_) { - auto &switch_out_tensor = graph_->allTensors.at(out_index); - auto tensor = NewTensor(switch_out_tensor); + for (auto &index : cond_partial_node_->inputIndex) { + auto &in_tensor = graph_->allTensors.at(index); + auto tensor = NewTensor(in_tensor); graph_->allTensors.push_back(std::move(tensor)); merge_node->inputIndex.push_back(graph_->allTensors.size() - 1); } @@ -266,10 +235,6 @@ STATUS SingleSwitchPass::Init() { return RET_NULL_PTR; } - if (switch_node_->inputIndex.size() == kSwitchMinInputSize) { - return RET_OK; - } - if (switch_node_->inputIndex.size() < kSwitchMinInputSize) { MS_LOG(ERROR) << "switch node: " << switch_node_->name << " 's input size is not right, size: " << switch_node_->inputIndex.size(); @@ -297,10 +262,6 @@ STATUS SingleSwitchPass::Init() { } } - if (cond_partial_node_->primitive->value.type != PrimitiveType_Partial || - body_partial_node_->primitive->value.type != PrimitiveType_Partial) { - return RET_OK; - } // get cond_graph_nodes_ cond_subgraph_index_ = cond_partial_node_->primitive->value.AsPartial()->subGraphIndex; auto cond_node_indices = graph_->subGraph.at(cond_subgraph_index_)->nodeIndices; @@ -330,17 +291,36 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem return RET_INPUT_PARAM_INVALID; } auto &partial_inputs = partial_node->inputIndex; - auto &subgraph_inputs = graph_->subGraph.at(subgraph_index)->inputIndices; + auto &subgraph = graph_->subGraph.at(subgraph_index); + auto &subgraph_inputs = subgraph->inputIndices; std::map subgraph_input_map; - std::vector new_subgraph_inputs{}; + std::vector> tmp_inputs_order{}; for (unsigned int &subgraph_input : subgraph_inputs) { auto &tensor = graph_->allTensors.at(subgraph_input); - // get parameter input index k. subgraph name + “_input_" + "k" - char k = tensor->name[graph_->subGraph.at(subgraph_index)->name.size() + 7]; - int partial_idx = k - '0'; + if (tensor->name.size() < subgraph->name.size() + 8) { + MS_LOG(ERROR) << "tensor name: " << tensor->name << " not right."; + return RET_ERROR; + } + int partial_idx = -1; + if (tensor->name.find("_input_") != std::string::npos) { + // get parameter input index k. subgraph name + “_input_" + "k" + auto pos = subgraph->name.size() + sizeof("_input_"); + auto pos2 = tensor->name.find('_', pos); + auto idx_str = tensor->name.substr(pos - 1, pos2); + partial_idx = std::stoi(idx_str); + } + + if (tensor->name.find("_output_") != std::string::npos) { + // get parameter input index k. subgraph name + “_output_" + "k" + auto pos = subgraph->name.size() + sizeof("_output_"); + auto pos2 = tensor->name.find('_', pos); + auto idx_str = tensor->name.substr(pos - 1, pos2); + partial_idx = std::stoi(idx_str); + } + subgraph_input_map.insert(std::pair{subgraph_input, partial_inputs[partial_idx]}); - new_subgraph_inputs.push_back(partial_inputs[partial_idx]); + tmp_inputs_order.emplace_back(partial_idx, partial_inputs[partial_idx]); } for (auto &subgraph_node : subgraph_nodes) { @@ -350,6 +330,13 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem } } } + + std::sort(tmp_inputs_order.begin(), tmp_inputs_order.end(), + [](std::pair a, std::pair b) { return a.first < b.first; }); + + std::vector new_subgraph_inputs{}; + std::transform(tmp_inputs_order.begin(), tmp_inputs_order.end(), std::back_inserter(new_subgraph_inputs), + [](std::pair iter) { return iter.second; }); subgraph_inputs.assign(new_subgraph_inputs.begin(), new_subgraph_inputs.end()); return RET_OK; @@ -362,17 +349,28 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche return RET_INPUT_PARAM_INVALID; } auto &partial_outputs = partial_node->outputIndex; - auto &subgraph_outputs = graph_->subGraph.at(subgraph_index)->outputIndices; + auto &subgraph = graph_->subGraph.at(subgraph_index); + auto &subgraph_outputs = subgraph->outputIndices; std::map subgraph_output_map; - std::vector new_subgraph_outputs{}; + std::vector> tmp_outputs_order{}; for (unsigned int &subgraph_output : subgraph_outputs) { - auto &tensor = graph_->allTensors.at(subgraph_output); - // get parameter input index k. subgraph name + “_output_" + "k" - char k = tensor->name[graph_->subGraph.at(subgraph_index)->name.size() + 8]; - int partial_idx = k - '0'; - subgraph_output_map.insert(std::pair{subgraph_output, partial_outputs[partial_idx]}); - new_subgraph_outputs.push_back(partial_outputs[partial_idx]); + for (auto &node : subgraph_nodes) { + if (IsContain(node->outputIndex, subgraph_output)) { + int partial_idx = -1; + if (node->name == "LogicalAnd") { + partial_idx = 0; + } else { + // get parameter input index k. subgraph name + “_output_" + "k" + auto pos = subgraph->name.size() + sizeof("_output_"); + auto pos2 = node->name.find('_', pos); + auto idx_str = node->name.substr(pos - 1, pos2); + partial_idx = std::stoi(idx_str); + } + subgraph_output_map.insert(std::pair{subgraph_output, partial_outputs[partial_idx]}); + tmp_outputs_order.emplace_back(partial_idx, partial_outputs[partial_idx]); + } + } } for (auto &subgraph_node : subgraph_nodes) { @@ -382,6 +380,10 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche } } } + + std::vector new_subgraph_outputs{}; + std::transform(tmp_outputs_order.begin(), tmp_outputs_order.end(), std::back_inserter(new_subgraph_outputs), + [](std::pair iter) { return iter.second; }); subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end()); return RET_OK; @@ -416,102 +418,6 @@ STATUS SingleSwitchPass::ConcatBodySubgraphInputAndOutput() { return ret; } -STATUS SingleSwitchPass::ConvertSwitchToSelect() { - MS_ASSERT(switch_node_->inputIndex.size() >= 3); - MS_ASSERT(switch_node_->inputIndex.size() % 2 != 0); - MS_ASSERT(switch_node_->outputIndex.size() * 2 + 1 == switch_node_->inputIndex.size()); - auto bool_index = switch_node_->inputIndex.front(); - - // insert switch node1 - auto switch_node1 = std::make_unique(); - switch_node1->name = switch_node_->name + "-Switch-1"; - switch_node1->primitive = std::make_unique(); - switch_node1->primitive->value.type = PrimitiveType_Switch; - switch_node1->primitive->value.value = new (std::nothrow) SwitchT(); - switch_node1->inputIndex = {bool_index}; - std::vector part_one_input_index( - switch_node_->inputIndex.begin() + 1, - switch_node_->inputIndex.begin() + 1 + (switch_node_->inputIndex.size() - 1) / 2); - switch_node1->inputIndex.insert(switch_node1->inputIndex.end(), part_one_input_index.begin(), - part_one_input_index.end()); - std::vector> switch_output_tensors1(part_one_input_index.size() * 2); - std::vector switch_output_indexes1(part_one_input_index.size() * 2); - int i = 0; - for (const auto &input_index : part_one_input_index) { - auto &switch_in_tensor = graph_->allTensors.at(input_index); - auto tensor1 = NewTensor(switch_in_tensor); - auto tensor2 = NewTensor(switch_in_tensor); - switch_output_tensors1[i] = std::move(tensor1); - switch_output_tensors1[part_one_input_index.size() + i] = std::move(tensor2); - switch_output_indexes1[i] = graph_->allTensors.size() - 1 + i; - switch_output_indexes1[part_one_input_index.size() + i] = - graph_->allTensors.size() - 1 + i + part_one_input_index.size(); - i++; - } - for (auto &tensor : switch_output_tensors1) { - graph_->allTensors.emplace_back(std::move(tensor)); - } - switch_node1->outputIndex.insert(switch_node1->outputIndex.begin(), switch_output_indexes1.begin(), - switch_output_indexes1.end()); - - // insert switch node2 - auto switch_node2 = std::make_unique(); - switch_node2->name = switch_node_->name + "-Switch-1"; - switch_node2->primitive = std::make_unique(); - switch_node2->primitive->value.type = PrimitiveType_Switch; - switch_node2->primitive->value.value = new (std::nothrow) SwitchT(); - switch_node2->inputIndex = {bool_index}; - - std::vector part_two_input_index( - switch_node_->inputIndex.begin() + 1 + (switch_node_->inputIndex.size() - 1) / 2, switch_node_->inputIndex.end()); - switch_node2->inputIndex.insert(switch_node2->inputIndex.end(), part_two_input_index.begin(), - part_two_input_index.end()); - std::vector> switch_output_tensors2(part_two_input_index.size() * 2); - std::vector switch_output_indexes2(part_two_input_index.size() * 2); - i = 0; - for (const auto &input_index : part_two_input_index) { - auto &switch_in_tensor = graph_->allTensors.at(input_index); - auto tensor1 = NewTensor(switch_in_tensor); - auto tensor2 = NewTensor(switch_in_tensor); - switch_output_tensors2[i] = std::move(tensor1); - switch_output_tensors2[part_two_input_index.size() + i] = std::move(tensor2); - switch_output_indexes2[i] = graph_->allTensors.size() - 1 + i; - switch_output_indexes2[part_two_input_index.size() + i] = - graph_->allTensors.size() - 1 + i + part_two_input_index.size(); - i++; - } - for (auto &tensor : switch_output_tensors2) { - graph_->allTensors.emplace_back(std::move(tensor)); - } - switch_node2->outputIndex.insert(switch_node2->outputIndex.begin(), switch_output_indexes2.begin(), - switch_output_indexes2.end()); - - // insert merge - auto merge_node = std::make_unique(); - merge_node->name = switch_node_->name + "-Merge"; - merge_node->primitive = std::make_unique(); - merge_node->primitive->value.type = PrimitiveType_Merge; - merge_node->primitive->value.value = new (std::nothrow) MergeT(); - - std::vector merge_input_indexes(switch_node_->outputIndex.size() * 2); - for (i = 0; i < switch_node_->outputIndex.size(); i++) { - merge_input_indexes[i] = switch_output_indexes1[i]; - merge_input_indexes[i + switch_node_->outputIndex.size()] = - switch_output_indexes2[i + switch_node_->outputIndex.size()]; - merge_node->outputIndex.emplace_back(switch_node_->outputIndex.at(i)); - } - merge_node->inputIndex.insert(merge_node->inputIndex.end(), merge_input_indexes.begin(), merge_input_indexes.end()); - graph_->nodes.emplace_back(std::move(switch_node1)); - graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1); - graph_->nodes.emplace_back(std::move(switch_node2)); - graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1); - graph_->nodes.emplace_back(std::move(merge_node)); - graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1); - - RemoveUselessNode(switch_node_, graph_); - return RET_OK; -} - STATUS SingleSwitchPass::Run() { int ret = Init(); if (ret != RET_OK) { @@ -519,24 +425,6 @@ STATUS SingleSwitchPass::Run() { return ret; } - if (switch_node_->inputIndex.size() == kSwitchMinInputSize) { - return RET_OK; - } - - if (cond_partial_node_->primitive->value.type != PrimitiveType_Partial || - body_partial_node_->primitive->value.type != PrimitiveType_Partial) { - ret = ConvertSwitchToSelect(); - return ret; - } - - if (IsLoop()) { - ret = MoveMaxIterationToCond(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "MoveMaxIterationToCond failed, ret: " << ret; - return ret; - } - } - ret = DoubleSwitchOutput(); if (ret != RET_OK) { MS_LOG(ERROR) << "DoubleSwitchOutput failed, ret: " << ret; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h index c349e49809..12360701c9 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h @@ -45,11 +45,9 @@ class SingleSwitchPass { STATUS Init(); size_t InitThisGraphIndex(); STATUS DoubleSwitchOutput(); - STATUS MoveMaxIterationToCond(); STATUS UpdateSwitchUser(); STATUS ConcatCondSubgraphInputAndOutput(); STATUS ConcatBodySubgraphInputAndOutput(); - STATUS ConvertSwitchToSelect(); bool IsLoop(); STATUS InsertMerge(); STATUS UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node, diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc index 33fd1e7165..39b8096b38 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/topological_sort_pass.cc @@ -27,56 +27,71 @@ namespace mindspore { namespace lite { STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); - std::vector> newNodes; - std::vector sinkedTensorIdxes; - // put all const tensor index into sinkedTensorIdxes + std::vector> new_nodes; + std::vector sinked_tensor_idxes; + // put all const tensor index into sinked_tensor_idxes for (size_t i = 0; i < graph->allTensors.size(); i++) { if (graph->allTensors.at(i)->nodeType == schema::NodeType::NodeType_ValueNode) { - sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), i); + sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), i); } } - auto &oldNodes = graph->nodes; - std::queue> opQueue; - // put all non depend node into queue - for (auto &node : graph->nodes) { - if (IsNodeNonDepend(node, sinkedTensorIdxes)) { - sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), node->outputIndex.begin(), node->outputIndex.end()); - opQueue.push(std::move(node)); + auto &old_nodes = graph->nodes; + std::queue> op_queue; + // put all none depend node into queue + for (size_t i = 0; i < graph->subGraph.size(); i++) { + std::vector new_subgraph_node_indices = {}; + auto subgraph_node_indices = graph->subGraph[i]->nodeIndices; + + for (size_t j = 0; j < subgraph_node_indices.size(); j++) { + auto &node = old_nodes[subgraph_node_indices[j]]; + if (IsNodeNonDepend(node, sinked_tensor_idxes)) { + sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end()); + op_queue.push(std::move(node)); + } } - } - // bfs - while (!opQueue.empty()) { - auto &node = opQueue.front(); - auto postNodeIdxes = GetOutputNodeIdx(*graph, *(node.get())); - for (auto postNodeIdx : postNodeIdxes) { - auto &postNode = oldNodes.at(postNodeIdx); - // check if postNode is non-depended - if (IsNodeNonDepend(postNode, sinkedTensorIdxes)) { - sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), postNode->outputIndex.begin(), postNode->outputIndex.end()); - opQueue.push(std::move(postNode)); + while (!op_queue.empty()) { + auto &node = op_queue.front(); + auto post_node_idxes = GetOutputNodeIdx(*graph, *(node.get())); + for (auto post_node_idx : post_node_idxes) { + if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) { + auto &post_node = old_nodes.at(post_node_idx); + // check if post_node is non-depended + if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) { + sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), post_node->outputIndex.begin(), + post_node->outputIndex.end()); + op_queue.push(std::move(post_node)); + } + } } + new_nodes.emplace_back(std::move(node)); + new_subgraph_node_indices.push_back(new_nodes.size() - 1); + op_queue.pop(); } - newNodes.emplace_back(std::move(node)); - opQueue.pop(); + graph->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices); } - if (newNodes.size() != oldNodes.size()) { - MS_LOG(ERROR) << "Unknow error in TopologicalSort, oldNodesSize: " << oldNodes.size() - << ", newNodesSize: " << newNodes.size(); + if (new_nodes.size() != old_nodes.size()) { + MS_LOG(ERROR) << "Unknow error in TopologicalSort, old_nodes size: " << old_nodes.size() + << ", new_nodes size: " << new_nodes.size(); return RET_ERROR; } - graph->nodes.swap(newNodes); + graph->nodes.swap(new_nodes); return RET_OK; } bool TopologicalSortPass::IsNodeNonDepend(const std::unique_ptr &node, - const std::vector &sinkedTensorIdxes) { + const std::vector &sinked_tensor_idxes) { MS_ASSERT(node != nullptr); - for (auto inputIdx : node->inputIndex) { - if (!IsContain(sinkedTensorIdxes, size_t(inputIdx))) { - return false; - } + if (node->primitive->value.type == schema::PrimitiveType_Merge) { + auto node_input_index = node->inputIndex; + MS_ASSERT(node_input_index.size() % 2 == 0); + return std::all_of(node_input_index.begin(), node_input_index.begin() + node_input_index.size() / 2, + [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); }) || + std::all_of(node_input_index.begin() + node_input_index.size() / 2, node_input_index.end(), + [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); }); + } else { + return std::all_of(node->inputIndex.begin(), node->inputIndex.end(), + [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); }); } - return true; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 5751bb4c20..b45ad25350 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -54,6 +54,7 @@ FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::s ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } + func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); return func_graph_ptr_; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 562f5d8ee1..53bd4ad988 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -80,6 +80,7 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st MS_LOG(ERROR) << "convert graph outputs failed."; return nullptr; } + func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); return func_graph_ptr_; } diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc index 3c2fe18261..3cf4e2da4d 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc @@ -61,7 +61,7 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, } primitive->value.type = schema::PrimitiveType_Mul; primitive->value.value = attr.release(); - } else if (tf_op.op() == "Div") { + } else if (tf_op.op() == "Div" || tf_op.op() == "RealDiv") { auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new attr failed"; @@ -154,6 +154,7 @@ TFNodeRegistrar g_tfAddV2Parser("AddV2", new TFArithmeticParser()); TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser()); TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser()); TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser()); +TFNodeRegistrar g_tfRealDivParser("RealDiv", new TFArithmeticParser()); TFNodeRegistrar g_tfMaximumParser("Maximum", new TFArithmeticParser()); TFNodeRegistrar g_tfMinimumParser("Minimum", new TFArithmeticParser()); TFNodeRegistrar g_tfGreaterParser("Greater", new TFArithmeticParser()); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index ce17c46491..add5e2314b 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -37,10 +37,11 @@ static const std::vector tensorListOutputOpList = { AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map &anf_node_map) { AnfNodePtr ret = nullptr; - if (anf_node_map.find(name) != anf_node_map.end()) { - ret = anf_node_map.at(name); + auto flat_anf_name = TensorFlowUtils::GetFlattenNodeName(name); + if (anf_node_map.find(flat_anf_name) != anf_node_map.end()) { + ret = anf_node_map.at(flat_anf_name); } else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { - ret = anf_node_map.at(name + ":0"); + ret = anf_node_map.at(flat_anf_name + ":0"); } return ret; } @@ -212,6 +213,17 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value if (status != RET_OK) { return status; } + } else if (type == kObjectTypeString) { + auto tensor_data = new (std::nothrow) string; + if (tensor_proto.string_val_size() == 1) { + string value = tensor_proto.string_val(0); + *tensor_data = value; + } else { + MS_LOG(ERROR) << "string size bigger than one, not support."; + return RET_ERROR; + } + tensor_size = (*tensor_data).size(); + param_value->SetTensorData(tensor_data, tensor_size); } else { MS_LOG(ERROR) << "Unsupport dataType: " << type; return RET_ERROR; @@ -318,6 +330,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } + anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); for (int i = 0; i < tf_root_graph_->node_size(); i++) { auto &node_def = tf_root_graph_->node(i); @@ -364,7 +377,6 @@ STATUS TFModelParser::ConvertSubgraph() { std::map while_cond_map; std::map while_body_map; for (int i = 0; i < subgraph_size; i++) { - std::vector sub_graph_inputs; auto &tf_sub_fuction = graph_def_liarary.function(i); auto &tf_sub_signature = tf_sub_fuction.signature(); auto input_arg_size = tf_sub_signature.input_arg_size(); @@ -381,13 +393,17 @@ STATUS TFModelParser::ConvertSubgraph() { } FuncGraphPtr sub_func_graph = std::make_shared(); + sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name)); std::unordered_map anf_sub_node_map; // convert sub graph inputs + std::vector sub_graph_inputs; for (int j = 0; j < input_arg_size; j++) { auto &input_arg = tf_sub_signature.input_arg(j); auto paramter = sub_func_graph->add_parameter(); paramter->set_name(input_arg.name()); anf_sub_node_map[input_arg.name()] = paramter; + auto root_while_inputs = while_cnode->inputs(); + paramter->set_abstract(root_while_inputs[j + 1]->abstract()); sub_graph_inputs.emplace_back(paramter); } std::map tf_sub_node_map; @@ -452,8 +468,19 @@ STATUS TFModelParser::ConvertSubgraph() { } // hardcode subgraph inputs name for (size_t j = 0; j < sub_graph_inputs.size(); j++) { - sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter"); + sub_graph_inputs[j]->set_name(sub_graph_name + "_input_" + std::to_string(j) + "_parameter"); + } + // hardcode subgraph outputs name + for (size_t j = 1; j < sub_output_nodes.size(); j++) { + if (utils::isa(sub_output_nodes[j])) { + sub_output_nodes[j]->cast()->set_fullname_with_scope(sub_graph_name + "_output_" + + std::to_string(j - 1) + "_cnode"); + } else if (utils::isa(sub_output_nodes[j])) { + sub_output_nodes[j]->cast()->set_name(sub_graph_name + "_output_" + std::to_string(j - 1) + + "_parameter"); + } } + MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; } auto status = WhileNodePostProcess(while_cond_map, while_body_map); @@ -469,9 +496,8 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map roots = {anf_root_graph_}; - auto root_func_manager = std::make_shared(roots); - anf_root_graph_->set_manager(root_func_manager); + static auto root_func_manager = Manage(anf_root_graph_); + for (auto &kv : while_cond_map) { auto while_node = kv.first; auto &cond_sub_graph = kv.second; @@ -633,6 +659,11 @@ STATUS TFModelParser::ConvertRootGraphOutputs() { for (auto &pair : tf_root_graph_nodes_) { for (int i = 0; i < pair.second->input_size(); ++i) { all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i))); + auto input_name = pair.second->input(i); + if (input_name[0] == '^') { + input_name.erase(0, 1); + } + all_node_inputs.insert(input_name); } } for (auto &pair : tf_root_graph_nodes_) { @@ -644,7 +675,7 @@ STATUS TFModelParser::ConvertRootGraphOutputs() { auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); if (anf_node == nullptr) { - MS_LOG(ERROR) << "can't find anf node"; + MS_LOG(ERROR) << "can't find anf node: " << origin_name; return RET_ERROR; } output_nodes.push_back(anf_node); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_ragged_range_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_ragged_range_parser.cc new file mode 100644 index 0000000000..8ec3c92734 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_ragged_range_parser.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_ragged_range_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFRaggedRangeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF RaggedRangeParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "starts", &attr_value)) { + MS_LOG(ERROR) << "The starts attr should be specified"; + return RET_ERROR; + } + attr->start = static_cast(attr_value.i()); + + if (!TensorFlowUtils::FindAttrValue(tf_op, "limits", &attr_value)) { + MS_LOG(ERROR) << "The limits attr should be specified"; + return RET_ERROR; + } + attr->limit = static_cast(attr_value.i()); + + if (!TensorFlowUtils::FindAttrValue(tf_op, "deltas", &attr_value)) { + MS_LOG(ERROR) << "The deltas attr should be specified"; + return RET_ERROR; + } + attr->delta = static_cast(attr_value.i()); + + primitive->value.type = schema::PrimitiveType_Range; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfRaggedRangeParser("RaggedRange", new TFRaggedRangeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_ragged_range_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_ragged_range_parser.h new file mode 100644 index 0000000000..be1bbf888e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_ragged_range_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RAGFED_RANGE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RAGGED_RANGE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFRaggedRangeParser : public TFNodeParser { + public: + TFRaggedRangeParser() = default; + ~TFRaggedRangeParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_range_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_range_parser.cc new file mode 100644 index 0000000000..85cbbec691 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_range_parser.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_range_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFRangeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF RangeParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "start", &attr_value)) { + MS_LOG(ERROR) << "The start attr should be specified"; + return RET_ERROR; + } + attr->start = static_cast(attr_value.i()); + + if (!TensorFlowUtils::FindAttrValue(tf_op, "limit", &attr_value)) { + MS_LOG(ERROR) << "The limit attr should be specified"; + return RET_ERROR; + } + attr->limit = static_cast(attr_value.i()); + + if (!TensorFlowUtils::FindAttrValue(tf_op, "delta", &attr_value)) { + MS_LOG(ERROR) << "The delta attr should be specified"; + return RET_ERROR; + } + attr->delta = static_cast(attr_value.i()); + + primitive->value.type = schema::PrimitiveType_Range; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfRangeParser("Range", new TFRangeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_range_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_range_parser.h new file mode 100644 index 0000000000..bf62cd0271 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_range_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANGE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANGE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFRangeParser : public TFNodeParser { + public: + TFRangeParser() = default; + ~TFRangeParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc index fa7c4db27a..4fda58a80a 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc @@ -9,7 +9,7 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WRRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.cc b/mindspore/lite/tools/converter/parser/tf/tf_util.cc index 791c2468af..145b9f5f89 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_util.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.cc @@ -18,8 +18,8 @@ #include #include #include -#include #include +#include #include "src/common/log_adapter.h" #include "schema/inner/model_generated.h" diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index c135034b97..80d323de2e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -76,6 +76,7 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std:: ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } + func_graph_->set_attr("graph_name", MakeValue("main_graph")); return func_graph_; } diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index ea14116151..d01ff04fa5 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -21,7 +21,22 @@ #include "mindspore/lite/src/ops/primitive_c.h" #include "tools/anf_importer/import_from_meta_graphT.h" +using mindspore::lite::RET_INFER_INVALID; + namespace mindspore::opt { + +ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) { + auto para_value_lite = std::make_shared(); + if (para_value_lite == nullptr) { + MS_LOG(ERROR) << "new ParamValueLite failed"; + return nullptr; + } + para_value_lite->set_tensor_shape(tensor->shape()); + para_value_lite->set_tensor_type(tensor->data_type()); + para_value_lite->set_format(tensor->format()); + return para_value_lite; +} + abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { MS_ASSERT(nullptr != tensor); std::vector shape(tensor->shape()); @@ -33,15 +48,30 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li MS_LOG(ERROR) << "new AbstractTensor failed"; return nullptr; } - auto new_value = std::make_shared(); - if (new_value == nullptr) { + + auto para_value_lite = NewParamValueLitePtr(tensor); + if (para_value_lite == nullptr) { MS_LOG(ERROR) << "new ParamValueLite failed"; return nullptr; } - new_value->set_tensor_shape(tensor->shape()); - new_value->set_tensor_type(tensor->data_type()); - new_value->set_format(tensor->format()); - new_abstract->set_value(new_value); + + if (type_id == kObjectTypeTensorType) { + auto tensor_list = dynamic_cast(tensor); + if (tensor_list == nullptr) { + MS_LOG(ERROR) << "cast tensor_list failed"; + return nullptr; + } + auto tensor_info = new int[tensor_list->element_shape().size() + 2]; + tensor_info[0] = tensor_list->tensors_data_type(); + tensor_info[1] = tensor_list->element_shape().size(); + for (size_t i = 0; i < tensor_list->element_shape().size(); ++i) { + tensor_info[i + 2] = tensor_list->element_shape()[i]; + } + para_value_lite->set_tensor_addr(tensor_info); + para_value_lite->set_tensor_size(tensor_list->element_shape().size() + 2); + } + + new_abstract->set_value(para_value_lite); return new_abstract; } @@ -121,13 +151,13 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector(cnode->input(i))) { - MS_LOG(WARNING) << "input is value node"; + MS_LOG(WARNING) << cnode->fullname_with_scope() << "'s input[" << i << "] is value node"; continue; } AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, i); if (abstract == nullptr) { - MS_LOG(ERROR) << "Abstract of CNode is nullptr"; + MS_LOG(ERROR) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr"; return RET_ERROR; } if (!utils::isa(abstract)) { @@ -194,7 +224,7 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector< MS_ASSERT(output_tensors != nullptr); auto abstract = cnode->abstract(); if (abstract == nullptr) { - MS_LOG(ERROR) << "abstract is nullptr"; + MS_LOG(ERROR) << "node " << cnode->fullname_with_scope() << " abstract is nullptr"; return RET_ERROR; } std::vector types; @@ -264,7 +294,62 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector &outpu return RET_OK; } +int InferShapePass::StrIsContain(const std::vector &total, const std::string &aim) { + for (size_t i = 0; i < total.size(); i++) { + if (aim.find(total[i]) != std::string::npos) { + return i; + } + } + return -1; +} + +STATUS InferShapePass::SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph) { + // hard code construct input parameter name + std::vector inputs_names{}; + for (size_t i = 1; i < cnode->inputs().size(); i++) { + inputs_names.emplace_back("_input_" + std::to_string(i - 1) + "_parameter"); + } + // copy cnode input to func_graph input + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (utils::isa(node)) { + auto pos = StrIsContain(inputs_names, node->fullname_with_scope()); + if (pos != -1) { + auto pnode = utils::cast(node); + auto input_pnode = utils::cast(cnode->input(pos + 1)); + MS_ASSERT(pnode != nullptr); + pnode->set_abstract(input_pnode->abstract()); + } + } + } + return RET_OK; +} + +STATUS InferShapePass::SwitchCNodeInferShape(const CNodePtr &switch_cnode) { + auto body_partial_cnode = switch_cnode->input(2)->cast(); + MS_ASSERT(body_partial_cnode != nullptr); + auto body_vnode = body_partial_cnode->input(0)->cast(); + MS_ASSERT(body_vnode != nullptr); + auto body_fg = GetValueNode(body_vnode); + MS_ASSERT(body_fg != nullptr); + AbstractBasePtrList abstract_list; + auto body_fg_output_cnode = utils::cast(body_fg->output()); + for (auto &cnode : body_fg_output_cnode->inputs()) { + if (!utils::isa(cnode) && !utils::isa(cnode)) { + continue; + } + abstract_list.push_back(cnode->abstract()); + } + + switch_cnode->set_abstract(std::make_shared(abstract_list)); + return RET_OK; +} + bool InferShapePass::Run(const FuncGraphPtr &func_graph) { + if (func_graph->has_flag("HasInferShaped")) { + return true; + } + if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) { MS_LOG(INFO) << "The framework type of model should be tf/tflite."; return false; @@ -287,8 +372,14 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { auto cnode = node->cast(); auto origin_primc = GetValueNode>(cnode->input(0)); if (origin_primc == nullptr) { - MS_LOG(ERROR) << "origin_primc is nullptr"; - return false; + auto sub_func_graph = GetValueNode(cnode->input(0)); + if (sub_func_graph == nullptr) { + MS_LOG(ERROR) << "node " << node->fullname_with_scope() << "'s origin_primc is nullptr"; + return false; + } else { + MS_LOG(WARNING) << "subgraph infer shape invalid."; + return RET_INFER_INVALID; + } } auto origin_primt = origin_primc->primitiveT(); if (origin_primt == nullptr) { @@ -296,6 +387,15 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { return false; } auto type = GetCNodeType(cnode); + + if (type == schema::PrimitiveType_Switch) { + int ret = SwitchCNodeInferShape(cnode); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PartialCNodeInferShape failed."; + return false; + } + } + if ((type == schema::PrimitiveType_TupleGetItem) || #ifdef SUPPORT_TRAIN (type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) || diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.h b/mindspore/lite/tools/optimizer/graph/infershape_pass.h index 13644dc7d8..d0955079c9 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.h +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.h @@ -41,6 +41,9 @@ class InferShapePass : public Pass { STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector *output_tensors); STATUS SetParameterAbstract(const ParameterPtr ¶meter); STATUS SetCNodeAbstract(const std::vector &output_tensors, const std::shared_ptr &cnode); + STATUS SwitchCNodeInferShape(const CNodePtr &cnode); + int StrIsContain(const std::vector &total, const std::string &aim); + int SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph); private: FmkType fmk_type = lite::converter::FmkType_ONNX; diff --git a/mindspore/lite/tools/optimizer/graph/while_pass.cc b/mindspore/lite/tools/optimizer/graph/while_pass.cc new file mode 100644 index 0000000000..a568845fa1 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/while_pass.cc @@ -0,0 +1,181 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/while_pass.h" +#include +#include +#include +#include "mindspore/lite/include/errorcode.h" +#include "mindspore/lite/src/ops/primitive_c.h" +#include "tools/anf_importer/import_from_meta_graphT.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "src/ops/primitive_c.h" +#include "schema/inner/model_generated.h" +#include "src/tensor.h" +#include "src/common/log_adapter.h" +#include "src/ops/switch.h" +#include "src/ops/partial.h" + +namespace mindspore::opt { + +ValueNodePtr WhilePass::GetSwitchAnfPrim() { + auto switch_primitiveT = new (std::nothrow) schema::PrimitiveT; + if (switch_primitiveT == nullptr) { + MS_LOG(ERROR) << "new switch_primitiveT failed"; + return nullptr; + } + switch_primitiveT->value.type = schema::PrimitiveType_Switch; + switch_primitiveT->value.value = new (std::nothrow) schema::SwitchT; + if (switch_primitiveT->value.value == nullptr) { + MS_LOG(ERROR) << "new MakeTupleT failed"; + return nullptr; + } + + auto partial_prim = std::make_shared(switch_primitiveT); + ValueNodePtr partial_anf_prim = NewValueNode(partial_prim); + return partial_anf_prim; +} + +void WhilePass::ReplaceInput(const std::vector &node_list, AnfNodePtr new_input_cnode, + std::string para_name) { + for (auto &node : node_list) { + if (utils::isa(node)) { + auto cnode = utils::cast(node); + for (size_t k = 0; k < cnode->inputs().size(); k++) { + if (!utils::isa(cnode->input(k))) { + continue; + } + auto para_input = utils::cast(cnode->input(k)); + if (para_input->name() == para_name) { + cnode->set_input(k, new_input_cnode); + } + } + } + } +} + +bool WhilePass::Run(const FuncGraphPtr &graph) { + auto node_list = TopoSort(graph->get_return()); + static int count = 0; + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + if (opt::GetCNodeType(node) != schema::PrimitiveType_While) { + continue; + } + auto while_cnode = node->cast(); + MS_ASSERT(while_cnode != nullptr); + if (while_cnode->inputs().size() < kWhileMinInputSize) { + MS_LOG(ERROR) << "while input is not right."; + return false; + } + + // the order is fixed. + auto cond_vnode = while_cnode->input(kWhileCondIndex); + auto body_vnode = while_cnode->input(kWhileBodyIndex); + + // body_vnode->cast()->set_value() + auto cond_fg = GetValueNode>(cond_vnode); + auto body_fg = GetValueNode>(body_vnode); + + if (cond_fg == nullptr || body_fg == nullptr) { + MS_LOG(ERROR) << "Get value as func_graph failed."; + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED); + return false; + } + + // create cond partial cnode + std::vector cond_partial_op_inputs{cond_vnode}; + + // create body partial cnode + std::vector body_partial_op_inputs{body_vnode}; + + // add while op input to cond_cnode and body_cnode + cond_partial_op_inputs.insert(cond_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, + while_cnode->inputs().end()); + body_partial_op_inputs.insert(body_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, + while_cnode->inputs().end()); + + static int idx = 0; + auto cond_partial_node = graph->NewCNode(cond_partial_op_inputs); + cond_partial_node->set_fullname_with_scope("Partial-while-cond-" + std::to_string(idx)); + cond_partial_node->set_abstract(cond_fg->output()->abstract()); + + auto body_partial_node = graph->NewCNode(body_partial_op_inputs); + body_partial_node->set_fullname_with_scope("Partial-while-body-" + std::to_string(idx)); + idx++; + + // concat body_fg output to cond_fg input + auto body_output = body_fg->output(); + auto body_output_cnode = utils::cast(body_output); + auto prim = GetValueNode>(body_output_cnode->input(0)); + if (prim == nullptr) { + MS_LOG(ERROR) << "Get PrimitiveC of node:" << body_output_cnode->fullname_with_scope() << " failed."; + return false; + } + + // concat body to cond + std::vector body_to_cond_inputs{cond_vnode}; + if ((schema::PrimitiveType)(prim->Type()) == schema::PrimitiveType_MakeTuple) { + for (size_t i = 1; i < body_output_cnode->inputs().size(); ++i) { + body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); + } + } else { + body_to_cond_inputs.emplace_back(body_output_cnode->input(1)); + } + + // concat body to cond + auto body_to_cond_cnode = body_fg->NewCNode(body_to_cond_inputs); + body_to_cond_cnode->set_fullname_with_scope("Partial-while-body-to-cond"); + auto body_fg_manager = body_fg->manager(); + body_fg_manager->Replace(body_fg->output(), body_to_cond_cnode); + body_fg->set_output(body_to_cond_cnode); + body_partial_node->set_abstract(cond_fg->output()->abstract()); + + // create switch cnode + ValueNodePtr switch_anf_primitive = GetSwitchAnfPrim(); + if (switch_anf_primitive == nullptr) { + MS_LOG(ERROR) << "GetSwitchAnfPrim failed."; + return false; + } + + // insert switch node + std::vector switch_op_inputs = {switch_anf_primitive, cond_partial_node, body_partial_node}; + auto switch_cnode = graph->NewCNode(switch_op_inputs); + switch_cnode->set_fullname_with_scope("Switch-" + std::to_string(count++)); + + AbstractBasePtrList abstract_list; + auto body_fg_output_cnode = utils::cast(body_fg->output()); + for (auto &cnode : body_fg_output_cnode->inputs()) { + if (!utils::isa(cnode) && !utils::isa(cnode)) { + continue; + } + abstract_list.push_back(cnode->abstract()); + } + + switch_cnode->set_abstract(std::make_shared(abstract_list)); + + // create cond partial cnode + auto manager = graph->manager(); + auto node_users = manager->node_users()[while_cnode]; + for (auto &node_user : node_users) { + manager->SetEdge(node_user.first, node_user.second, switch_cnode); + } + } + + return true; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/while_pass.h b/mindspore/lite/tools/optimizer/graph/while_pass.h new file mode 100644 index 0000000000..e37e7eeb90 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/while_pass.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WHILE_PASS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WHILE_PASS_H_ +#include +#include +#include "schema/inner/model_generated.h" +#include "tools/converter/converter_flags.h" +#include "backend/optimizer/common/pass.h" +#include "src/param_value_lite.h" + +using mindspore::lite::converter::FmkType; +namespace mindspore::opt { +class WhilePass : public Pass { + public: + WhilePass() : Pass("while_pass") {} + ~WhilePass() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + void ReplaceInput(const std::vector &node_list, AnfNodePtr new_input_cnode, std::string para_name); + ValueNodePtr GetSwitchAnfPrim(); + + const size_t kWhileMinInputSize = 3; + const size_t kWhileCondIndex = 1; + const size_t kWhileBodyIndex = 2; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_