/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/subgraph_pass.h" #include "graph/utils/node_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" namespace ge { /** * @ingroup ge * @brief Subgraph optimizer. * @param [in] graph: Input ComputeGraph * @return: 0 for success / others for fail */ Status SubgraphPass::Run(ComputeGraphPtr graph) { const bool is_sub_graph = graph->GetParentNode() != nullptr; for (const NodePtr &node : graph->GetDirectNode()) { if (is_sub_graph && (node->GetType() == DATA)) { if (SubgraphInputNode(graph, node) != SUCCESS) { GELOGE(FAILED, "Handle input %s of subgraph failed.", node->GetName().c_str()); return FAILED; } continue; } // NetOutput in subgraph if (is_sub_graph && (node->GetType() == NETOUTPUT)) { if (SubgraphOutputNode(graph, node) != SUCCESS) { GELOGE(FAILED, "Handle output %s of subgraph failed.", node->GetName().c_str()); return FAILED; } continue; } if (kWhileOpTypes.count(node->GetType()) > 0) { // Input->While and Input link to other nodes if (WhileInputNodes(graph, node) != SUCCESS) { GELOGE(FAILED, "Handle input of while_body failed, while:%s.", node->GetName().c_str()); return FAILED; } // body subgraph of While op if (WhileBodySubgraph(graph, node) != SUCCESS) { GELOGE(FAILED, "Handle while_body failed, while:%s.", node->GetName().c_str()); return FAILED; } continue; } } return SUCCESS; } /** * @ingroup ge * @brief Check Subgraph Input node * @param [in] graph: ComputeGraph. * @param [in] node: Data node in Subgraph. * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodePtr &node) { GELOGD("Handle input_node %s for graph %s.", node->GetName().c_str(), graph->GetName().c_str()); // Data has and only has one output bool input_continues_required_flag = false; OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); std::vector in_anchors; for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { input_continues_required_flag = input_continues_required_flag || IsInputContinuesRequired(peer_in_anchor->GetOwnerNode()); in_anchors.emplace_back(peer_in_anchor); } // Data->InputContinuesRequiredOp in subgraph need memcpy. if (input_continues_required_flag) { GELOGD("Data %s output_node required continues input.", node->GetName().c_str()); std::string name = node->GetName() + "_output_0_Memcpy"; if (InsertMemcpyNode(graph, out_data_anchor, in_anchors, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy after %s failed.", node->GetName().c_str()); return FAILED; } } uint32_t parent_index = 0; if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { GELOGE(FAILED, "Get attr PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); return FAILED; } // Subgraph Data Node, check for constant input. std::string const_type; if (!NodeUtils::GetConstOpType(node, const_type)) { return SUCCESS; } const NodePtr &parent_node = graph->GetParentNode(); if (kWhileOpTypes.count(parent_node->GetType()) != 0) { // Constant input to While need memcpy. const ComputeGraphPtr &parent_graph = parent_node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(parent_graph); const InDataAnchorPtr &in_data_anchor = parent_node->GetInDataAnchor(parent_index); GE_CHECK_NOTNULL(in_data_anchor); const OutDataAnchorPtr &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out_anchor); GELOGD("Constant input %s links to While %s.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_node->GetName().c_str()); std::string name = parent_node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; if (InsertMemcpyNode(parent_graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_node->GetName().c_str()); return FAILED; } } return SUCCESS; } /** * @ingroup ge * @brief Check Subgraph Output node * @param [in] graph: ComputeGraph. * @param [in] node: NetOutput node in Subgraph. * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const NodePtr &node) { for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { const OutDataAnchorPtr &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); NodePtr in_node = peer_out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(in_node); // Need insert memcpy // 1. Const->NetOutput in subgraph // 2. AtomicOp->NetOutput in subgraph // 3. OutputContinuesRequiredOp->NetOutput in subgraph // 4. Data->NetOutput in subgraph but parent_node is not while std::string op_type; bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)); if (insert_flag) { GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); return FAILED; } } } return SUCCESS; } /** * @ingroup ge * @brief Check is Input->While and Input link to other nodes * @param [in] graph: ComputeGraph. * @param [in] node: While node. * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::WhileInputNodes(const ComputeGraphPtr &graph, const NodePtr &node) { for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { const OutDataAnchorPtr &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); NodePtr in_node = peer_out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(in_node); if (in_node->GetType() == VARIABLE || in_node->GetType() == VARHANDLEOP || in_node->GetType() == VARIABLEV2) { continue; } // Input->While and Input link to other nodes need insert memcpy if (peer_out_anchor->GetPeerInDataAnchors().size() > 1) { GELOGD("Input %s of While %s links to other nodes.", in_node->GetName().c_str(), node->GetName().c_str()); std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); return FAILED; } } } return SUCCESS; } /** * @ingroup ge * @brief Check body subgraph of While op * @param [in] graph: ComputeGraph. * @param [in] node: While node. * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::WhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node) { // index of body_subgraph is 1 ComputeGraphPtr while_body = NodeUtils::GetSubgraph(*node, 1); if (while_body == nullptr) { GELOGE(FAILED, "while_body of %s is NULL.", node->GetName().c_str()); return FAILED; } if (GraphUtils::IsUnknownShapeGraph(while_body)) { GELOGI("Unknown shape while_body graph %s no need to insert memcpy.", while_body->GetName().c_str()); return SUCCESS; } std::vector data_nodes; std::set bypass_index; NodePtr output_node = nullptr; for (const auto &n : while_body->GetDirectNode()) { const std::string &type = n->GetType(); if (type == DATA) { if (CheckInsertInputMemcpy(n, bypass_index)) { data_nodes.emplace_back(n); } } else if (type == NETOUTPUT) { if (output_node == nullptr) { output_node = n; } else { GELOGE(FAILED, "while_body %s exists multi NetOutput nodes.", while_body->GetName().c_str()); return FAILED; } } } if (output_node == nullptr) { GELOGE(FAILED, "while_body %s has no output.", while_body->GetName().c_str()); return FAILED; } if ((InsertInputMemcpy(while_body, data_nodes) != SUCCESS) || (InsertOutputMemcpy(while_body, output_node, bypass_index) != SUCCESS)) { GELOGE(FAILED, "Insert memcpy node in while_body %s failed.", while_body->GetName().c_str()); return FAILED; } return SUCCESS; } /** * @ingroup ge * @brief Insert input memcpy node in while_body * @param [in] graph: while_body * @param [in] data_nodes: data_nodes * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::InsertInputMemcpy(const ComputeGraphPtr &graph, const std::vector &data_nodes) { if (data_nodes.empty()) { GELOGD("No need to insert input memcpy node in while_body %s.", graph->GetName().c_str()); return SUCCESS; } std::string in_name = graph->GetName() + "_input_Memcpy"; OpDescBuilder in_builder(in_name, IDENTITY); for (size_t i = 0; i < data_nodes.size(); i++) { // Data node has and only has one output in_builder.AddInput("x" + std::to_string(i), data_nodes[i]->GetOpDesc()->GetOutputDesc(0)) .AddOutput("y" + std::to_string(i), data_nodes[i]->GetOpDesc()->GetOutputDesc(0)); } GELOGD("Insert memcpy after data_nodes of while_body %s.", graph->GetName().c_str()); NodePtr in_memcpy = graph->AddNode(in_builder.Build()); GE_CHECK_NOTNULL(in_memcpy); for (size_t i = 0; i < data_nodes.size(); i++) { // Data node has and only has one output OutDataAnchorPtr out_data_anchor = data_nodes[i]->GetOutDataAnchor(0); std::vector in_anchors; for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { in_anchors.emplace_back(peer_in_anchor); } if (InsertNodeBetween(out_data_anchor, in_anchors, in_memcpy, i, i) != SUCCESS) { GELOGE(FAILED, "Insert MemcpyAsync %s in while_body %s failed.", in_name.c_str(), graph->GetName().c_str()); return FAILED; } } return SUCCESS; } /** * @ingroup ge * @brief Insert output memcpy node in while_body * @param [in] graph: while_body * @param [in] output_node: NetOutput * @param [in] bypass_index * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::InsertOutputMemcpy(const ComputeGraphPtr &graph, const NodePtr &output_node, const std::set &bypass_index) { if (output_node->GetAllInDataAnchorsSize() == bypass_index.size()) { GELOGD("No need to insert output memcpy node in while_body %s, output_size=%zu, bypass_num=%zu.", graph->GetName().c_str(), output_node->GetAllInDataAnchorsSize(), bypass_index.size()); return SUCCESS; } std::string out_name = graph->GetName() + "_output_Memcpy"; OpDescBuilder out_builder(out_name, IDENTITY); for (size_t i = 0; i < output_node->GetAllInDataAnchorsSize(); i++) { if (bypass_index.count(i) == 0) { out_builder.AddInput("x" + std::to_string(i), output_node->GetOpDesc()->GetInputDesc(i)) .AddOutput("y" + std::to_string(i), output_node->GetOpDesc()->GetInputDesc(i)); } } GELOGD("Insert memcpy before NetOutput of while_body %s.", graph->GetName().c_str()); NodePtr out_memcpy = graph->AddNode(out_builder.Build()); GE_CHECK_NOTNULL(out_memcpy); size_t cnt = 0; for (size_t i = 0; i < output_node->GetAllInDataAnchorsSize(); i++) { if (bypass_index.count(i) == 0) { InDataAnchorPtr in_data_anchor = output_node->GetInDataAnchor(i); OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); if (InsertNodeBetween(peer_out_anchor, {in_data_anchor}, out_memcpy, cnt, cnt) != SUCCESS) { GELOGE(FAILED, "Insert MemcpyAsync %s in while_body %s failed.", out_name.c_str(), graph->GetName().c_str()); return FAILED; } cnt++; } } return SUCCESS; } /** * @ingroup ge * @brief Check is data->netoutput without change in while body * @param [in] node: data node * @param [out] bypass_index * @return: false for data->netoutput without change in while body / for true for others */ bool SubgraphPass::CheckInsertInputMemcpy(const NodePtr &node, std::set &bypass_index) { uint32_t input_index = 0; if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, input_index)) { return true; } // Data node has and only has one output OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); if ((out_data_anchor == nullptr) || (out_data_anchor->GetPeerInDataAnchors().size() != 1)) { return true; } InDataAnchorPtr peer_in_anchor = out_data_anchor->GetPeerInDataAnchors().at(0); if (peer_in_anchor->GetOwnerNode()->GetType() != NETOUTPUT) { return true; } OpDescPtr op_desc = peer_in_anchor->GetOwnerNode()->GetOpDesc(); uint32_t output_index = 0; if ((op_desc == nullptr) || !AttrUtils::GetInt(op_desc->GetInputDesc(peer_in_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, output_index)) { return true; } if (input_index != output_index) { return true; } bypass_index.insert(peer_in_anchor->GetIdx()); return false; } /** * @ingroup ge * @brief Check is AtomicOp->NetOutput * @param [in] node * @param [in] out_index * @return: true for AtomicOp->NetOutput / false for others */ bool SubgraphPass::IsAtomicRequired(const NodePtr &node, int64_t out_index) { auto op_desc = node->GetOpDesc(); if (op_desc != nullptr) { bool is_atomic = false; (void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_ATOMIC_NODE, is_atomic); if (is_atomic) { std::vector atomic_output_index; // If GetListInt fail, atomic_output_index is empty. (void)ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); for (int64_t ind : atomic_output_index) { if (ind == out_index) { return true; } } } } return false; } /** * @ingroup ge * @brief Check is OutputContinuesRequiredOp->NetOutput * @param [in] node * @return: true for OutputContinuesRequiredOp->NetOutput / false for others */ bool SubgraphPass::IsOutputContinuesRequired(const NodePtr &node) { OpDescPtr op_desc = node->GetOpDesc(); if (op_desc != nullptr) { bool continuous_output_flag = false; (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_OUTPUT, continuous_output_flag); bool no_padding_continuous_output_flag = false; (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT, no_padding_continuous_output_flag); return continuous_output_flag || no_padding_continuous_output_flag; } return false; } /** * @ingroup ge * @brief Check is InputContinuesRequiredOp->NetOutput * @param [in] node * @return: true for InputContinuesRequiredOp->NetOutput / false for others */ bool SubgraphPass::IsInputContinuesRequired(const NodePtr &node) { OpDescPtr op_desc = node->GetOpDesc(); if (op_desc != nullptr) { bool continuous_input_flag = false; (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT, continuous_input_flag); bool no_padding_continuous_input_flag = false; (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOPADDING_CONTINUOUS_INPUT, no_padding_continuous_input_flag); return continuous_input_flag || no_padding_continuous_input_flag; } return false; } /** * @ingroup ge * @brief Insert memcpy node * @param [in] graph * @param [in] out_anchor * @param [in] in_anchors * @param [in] name * @return: 0 for success / others for fail */ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, const std::vector &in_anchors, const std::string &name) { GE_CHECK_NOTNULL(out_anchor); NodePtr in_node = out_anchor->GetOwnerNode(); OpDescBuilder op_desc_builder(name, IDENTITY); OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(0)) .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) .Build(); (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert IDENTITY node %s after %s failed.", name.c_str(), in_node->GetName().c_str()); return FAILED; } return SUCCESS; } /// /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst /// @param [in] src /// @param [in] dsts /// @param [in] insert_node /// @param [in] input_index /// @param [in] output_index /// @return Status /// Status SubgraphPass::InsertNodeBetween(const OutDataAnchorPtr &src, const std::vector &dsts, const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { if (GraphUtils::AddEdge(src, insert_node->GetInDataAnchor(input_index)) != GRAPH_SUCCESS) { GELOGE(FAILED, "Add data_edge %s:%d->%s:%u failed.", src->GetOwnerNode()->GetName().c_str(), src->GetIdx(), insert_node->GetName().c_str(), input_index); return FAILED; } for (const auto &dst : dsts) { GELOGD("Insert node %s between %s->%s.", insert_node->GetName().c_str(), src->GetOwnerNode()->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str()); if ((GraphUtils::RemoveEdge(src, dst) != GRAPH_SUCCESS) || (GraphUtils::AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS)) { GELOGE(FAILED, "Replace data_edge %s:%d->%s:%d by %s:%u->%s:%d failed.", src->GetOwnerNode()->GetName().c_str(), src->GetIdx(), dst->GetOwnerNode()->GetName().c_str(), dst->GetIdx(), insert_node->GetName().c_str(), output_index, dst->GetOwnerNode()->GetName().c_str(), dst->GetIdx()); return FAILED; } } return SUCCESS; } } // namespace ge