diff --git a/ge/graph/passes/data_pass.cc b/ge/graph/passes/data_pass.cc index 38688848..c7314efb 100644 --- a/ge/graph/passes/data_pass.cc +++ b/ge/graph/passes/data_pass.cc @@ -23,6 +23,139 @@ #include "register/op_registry.h" namespace ge { +namespace { +Status MappingSubgraphInput(const ComputeGraphPtr &graph, const std::function &input) { + for (const auto &node : graph->GetDirectNode()) { + if (node->GetType() != DATA) { + continue; + } + + int index = -1; + if (!AttrUtils::GetInt(node->GetOpDesc(), "index", index)) { + GELOGE(FAILED, "Failed to get index from data[%s]", node->GetName().c_str()); + return FAILED; + } + + int parent_index = input(index); + GELOGI("Generate subgraph input map for subgraph %s, data index %d, parent index %d", + graph->GetName().c_str(), index, parent_index); + if (!AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "Failed to set parent index for node %s", node->GetName().c_str()); + return FAILED; + } + } + + return SUCCESS; +} + +Status MappingSubgraphOutput(const ComputeGraphPtr &graph, const std::function &output) { + const auto &output_node = graph->FindFirstNodeMatchType(NETOUTPUT); + if (output_node == nullptr) { + return SUCCESS; + } + + const auto &op_desc = output_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (auto index = 0; index < op_desc->GetInputsSize(); ++index) { + int parent_index = output(index); + GELOGI("Generate subgraph output map for subgraph %s, index %d, parent index %d", + graph->GetName().c_str(), index, parent_index); + if (parent_index == -1) { + continue; + } + + GeTensorDescPtr tensor = op_desc->MutableInputDesc(index); + GE_CHECK_NOTNULL(tensor); + if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "Failed to set parent index for graph %s", graph->GetName().c_str()); + return FAILED; + } + } + + return SUCCESS; +} + +Status MappingSubgraphIndex(const ComputeGraphPtr &graph, + const std::function &input, + const std::function &output) { + GE_CHECK_NOTNULL(graph); + GE_CHECK_NOTNULL(input); + GE_CHECK_NOTNULL(output); + if (MappingSubgraphInput(graph, input) != SUCCESS) { + GELOGE(FAILED, "Failed to mapping subgraph input for graph: %s", graph->GetName().c_str()); + return FAILED; + } + + if (MappingSubgraphOutput(graph, output) != SUCCESS) { + GELOGE(FAILED, "Failed to mapping subgraph output for graph: %s", graph->GetName().c_str()); + return FAILED; + } + + return SUCCESS; +} + +Status ParseSubgraphPostFnCase(const string &subgraph_name, const ComputeGraphPtr &graph) { + return MappingSubgraphIndex(graph, + [](int data_index) { return data_index + 1; }, + [](int retval_index) { return retval_index; }); +} + +Status ParseSubgraphPostFnIf(const string &subgraph_name, const ComputeGraphPtr &graph) { + return MappingSubgraphIndex(graph, + [](int data_index) { return data_index + 1; }, + [](int retval_index) { return retval_index; }); +} + +Status ParseSubgraphPostFnWhile(const string &subgraph_name, const ComputeGraphPtr &graph) { + return MappingSubgraphIndex(graph, + [](int data_index) { return data_index; }, + [&](int retval_index) { return (subgraph_name == "cond") ? -1 : retval_index; }); +} + +Status ParseSubgraphPostFnFor(const string &subgraph_name, const ComputeGraphPtr &graph) { + return MappingSubgraphIndex(graph, + [](int data_index) { return (data_index == 0) ? 0 : data_index + 2; }, + [](int retval_index) { return retval_index; }); +} + +Status ParseSubgraphPostFnPartitionedCall(const string &subgraph_name, const ComputeGraphPtr &graph) { + return MappingSubgraphIndex(graph, + [](int data_index) { return data_index; }, + [](int retval_index) { return retval_index; }); +} +} + +Status DataPass::PostParseSubgraph(const ComputeGraphPtr &graph, const string &ir_name, const NodePtr &parent_node) { + using ParseSubgraphFunc = std::function; + const static std::map subgraph_handle = { + {FOR, ParseSubgraphPostFnFor}, + {CASE, ParseSubgraphPostFnCase}, + {IF, ParseSubgraphPostFnIf}, + {_IF, ParseSubgraphPostFnIf}, + {STATELESSIF, ParseSubgraphPostFnIf}, + {WHILE, ParseSubgraphPostFnWhile}, + {_WHILE, ParseSubgraphPostFnWhile}, + {STATELESSWHILE, ParseSubgraphPostFnWhile}, + {PARTITIONEDCALL, ParseSubgraphPostFnPartitionedCall}, + {STATEFULPARTITIONEDCALL, ParseSubgraphPostFnPartitionedCall} + }; + + auto post_func_it = subgraph_handle.find(parent_node->GetType()); + if (post_func_it == subgraph_handle.end()) { + GELOGE(FAILED, "The subgraph post func for node %s type %s is null.", + parent_node->GetName().c_str(), parent_node->GetType().c_str()); + return FAILED; + } + + if (post_func_it->second(ir_name, graph) != SUCCESS) { + GELOGE(FAILED, "Failed to post process subgraph %s on node %s type %s", + graph->GetName().c_str(), parent_node->GetName().c_str(), parent_node->GetType().c_str()); + return FAILED; + } + + return SUCCESS; +} + Status DataPass::Run(ComputeGraphPtr compute_graph) { GE_CHECK_NOTNULL(compute_graph); if (compute_graph->GetParentNode() == nullptr) { // for subgraph post process. @@ -63,21 +196,6 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) { node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); } - auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(parent_node->GetType()); - if (post_func == nullptr) { - GELOGW("The subgraph post func for node %s type %s is null.", - parent_node->GetName().c_str(), parent_node->GetType().c_str()); - return SUCCESS; - } - - auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); - auto ret = post_func(subgraph_name, graph); - if (ret != SUCCESS) { - GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", - graph.GetName().c_str(), parent_node->GetName().c_str(), parent_node->GetType().c_str()); - return FAILED; - } - - return SUCCESS; + return PostParseSubgraph(compute_graph, subgraph_name, parent_node); } } // namespace ge diff --git a/ge/graph/passes/data_pass.h b/ge/graph/passes/data_pass.h index bce2fd5a..519ae046 100644 --- a/ge/graph/passes/data_pass.h +++ b/ge/graph/passes/data_pass.h @@ -24,6 +24,9 @@ namespace ge { class DataPass : public GraphPass { public: Status Run(ge::ComputeGraphPtr graph); + + private: + Status PostParseSubgraph(const ComputeGraphPtr &graph, const string &ir_name, const NodePtr &parent_node); }; } // namespace ge diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index 0043dc8e..ccf3b24e 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -17,7 +17,6 @@ #include #include "common/auth/file_saver.h" -#include "common/ge/tbe_plugin_manager.h" #include "external/register/register_types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" @@ -163,8 +162,6 @@ graphStatus aclgrphBuildInitialize(std::map global_opt GELOGE(ret, "GE initialize failed!"); return GRAPH_FAILED; } - // for functional subgraph assign _parent_index. - TBEPluginManager::Instance().InitPreparation(global_options); } GELOGW("gelib has been initialized!"); return GRAPH_SUCCESS; @@ -172,7 +169,6 @@ graphStatus aclgrphBuildInitialize(std::map global_opt void aclgrphBuildFinalize() { if (ge::GELib::GetInstance() != nullptr && ge::GELib::GetInstance()->InitFlag()) { - (void)TBEPluginManager::Instance().Finalize(); (void)ge::GELib::GetInstance()->Finalize(); return; }