diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 1dc267e005..473e8b29ea 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -34,7 +34,6 @@ #include "src/tensor.h" #include "src/param_value_lite.h" #include "src/common/utils.h" -#include "ops/partial.h" #include "tools/common/graph_util.h" #include "src/ops/ops_utils.h" @@ -287,7 +286,6 @@ int AnfExporter::SetGraphInputIndex(const std::unique_ptr &m int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index, const std::unique_ptr &meta_graphT, - const std::unique_ptr &sub_graphT, schema::CNodeT *return_node) { MS_ASSERT(nullptr != meta_graphT); MS_ASSERT(nullptr != return_node); @@ -319,9 +317,15 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgrap return RET_OK; } +bool AnfExporter::HasExported(const FuncGraphPtr &func_graph) { + if (fg_subgraph_map_.find(func_graph) != fg_subgraph_map_.end()) { + return true; + } + return false; +} + int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT, - const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive, - const std::unique_ptr &sub_graphT) { + const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive) { int ret = RET_OK; auto cnodes = GetOrderedCNodes(func_graph); for (const auto &cnode : cnodes) { @@ -334,19 +338,18 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr>(partial_cnode->input(0)); primT = GetPrimitiveT(partial_cnode->input(0)); MS_ASSERT(primT != nullptr); - auto pos = fg_subgraph_map.find(fg); - if (pos != fg_subgraph_map.end()) { + auto pos = fg_subgraph_map_.find(fg); + if (pos != fg_subgraph_map_.end()) { MS_ASSERT(primT->value.AsPartialFusion() != nullptr); - primT->value.AsPartialFusion()->sub_graph_index = fg_subgraph_map.at(fg); + primT->value.AsPartialFusion()->sub_graph_index = fg_subgraph_map_.at(fg); } else { - size_t next_subgraph_index = fg_subgraph_map.size() + 1; - fg_subgraph_map.insert(std::pair{fg, next_subgraph_index}); + size_t next_subgraph_index = meta_graphT->subGraph.size(); MS_ASSERT(primT->value.AsPartialFusion() != nullptr); primT->value.AsPartialFusion()->sub_graph_index = next_subgraph_index; - ret = ExportSubgraph(fg, meta_graphT, next_subgraph_index, keep_graph, copy_primitive, cnode); + ret = ExportSubgraph(fg, meta_graphT, keep_graph, copy_primitive, cnode); if (ret != RET_OK) { MS_LOG(ERROR) << "ExportSubgraph failed"; - break; + return ret; } } } else { @@ -374,7 +377,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrname = mindspore::ops::kNameReturn; - ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get()); + ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, node.get()); if (ret != RET_OK) { MS_LOG(ERROR) << "SetOpOutputN failed"; break; @@ -398,26 +401,28 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrnodes.push_back(std::move(node)); - meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx++); + meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx_++); } return ret; } int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT, - const size_t &subgraph_index, bool keep_graph, bool copy_primitive, - const std::shared_ptr &partial_anode) { - int ret = RET_OK; + bool keep_graph, bool copy_primitive, const std::shared_ptr &partial_anode) { + if (HasExported(func_graph)) { + MS_LOG(INFO) << "Has been exported."; + return RET_OK; + } + meta_graphT->subGraph.emplace_back(std::make_unique()); - auto &sub_graphT = meta_graphT->subGraph.at(subgraph_index); + auto subgraph_index = meta_graphT->subGraph.size() - 1; + fg_subgraph_map_[func_graph] = subgraph_index; auto subgraph_name = func_graph->get_attr("graph_name"); MS_ASSERT(subgraph_name != nullptr); - sub_graphT->name = GetValue(subgraph_name); - auto fmk = func_graph->get_attr("fmk"); - MS_ASSERT(fmk != nullptr); - meta_graphT->fmkType = GetValue(fmk); + meta_graphT->subGraph.back()->name = GetValue(subgraph_name); - ret = Anf2Fb(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive, sub_graphT); + int ret = Anf2Fb(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive); if (ret != RET_OK) { + MS_LOG(ERROR) << "Anf2Fb failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); return ret; } @@ -441,11 +446,14 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, bool train_flag) { - static int subgraph_index = 0; - this->train_flag = train_flag; + this->train_flag_ = train_flag; auto meta_graphT = std::make_unique(); - int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive); + auto fmk = func_graph->get_attr("fmk"); + MS_ASSERT(fmk != nullptr); + meta_graphT->fmkType = GetValue(fmk); + int ret = ExportSubgraph(func_graph, meta_graphT, keep_graph, copy_primitive); if (ret != RET_OK) { + MS_LOG(ERROR) << "ExportSubgraph failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); return nullptr; } @@ -455,7 +463,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode) { MS_ASSERT(input_anode != nullptr && output_cnode != nullptr); auto input_name = input_anode->fullname_with_scope(); - if (this->train_flag) { + if (this->train_flag_) { bool found = false; if (node_id_map_.find(input_name) != node_id_map_.end()) { output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); @@ -618,7 +626,7 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr(value); }); (*paramTensor)->dims = dims; - if (train_flag && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1}; + if (train_flag_ && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1}; (*paramTensor)->nodeType = NodeType_ValueNode; auto data = value->cast(); (*paramTensor)->data.resize(data->Size()); @@ -742,7 +750,7 @@ int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::uniqu (*paramTensor)->dataType = valueLite->tensor_type(); (*paramTensor)->dims = valueLite->tensor_shape(); - if (train_flag && (*paramTensor)->dims.empty()) { + if (train_flag_ && (*paramTensor)->dims.empty()) { (*paramTensor)->dims = {1}; } @@ -765,7 +773,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano auto paramTensor = std::make_unique(); auto value = valueNode->value(); int ret = RET_OK; - if (train_flag) { + if (train_flag_) { paramTensor->name = valueNode->fullname_with_scope(); } if (value->isa()) { @@ -861,7 +869,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrnodeType = NodeType_CNode; fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); - if (train_flag) { + if (train_flag_) { std::string name = cnode_name + "_o:" + std::to_string(i); node_id_map_[name] = meta_graphT->allTensors.size(); meta_graphT->allTensors.emplace_back(msTensor); diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index be0cfa8d5e..56087c7a6a 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -30,7 +30,6 @@ using mindspore::ops::PrimitiveC; namespace mindspore::lite { -constexpr const int kPartialMinSize = 3; constexpr const int kMainGraphIndex = 0; class AnfExporter { @@ -74,28 +73,27 @@ class AnfExporter { const std::unique_ptr &meta_graphT); int SetGraphInputIndex(const std::unique_ptr &meta_graphT, const size_t &subgraph_index); int SetGraphoutputIndex(const CNodePtr &cnode, size_t subgraph_index, - const std::unique_ptr &meta_graphT, - const std::unique_ptr &sub_graphT, schema::CNodeT *return_node); + const std::unique_ptr &meta_graphT, schema::CNodeT *return_node); static int ConvertQuantParam(const std::unique_ptr &meta_graph, const std::shared_ptr &primitive, const std::unique_ptr &dst_node); int Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT, - const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive, - const std::unique_ptr &sub_graphT); + const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive); int ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT, - const size_t &subgraph_index, bool keep_graph, bool copy_primitive, - const std::shared_ptr &partial_anode = nullptr); + bool keep_graph, bool copy_primitive, const std::shared_ptr &partial_anode = nullptr); static ValueNodePtr GetPartialAnfPrim(); static CNodePtr CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr cnode); static std::vector GetSubgraphNodes(const std::unique_ptr &meta_graphT, const size_t &subgraph_index); + bool HasExported(const FuncGraphPtr &func_graph); private: std::map node_id_map_; std::vector graph_input_nodes_; - std::map fg_subgraph_map; - uint32_t node_idx = 0; - bool train_flag = false; + // The first item is FuncGraph which has been exported, the second item is the subgraph index in meta_graph + std::map fg_subgraph_map_; + uint32_t node_idx_ = 0; + bool train_flag_ = false; }; // by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT. // but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify