/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/for_pass.h" #include "common/ge/ge_util.h" #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" #include "graph/utils/node_utils.h" #include "graph/utils/op_desc_utils.h" namespace { const uint32_t kWhileIInputIndex = 0; const uint32_t kWhileAbsDeltaInputIndex = 1; const uint32_t kWhileRangeInputIndex = 2; const uint32_t kWhileStartInputIndex = 3; const uint32_t kWhileDeltaInputIndex = 4; const uint32_t kWhileDataInputIndex = 5; const uint32_t kSubgraphLoopVarInputIndex = 0; const uint32_t kSubgraphInputIndex = 1; const uint32_t kWhileOutputIndex = 5; const size_t kIDiffValue = 2; const std::string kAbs = "Abs"; } namespace ge { Status ForPass::Run(NodePtr &node) { if (node->GetType() != FOR) { GELOGD("no need for_pass for node %s.", node->GetName().c_str()); return SUCCESS; } GELOGI("Begin to transfer for_op to while_op, node:%s.", node->GetName().c_str()); ComputeGraphPtr graph = node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(graph); ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph); GE_CHECK_NOTNULL(root_graph); ForInfo for_info; GE_CHK_STATUS_RET(BuildForInfo(root_graph, node, for_info), "Build ForInfo failed, node:%s.", node->GetName().c_str()); WhileInfo while_info; GE_CHK_STATUS_RET(TranWhileInfo(graph, for_info, while_info), "Transfer WhileInfo from ForInfo failed, node:%s.", node->GetName().c_str()); ComputeGraphPtr cond_graph = BuildCondGraph(while_info); if ((cond_graph == nullptr) || (root_graph->AddSubgraph(cond_graph) != GRAPH_SUCCESS)) { GELOGE(FAILED, "Add while_cond_graph failed, node:%s.", node->GetName().c_str()); return FAILED; } ComputeGraphPtr body_graph = BuildBodyGraph(while_info); if ((body_graph == nullptr) || (root_graph->AddSubgraph(body_graph) != GRAPH_SUCCESS)) { GELOGE(FAILED, "Add while_body_graph failed, node:%s.", node->GetName().c_str()); return FAILED; } GE_CHK_STATUS_RET(UpdateForBodyInputMapping(while_info), "Update InputMapping for for-body-graph failed, node:%s.", node->GetName().c_str()); // for node has and only has one subgraph GE_CHECK_NOTNULL(node->GetOpDesc()); node->GetOpDesc()->RemoveSubgraphInstanceName(node->GetOpDesc()->GetSubgraphInstanceName(0)); GELOGI("Transfer for_op to while_op succ, node:%s.", node->GetName().c_str()); return IsolateAndDeleteNode(node, std::vector()); } /// /// @brief Build for_info /// @param [in] root_graph /// @param [in] node /// @param [out] for_info /// @return Status /// Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info) { GELOGI("Begin to build for_info for node %s.", node->GetName().c_str()); OutDataAnchorPtr start = FindInputWithIndex(node, FOR_START_INPUT); OutDataAnchorPtr limit = FindInputWithIndex(node, FOR_LIMIT_INPUT); OutDataAnchorPtr delta = FindInputWithIndex(node, FOR_DELTA_INPUT); if ((start == nullptr) || (limit == nullptr) || (delta == nullptr)) { GELOGE(FAILED, "BuildForInfo for %s failed: start/limit/delta is NULL.", node->GetName().c_str()); return FAILED; } std::vector data_inputs; std::vector> data_outputs; std::vector ctrl_inputs; std::vector ctrl_outputs; if (FindInputsAndOutputs(node, data_inputs, data_outputs, ctrl_inputs, ctrl_outputs) != SUCCESS) { GELOGE(FAILED, "BuildForInfo for %s failed: find inputs/outputs failed.", node->GetName().c_str()); return FAILED; } NodeUtils::UnlinkAll(*node); OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); // For node has and only has one sub_graph std::string for_body_name = op_desc->GetSubgraphInstanceName(0); if (for_body_name.empty()) { GELOGE(FAILED, "BuildForInfo for %s failed: sub_graph_name is empty.", node->GetName().c_str()); return FAILED; } ComputeGraphPtr for_body = root_graph->GetSubgraph(for_body_name); if (for_body == nullptr) { GELOGE(FAILED, "BuildForInfo for %s failed: for_body_graph is NULL.", node->GetName().c_str()); return FAILED; } for_info.for_node = node; for_info.start = start; for_info.limit = limit; for_info.delta = delta; for_info.body_name = for_body_name; for_info.for_body = for_body; for_info.data_inputs = std::move(data_inputs); for_info.data_outputs = std::move(data_outputs); for_info.ctrl_inputs = std::move(ctrl_inputs); for_info.ctrl_outputs = std::move(ctrl_outputs); GELOGI("Build for_info for node %s success.", node->GetName().c_str()); return SUCCESS; } /// /// @brief Find input with index for For node /// @param [in] node /// @param [in] index /// @return OutDataAnchorPtr /// OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index) { if (node == nullptr) { GELOGE(FAILED, "FindInputWithIndex failed: node is NULL."); return nullptr; } InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); if (in_data_anchor == nullptr) { GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index); return nullptr; } return in_data_anchor->GetPeerOutAnchor(); } /// /// @brief Find inputs / outputs for for node /// @param [in] node /// @param [out] data_inputs /// @param [out] data_outputs /// @param [out] ctrl_inputs /// @param [out] ctrl_outputs /// @return Status /// Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector &data_inputs, std::vector> &data_outputs, std::vector &ctrl_inputs, std::vector &ctrl_outputs) { GE_CHECK_NOTNULL(node); uint32_t input_data_num = node->GetAllInDataAnchorsSize(); for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) { InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); GE_CHECK_NOTNULL(in_data_anchor); data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor()); } for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { std::vector peer_in_data_anchors; for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { peer_in_data_anchors.emplace_back(peer_in_data_anchor); } data_outputs.emplace_back(peer_in_data_anchors); } InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor(); GE_CHECK_NOTNULL(in_ctrl_anchor); for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { ctrl_inputs.emplace_back(peer_out_ctrl_anchor); } OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor(); GE_CHECK_NOTNULL(out_ctrl_anchor); for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { ctrl_outputs.emplace_back(peer_in_ctrl_anchor); } return SUCCESS; } /// /// @brief Transfer while_info from for_info /// @param [in] graph /// @param [in] for_info /// @param [out] while_info /// @return Status /// Status ForPass::TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info) { std::string for_name = for_info.for_node->GetName(); GELOGI("Begin to transfer for_info to while_info, node:%s.", for_name.c_str()); std::string i_name = for_name + "_i"; NodePtr i_node = graph->AddNode(CreateConstDesc(i_name, 0)); if (i_node == nullptr) { GELOGE(FAILED, "TranWhileInfo failed: create i_node failed."); return FAILED; } AddRePassNode(i_node); std::string identity_name = i_name + "_Identity"; NodePtr identity_node = graph->AddNode(CreateOpDesc(identity_name, IDENTITY, true)); // Const node has and only has one output, Identity node has and only has one input if ((identity_node == nullptr) || (GraphUtils::AddEdge(i_node->GetOutDataAnchor(0), identity_node->GetInDataAnchor(0)) != GRAPH_SUCCESS)) { GELOGE(FAILED, "TranWhileInfo failed: Add data-edge %s:0->%s:0 failed.", i_name.c_str(), identity_name.c_str()); return FAILED; } AddRePassNode(identity_node); // Identity node has and only has one output OutDataAnchorPtr i_input = identity_node->GetOutDataAnchor(0); if (i_input == nullptr) { GELOGE(FAILED, "TranWhileInfo failed: i_input is NULL."); return FAILED; } OutDataAnchorPtr range_input = nullptr; OutDataAnchorPtr abs_delta_input = nullptr; if (CreateLoopInput(graph, for_info, range_input, abs_delta_input) != SUCCESS) { GELOGE(FAILED, "TranWhileInfo failed: create loop input failed."); return FAILED; } BuildWhileInfo(for_info, i_input, range_input, abs_delta_input, while_info); if (InsertWhileNode(graph, for_name + "_While", while_info) != SUCCESS) { GELOGE(FAILED, "TranWhileInfo failed: insert while node failed."); return FAILED; } GELOGI("Transfer for_info to while_info succ, for_node:%s, while_node:%s.", for_name.c_str(), while_info.while_node->GetName().c_str()); return SUCCESS; } /// /// @brief Create const op_desc /// @param [in] name /// @param [in] value /// @return OpDescPtr /// OpDescPtr ForPass::CreateConstDesc(const std::string &name, int32_t value) { OpDescPtr const_op_desc = MakeShared(name, CONSTANT); if (const_op_desc == nullptr) { GELOGE(FAILED, "Create op_desc failed, const:%s.", name.c_str()); return nullptr; } GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_INT32); GeTensorPtr const_value = MakeShared(data_desc, reinterpret_cast(&value), sizeof(int32_t)); if (const_value == nullptr) { GELOGE(FAILED, "Create tensor failed, const:%s.", name.c_str()); return nullptr; } if (!AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value)) { GELOGE(FAILED, "Set ATTR_NAME_WEIGHTS failed, const:%s.", name.c_str()); return nullptr; } if (const_op_desc->AddOutputDesc("y", data_desc) != GRAPH_SUCCESS) { GELOGE(FAILED, "Add output desc failed, const:%s.", name.c_str()); return nullptr; } return const_op_desc; } /// /// @brief Create loop node /// @param [in] graph /// @param [in] for_info /// @param [out] range_input /// @param [out] abs_delta_input /// @return Status /// Status ForPass::CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info, OutDataAnchorPtr &range_input, OutDataAnchorPtr &abs_delta_input) { std::string for_name = for_info.for_node->GetName(); GELOGD("Begin to create loop_count input, node:%s", for_name.c_str()); OutDataAnchorPtr start = for_info.start; OutDataAnchorPtr limit = for_info.limit; OutDataAnchorPtr delta = for_info.delta; std::string sub_name_0 = for_name + "_Sub_0"; std::string abs_name_0 = for_name + "_Abs_0"; std::string abs_name_1 = for_name + "_Abs_1"; // i * |delta| < |limit-start| PartialGraphBuilder graph_builder; graph_builder.SetOwnerGraph(graph) .AddExistNode(for_info.start->GetOwnerNode()) .AddExistNode(for_info.limit->GetOwnerNode()) .AddExistNode(for_info.delta->GetOwnerNode()) .AddNode(CreateOpDesc(sub_name_0, SUB, false)) .AddNode(CreateOpDesc(abs_name_0, kAbs, true)) .AddNode(CreateOpDesc(abs_name_1, kAbs, true)) .AddDataLink(delta->GetOwnerNode()->GetName(), delta->GetIdx(), abs_name_0, 0) .AddDataLink(limit->GetOwnerNode()->GetName(), limit->GetIdx(), sub_name_0, 0) .AddDataLink(start->GetOwnerNode()->GetName(), start->GetIdx(), sub_name_0, 1) .AddDataLink(sub_name_0, 0, abs_name_1, 0); graphStatus error_code = GRAPH_SUCCESS; std::string error_msg; if ((graph_builder.Build(error_code, error_msg) == nullptr) || (error_code != GRAPH_SUCCESS)) { GELOGE(FAILED, "Create loop_count node failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str()); return FAILED; } // Add repass_nodes for (auto &node : graph_builder.GetAllNodes()) { AddRePassNode(node); } NodePtr abs_delta_node = graph_builder.GetNode(abs_name_0); NodePtr loop_count_node = graph_builder.GetNode(abs_name_1); if ((abs_delta_node == nullptr) || (loop_count_node == nullptr)) { GELOGE(FAILED, "Create loop node failed: node is NULL."); return FAILED; } GELOGD("Create loop_range input succ, node:%s", for_name.c_str()); // abs_node has and only has one output abs_delta_input = abs_delta_node->GetOutDataAnchor(0); range_input = loop_count_node->GetOutDataAnchor(0); return SUCCESS; } /// /// @brief Create op_desc /// @param [in] name /// @param [in] type /// @param [in] io_equal_flag /// @return OpDescPtr /// OpDescPtr ForPass::CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag) { OpDescBuilder op_desc_builder(name, type); if (io_equal_flag) { op_desc_builder.AddInput("x") .AddOutput("y"); } else { op_desc_builder.AddInput("x1") .AddInput("x2") .AddOutput("y"); } return op_desc_builder.Build(); } /// /// @brief Build while-info /// @param [in] for_info /// @param [in] i_input /// @param [in] range_input /// @param [in] abs_delta_input /// @param [out] while_info /// @return void /// void ForPass::BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input, const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input, WhileInfo &while_info) { while_info.i = i_input; while_info.abs_delta = abs_delta_input; while_info.range = range_input; while_info.start = for_info.start; while_info.delta = for_info.delta; while_info.for_body_name = for_info.body_name; while_info.for_body = for_info.for_body; while_info.data_inputs.emplace_back(while_info.i); while_info.data_inputs.emplace_back(while_info.abs_delta); while_info.data_inputs.emplace_back(while_info.range); while_info.data_inputs.emplace_back(while_info.start); while_info.data_inputs.emplace_back(while_info.delta); for (auto &item : for_info.data_inputs) { while_info.data_inputs.emplace_back(item); } for (auto &item : for_info.data_outputs) { while_info.data_outputs.emplace_back(item); } for (auto &item : for_info.ctrl_inputs) { while_info.ctrl_inputs.emplace_back(item); } for (auto &item : for_info.ctrl_outputs) { while_info.ctrl_outputs.emplace_back(item); } } /// /// @brief Insert while_node /// @param [in] graph /// @param [in] name /// @param [in&out] while_info /// @return Status /// Status ForPass::InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info) { GELOGD("Begin to create while node, name:%s.", name.c_str()); size_t arg_num = while_info.data_inputs.size(); OpDescBuilder op_desc_builder(name, WHILE); OpDescPtr op_desc = op_desc_builder.AddDynamicInput("input", arg_num).AddDynamicOutput("output", arg_num).Build(); if (op_desc == nullptr) { GELOGE(FAILED, "Create while op_desc failed, name:%s.", name.c_str()); return FAILED; } NodePtr while_node = graph->AddNode(op_desc); if (while_node == nullptr) { GELOGE(FAILED, "Create while node failed, name:%s.", name.c_str()); return FAILED; } AddRePassNode(while_node); while_info.while_node = while_node; if (BuildWhileLink(while_info) != SUCCESS) { GELOGE(FAILED, "Build while link-edge failed, name:%s.", name.c_str()); return FAILED; } GELOGD("Create while node succ, name:%s.", name.c_str()); return SUCCESS; } /// /// @brief Build while link-edge /// @param [in] while_info /// @return Status /// Status ForPass::BuildWhileLink(const WhileInfo &while_info) { NodePtr while_node = while_info.while_node; GE_CHECK_NOTNULL(while_node); size_t input_num = while_info.data_inputs.size(); for (size_t i = 0; i < input_num; i++) { InDataAnchorPtr in_data_anchor = while_node->GetInDataAnchor(i); GE_CHECK_NOTNULL(in_data_anchor); OutDataAnchorPtr peer_out_anchor = while_info.data_inputs[i]; if (peer_out_anchor == nullptr) { continue; } GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_data_anchor), "Add data-edge %s:%d->%s:%zu failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(), while_node->GetName().c_str(), i); } size_t output_num = while_info.data_outputs.size(); for (size_t i = 0; i < output_num; i++) { OutDataAnchorPtr out_data_anchor = while_node->GetOutDataAnchor(static_cast(i + kWhileOutputIndex)); GE_CHECK_NOTNULL(out_data_anchor); for (auto &peer_in_anchor : while_info.data_outputs[i]) { GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_data_anchor, peer_in_anchor), "Add data-edge %s:%zu->%s:%d failed.", while_node->GetName().c_str(), i + kWhileOutputIndex, peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx()); } } InControlAnchorPtr in_ctrl_anchor = while_node->GetInControlAnchor(); GE_CHECK_NOTNULL(in_ctrl_anchor); for (auto &peer_out_anchor : while_info.ctrl_inputs) { GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_ctrl_anchor), "Add ctrl-edge %s->%s failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); } OutControlAnchorPtr out_ctrl_anchor = while_node->GetOutControlAnchor(); GE_CHECK_NOTNULL(out_ctrl_anchor); for (auto &peer_in_anchor : while_info.ctrl_outputs) { GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_ctrl_anchor, peer_in_anchor), "Add ctrl-edge %s->%s failed.", out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetOwnerNode()->GetName().c_str()); } return SUCCESS; } /// /// @brief Build cond_graph for while_node /// @param [in&out] while_info /// @return ComputeGraphPtr /// ComputeGraphPtr ForPass::BuildCondGraph(WhileInfo &while_info) { std::string cond_name = while_info.for_body_name + "_Cond"; CompleteGraphBuilder graph_builder(cond_name); // Add parent node graph_builder.SetParentNode(while_info.while_node); // Add Node const std::string mul_name = "Mul"; graph_builder.AddNode(CreateOpDesc(mul_name, MUL, false)); const std::string less_name = "Less"; graph_builder.AddNode(CreateOpDesc(less_name, LESS, false)); // Set Input graph_builder.SetInput(kWhileIInputIndex, { mul_name }, { 0 }) .SetInput(kWhileAbsDeltaInputIndex, { mul_name }, { 1 }) .SetInput(kWhileRangeInputIndex, { less_name }, { 1 }) .SetUselessInput(kWhileStartInputIndex) .SetUselessInput(kWhileDeltaInputIndex); size_t input_num = while_info.data_inputs.size(); for (size_t i = kWhileDataInputIndex; i < input_num; i++) { graph_builder.SetUselessInput(i); } // Add Output graph_builder.AddOutput(less_name, 0); // Add Edges graph_builder.AddDataLink(mul_name, 0, less_name, 0); // Add Input-Mapping std::map input_mapping; for (size_t i = 0; i < input_num; i++) { input_mapping[i] = i; } graph_builder.SetInputMapping(input_mapping); graphStatus error_code = GRAPH_SUCCESS; std::string error_msg; ComputeGraphPtr cond_graph = graph_builder.Build(error_code, error_msg); if (cond_graph == nullptr) { GELOGE(FAILED, "Build cond_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str()); return nullptr; } size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size(); while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_COND); while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, cond_name); while_info.while_cond = cond_graph; return cond_graph; } /// /// @brief Build body_graph for while_node /// @param [in&out] while_info /// @return ComputeGraphPtr /// ComputeGraphPtr ForPass::BuildBodyGraph(WhileInfo &while_info) { std::string body_name = while_info.for_body_name + "_Body"; CompleteGraphBuilder graph_builder(body_name); // Add parent node graph_builder.SetParentNode(while_info.while_node); // Add calculation nodes std::string const_name = "Const"; std::string add_name_0 = "Add_0"; std::string mul_name = "Mul"; std::string add_name_1 = "Add_1"; graph_builder.AddNode(CreateConstDesc(const_name, 1)) .AddNode(CreateOpDesc(add_name_0, ADD, false)) .AddNode(CreateOpDesc(mul_name, MUL, false)) .AddNode(CreateOpDesc(add_name_1, ADD, false)); // Add Subgraph node auto input_num = static_cast(while_info.data_inputs.size()); std::string sub_graph_node_name = while_info.for_body_name; uint32_t sub_graph_input_num = input_num - kWhileDataInputIndex + kSubgraphInputIndex; auto sub_graph_output_num = static_cast(while_info.data_outputs.size()); graph_builder.AddNode(CreateSubgraphOpDesc(sub_graph_node_name, sub_graph_input_num, sub_graph_output_num)); // Set Input graph_builder.SetInput(kWhileIInputIndex, { add_name_0, mul_name }, { 0, 0 }) .SetUselessInput(kWhileAbsDeltaInputIndex) .SetUselessInput(kWhileRangeInputIndex) .SetInput(kWhileStartInputIndex, { add_name_1 }, { 0 }) .SetInput(kWhileDeltaInputIndex, { mul_name }, { 1 }); for (uint32_t i = 0; i < input_num - kWhileDataInputIndex; i++) { graph_builder.SetInput(i + kWhileDataInputIndex, { sub_graph_node_name }, { i + kSubgraphInputIndex }); } // Add Outputs graph_builder.AddOutput(add_name_0, 0); for (uint32_t i = kWhileAbsDeltaInputIndex; i < kWhileDataInputIndex; i++) { graph_builder.AddOutput("Data_" + std::to_string(i), 0); } for (uint32_t i = 0; i < sub_graph_output_num; i++) { graph_builder.AddOutput(sub_graph_node_name, i); } // Add Edges graph_builder.AddDataLink(const_name, 0, add_name_0, 1) .AddDataLink(mul_name, 0, add_name_1, 1) .AddDataLink(add_name_1, 0, sub_graph_node_name, kSubgraphLoopVarInputIndex); // Add Input-Mapping std::map input_mapping; for (size_t i = 0; i < input_num; i++) { input_mapping[i] = i; } graph_builder.SetInputMapping(input_mapping); // Add outputMapping std::map output_mapping; for (size_t i = 0; i < sub_graph_output_num + kWhileOutputIndex; i++) { output_mapping[i] = i; } graph_builder.SetOutputMapping(output_mapping); graphStatus error_code = GRAPH_SUCCESS; std::string error_msg; ComputeGraphPtr body_graph = graph_builder.Build(error_code, error_msg); if (body_graph == nullptr) { GELOGE(FAILED, "Build body_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str()); return nullptr; } NodePtr sub_graph_node = graph_builder.GetNode(sub_graph_node_name); if (sub_graph_node == nullptr) { GELOGE(FAILED, "Get sub_graph_node failed: name:%s.", sub_graph_node_name.c_str()); return nullptr; } while_info.sub_graph_node = sub_graph_node; size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size(); while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_BODY); while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, body_name); while_info.while_body = body_graph; return body_graph; } /// /// @brief Create op_desc for subgraph node /// @param [in] name /// @param [in] input_num /// @param [in] output_num /// @return OpDescPtr /// OpDescPtr ForPass::CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num) { OpDescBuilder op_desc_builder(name, PARTITIONEDCALL); op_desc_builder.AddDynamicInput("args", input_num) .AddDynamicOutput("output", output_num); OpDescPtr op_desc = op_desc_builder.Build(); if (op_desc == nullptr) { GELOGE(FAILED, "Create op_desc for subgraph node failed, name:%s.", name.c_str()); return nullptr; } size_t index = op_desc->GetSubgraphInstanceNames().size(); op_desc->AddSubgraphName("f"); op_desc->SetSubgraphInstanceName(index, name); return op_desc; } /// /// @brief Update InputMapping for for-body-graph /// @param [in] while_info /// @return Status /// Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) { ComputeGraphPtr for_body = while_info.for_body; GE_CHECK_NOTNULL(for_body); // index_of_cur_graph_node_input -> index_of_new_graph_node_input std::map input_mapping; size_t input_num = while_info.data_inputs.size() - kWhileDataInputIndex + FOR_DATA_INPUT; for (size_t i = 0; i < input_num; i++) { if (i == FOR_START_INPUT) { input_mapping[i] = i; } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) { continue; } else { input_mapping[i] = i - kIDiffValue; } } for_body->UpdateInputMapping(input_mapping); for_body->SetParentNode(while_info.sub_graph_node); for_body->SetParentGraph(while_info.while_body); return SUCCESS; } } // namespace ge