/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "switch_data_edges_bypass.h" #include #include "common/debug/log.h" #include "common/ge/ge_util.h" #include "common/op/ge_op_utils.h" #include "common/util.h" #include "graph/utils/node_utils.h" namespace ge { namespace { bool IsSwitchInWhileLoop(const NodePtr &node) { auto pred_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); if (pred_anchor == nullptr) { GELOGW("The switch node %s does not have a pred in anchor, the node may be invalid", node->GetName().c_str()); return true; } auto pred_node_anchor = pred_anchor->GetPeerOutAnchor(); if (pred_node_anchor == nullptr) { GELOGW("The switch node %s does not have a pred in node, the graph may be invalid", node->GetName().c_str()); return true; } auto pred_node = pred_node_anchor->GetOwnerNode(); if (pred_node == nullptr) { GELOGW("The switch node %s does not have a pred in node, the pred-anchor may be invalid", node->GetName().c_str()); return true; } if (pred_node->GetType() == LOOPCOND) { GELOGD("The switch node %s is in a while loop, skip the bypass process", node->GetName().c_str()); return true; } return false; } std::vector> GetOutDataNodesByIndex(const NodePtr &node, int index) { auto out_anchor = node->GetOutDataAnchor(index); if (out_anchor == nullptr) { GELOGE(PARAM_INVALID, "Failed to get out data nodes of index %d from node %s, the anchor does not exists", index, node->GetName().c_str()); return {}; } std::vector> nodes_and_anchors; for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { auto out_node = in_anchor->GetOwnerNode(); if (out_node != nullptr) { nodes_and_anchors.emplace_back(out_node, in_anchor); } } return nodes_and_anchors; } std::pair GetInDataNodeByIndex(const NodePtr &node, int index) { auto in_anchor = node->GetInDataAnchor(index); if (in_anchor == nullptr) { GELOGD("Failed to get in data node of index %d from node %s, the anchor does not exists", index, node->GetName().c_str()); return {}; } auto out_anchor = in_anchor->GetPeerOutAnchor(); if (out_anchor == nullptr) { GELOGD("Failed to get in data node of index %d from node %s, the data input does not exists", index, node->GetName().c_str()); return {}; } return {out_anchor->GetOwnerNode(), out_anchor}; } NodePtr AddIdentityAfterNode(const NodePtr &node, int index) { static std::atomic_long atomic_identity_counter(0); auto identity_counter = atomic_identity_counter.fetch_add(1); auto node_desc = node->GetOpDesc(); if (node_desc == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to add identity after node %s index %d, the op desc is null", node->GetName().c_str(), index); return nullptr; } auto tensor = node_desc->GetOutputDescPtr(index); if (tensor == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to find the tensor by index %d from node %s, can not add the identity node", index, node->GetName().c_str()); return nullptr; } auto anchor = node->GetOutDataAnchor(index); if (anchor == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to add identity after node %s index %d, the out anchor does not exists", node->GetName().c_str(), index); return nullptr; } auto identity_opdesc = MakeShared("SwitchDataEdgesByPass_Identity_" + std::to_string(identity_counter), IDENTITY); if (identity_opdesc == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to add identity after node %s index %d", node->GetName().c_str(), index); return nullptr; } auto ret1 = identity_opdesc->AddInputDesc("x", *tensor); auto ret2 = identity_opdesc->AddOutputDesc("y", *tensor); auto identity = node->GetOwnerComputeGraph()->AddNode(identity_opdesc); if (ret1 != GRAPH_SUCCESS || ret2 != GRAPH_SUCCESS || identity == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to add identity after node %s index %d", node->GetName().c_str(), index); return nullptr; } (void)anchor->LinkTo(identity->GetInDataAnchor(0)); return identity; } NodePtr AddMemcpyBeforeNode(const NodePtr &node, int index) { static std::atomic_long atomic_counter(0); auto counter = atomic_counter.fetch_add(1); auto node_desc = node->GetOpDesc(); if (node_desc == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to add memcpy before node %s index %d, null op desc", node->GetName().c_str(), index); return nullptr; } auto tensor = node_desc->GetInputDescPtr(index); if (tensor == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to find the tensor by index %d from node %s, can not add the memcpy node", index, node->GetName().c_str()); return nullptr; } auto anchor = node->GetInDataAnchor(index); if (anchor == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to add memcpy before node %s index %d, the in anchor does not exists", node->GetName().c_str(), index); return nullptr; } auto memcpy_opdesc = MakeShared("SwitchDataEdgesByPass_Memcpy_" + std::to_string(counter), MEMCPYASYNC); if (memcpy_opdesc == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to add memcpy before node %s index %d", node->GetName().c_str(), index); return nullptr; } auto ret1 = memcpy_opdesc->AddInputDesc(*tensor); auto ret2 = memcpy_opdesc->AddOutputDesc(*tensor); auto memcpy_node = node->GetOwnerComputeGraph()->AddNode(memcpy_opdesc); if (ret1 != GRAPH_SUCCESS || ret2 != GRAPH_SUCCESS || memcpy_node == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to add memcpy before node %s index %d", node->GetName().c_str(), index); return nullptr; } (void)memcpy_node->GetOutDataAnchor(0)->LinkTo(anchor); return memcpy_node; } Status BypassSwitchOut(const NodePtr &switch_node, int out_index) { auto nodes_and_anchors = GetOutDataNodesByIndex(switch_node, out_index); if (nodes_and_anchors.empty()) { GELOGD("The switch node %s does not has out branch %d, skip the bypass process", switch_node->GetName().c_str(), out_index); return SUCCESS; } auto data_node_and_anchor = GetInDataNodeByIndex(switch_node, SWITCH_DATA_INPUT); if (data_node_and_anchor.first == nullptr) { GELOGW("Can not bypass switch node %s, the node does not has a data input", switch_node->GetName().c_str()); return SUCCESS; } auto identity = AddIdentityAfterNode(switch_node, out_index); GE_CHECK_NOTNULL(identity); std::set connected_nodes; for (const auto &node_and_anchor : nodes_and_anchors) { auto head_anchor = node_and_anchor.second; head_anchor->UnlinkAll(); auto head_node = node_and_anchor.first; auto head_node_type = NodeUtils::GetNodeType(*head_node); if (head_node_type == MEMCPYASYNC) { // if the switch connect to the merge directly, insert memcpy before merge auto memcpy_node = AddMemcpyBeforeNode(head_node, head_anchor->GetIdx()); GE_CHECK_NOTNULL(memcpy_node); GELOGD("Add memcpy %s before merge node %s", memcpy_node->GetName().c_str(), head_node->GetName().c_str()); head_node = memcpy_node; head_anchor = memcpy_node->GetInDataAnchor(0); } (void)data_node_and_anchor.second->LinkTo(head_anchor); if (connected_nodes.insert(head_node.get()).second) { (void)identity->GetOutControlAnchor()->LinkTo(head_node->GetInControlAnchor()); } } GELOGI("Bypass switch %s out index %d success", switch_node->GetName().c_str(), out_index); return SUCCESS; } } // namespace Status SwitchDataEdgesBypass::Run(ComputeGraphPtr graph) { for (const auto &node : graph->GetDirectNode()) { auto ret = BypassSwitch(node); GE_CHK_STATUS_RET(ret, "By pass switch node %s failed", node->GetName().c_str()) } return SUCCESS; } Status SwitchDataEdgesBypass::BypassSwitch(const NodePtr &node) { auto node_type = NodeUtils::GetNodeType(*node); if ((node_type != SWITCH) && (node_type != REFSWITCH)) { return SUCCESS; } if (IsSwitchInWhileLoop(node)) { return SUCCESS; } auto ret = BypassSwitchOut(node, SWITCH_FALSE_OUTPUT); GE_CHK_STATUS_RET(ret, "By pass switch node %s false output failed", node->GetName().c_str()) ret = BypassSwitchOut(node, SWITCH_TRUE_OUTPUT); GE_CHK_STATUS_RET(ret, "By pass switch node %s true output failed", node->GetName().c_str()) return SUCCESS; } } // namespace ge