You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/passes/switch_data_edges_bypass.cc

225 lines
8.9 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "switch_data_edges_bypass.h"
#include <atomic>
#include "common/debug/log.h"
#include "common/ge/ge_util.h"
#include "common/op/ge_op_utils.h"
#include "common/util.h"
#include "graph/utils/node_utils.h"
namespace ge {
namespace {
bool IsSwitchInWhileLoop(const NodePtr &node) {
auto pred_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
if (pred_anchor == nullptr) {
GELOGW("The switch node %s does not have a pred in anchor, the node may be invalid", node->GetName().c_str());
return true;
}
auto pred_node_anchor = pred_anchor->GetPeerOutAnchor();
if (pred_node_anchor == nullptr) {
GELOGW("The switch node %s does not have a pred in node, the graph may be invalid", node->GetName().c_str());
return true;
}
auto pred_node = pred_node_anchor->GetOwnerNode();
if (pred_node == nullptr) {
GELOGW("The switch node %s does not have a pred in node, the pred-anchor may be invalid", node->GetName().c_str());
return true;
}
if (pred_node->GetType() == LOOPCOND) {
GELOGD("The switch node %s is in a while loop, skip the bypass process", node->GetName().c_str());
return true;
}
return false;
}
std::vector<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesByIndex(const NodePtr &node, int index) {
auto out_anchor = node->GetOutDataAnchor(index);
if (out_anchor == nullptr) {
GELOGE(PARAM_INVALID, "Failed to get out data nodes of index %d from node %s, the anchor does not exists", index,
node->GetName().c_str());
return {};
}
std::vector<std::pair<NodePtr, InDataAnchorPtr>> nodes_and_anchors;
for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
auto out_node = in_anchor->GetOwnerNode();
if (out_node != nullptr) {
nodes_and_anchors.emplace_back(out_node, in_anchor);
}
}
return nodes_and_anchors;
}
std::pair<NodePtr, OutDataAnchorPtr> GetInDataNodeByIndex(const NodePtr &node, int index) {
auto in_anchor = node->GetInDataAnchor(index);
if (in_anchor == nullptr) {
GELOGD("Failed to get in data node of index %d from node %s, the anchor does not exists", index,
node->GetName().c_str());
return {};
}
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) {
GELOGD("Failed to get in data node of index %d from node %s, the data input does not exists", index,
node->GetName().c_str());
return {};
}
return {out_anchor->GetOwnerNode(), out_anchor};
}
NodePtr AddIdentityAfterNode(const NodePtr &node, int index) {
static std::atomic_long atomic_identity_counter(0);
auto identity_counter = atomic_identity_counter.fetch_add(1);
auto node_desc = node->GetOpDesc();
if (node_desc == nullptr) {
GELOGE(INTERNAL_ERROR, "Failed to add identity after node %s index %d, the op desc is null",
node->GetName().c_str(), index);
return nullptr;
}
auto tensor = node_desc->GetOutputDescPtr(index);
if (tensor == nullptr) {
GELOGE(INTERNAL_ERROR, "Failed to find the tensor by index %d from node %s, can not add the identity node", index,
node->GetName().c_str());
return nullptr;
}
auto anchor = node->GetOutDataAnchor(index);
if (anchor == nullptr) {
GELOGE(OUT_OF_MEMORY, "Failed to add identity after node %s index %d, the out anchor does not exists",
node->GetName().c_str(), index);
return nullptr;
}
auto identity_opdesc =
MakeShared<OpDesc>("SwitchDataEdgesByPass_Identity_" + std::to_string(identity_counter), IDENTITY);
if (identity_opdesc == nullptr) {
GELOGE(OUT_OF_MEMORY, "Failed to add identity after node %s index %d", node->GetName().c_str(), index);
return nullptr;
}
auto ret1 = identity_opdesc->AddInputDesc("x", *tensor);
auto ret2 = identity_opdesc->AddOutputDesc("y", *tensor);
auto identity = node->GetOwnerComputeGraph()->AddNode(identity_opdesc);
if (ret1 != GRAPH_SUCCESS || ret2 != GRAPH_SUCCESS || identity == nullptr) {
GELOGE(OUT_OF_MEMORY, "Failed to add identity after node %s index %d", node->GetName().c_str(), index);
return nullptr;
}
(void)anchor->LinkTo(identity->GetInDataAnchor(0));
return identity;
}
NodePtr AddMemcpyBeforeNode(const NodePtr &node, int index) {
static std::atomic_long atomic_counter(0);
auto counter = atomic_counter.fetch_add(1);
auto node_desc = node->GetOpDesc();
if (node_desc == nullptr) {
GELOGE(INTERNAL_ERROR, "Failed to add memcpy before node %s index %d, null op desc", node->GetName().c_str(),
index);
return nullptr;
}
auto tensor = node_desc->GetInputDescPtr(index);
if (tensor == nullptr) {
GELOGE(INTERNAL_ERROR, "Failed to find the tensor by index %d from node %s, can not add the memcpy node", index,
node->GetName().c_str());
return nullptr;
}
auto anchor = node->GetInDataAnchor(index);
if (anchor == nullptr) {
GELOGE(INTERNAL_ERROR, "Failed to add memcpy before node %s index %d, the in anchor does not exists",
node->GetName().c_str(), index);
return nullptr;
}
auto memcpy_opdesc = MakeShared<OpDesc>("SwitchDataEdgesByPass_Memcpy_" + std::to_string(counter), MEMCPYASYNC);
if (memcpy_opdesc == nullptr) {
GELOGE(OUT_OF_MEMORY, "Failed to add memcpy before node %s index %d", node->GetName().c_str(), index);
return nullptr;
}
auto ret1 = memcpy_opdesc->AddInputDesc(*tensor);
auto ret2 = memcpy_opdesc->AddOutputDesc(*tensor);
auto memcpy_node = node->GetOwnerComputeGraph()->AddNode(memcpy_opdesc);
if (ret1 != GRAPH_SUCCESS || ret2 != GRAPH_SUCCESS || memcpy_node == nullptr) {
GELOGE(OUT_OF_MEMORY, "Failed to add memcpy before node %s index %d", node->GetName().c_str(), index);
return nullptr;
}
(void)memcpy_node->GetOutDataAnchor(0)->LinkTo(anchor);
return memcpy_node;
}
Status BypassSwitchOut(const NodePtr &switch_node, int out_index) {
auto nodes_and_anchors = GetOutDataNodesByIndex(switch_node, out_index);
if (nodes_and_anchors.empty()) {
GELOGD("The switch node %s does not has out branch %d, skip the bypass process", switch_node->GetName().c_str(),
out_index);
return SUCCESS;
}
auto data_node_and_anchor = GetInDataNodeByIndex(switch_node, SWITCH_DATA_INPUT);
if (data_node_and_anchor.first == nullptr) {
GELOGW("Can not bypass switch node %s, the node does not has a data input", switch_node->GetName().c_str());
return SUCCESS;
}
auto identity = AddIdentityAfterNode(switch_node, out_index);
GE_CHECK_NOTNULL(identity);
std::set<Node *> connected_nodes;
for (const auto &node_and_anchor : nodes_and_anchors) {
auto head_anchor = node_and_anchor.second;
head_anchor->UnlinkAll();
auto head_node = node_and_anchor.first;
auto head_node_type = NodeUtils::GetNodeType(*head_node);
if (head_node_type == MERGE || head_node_type == REFMERGE) {
// if the switch connect to the merge directly, insert memcpy before merge
auto memcpy_node = AddMemcpyBeforeNode(head_node, head_anchor->GetIdx());
GE_CHECK_NOTNULL(memcpy_node);
GELOGD("Add memcpy %s before merge node %s", memcpy_node->GetName().c_str(), head_node->GetName().c_str());
head_node = memcpy_node;
head_anchor = memcpy_node->GetInDataAnchor(0);
}
(void)data_node_and_anchor.second->LinkTo(head_anchor);
if (connected_nodes.insert(head_node.get()).second) {
(void)identity->GetOutControlAnchor()->LinkTo(head_node->GetInControlAnchor());
}
}
GELOGI("Bypass switch %s out index %d success", switch_node->GetName().c_str(), out_index);
return SUCCESS;
}
} // namespace
Status SwitchDataEdgesBypass::Run(ComputeGraphPtr graph) {
for (const auto &node : graph->GetDirectNode()) {
auto ret = BypassSwitch(node);
GE_CHK_STATUS_RET(ret, "By pass switch node %s failed", node->GetName().c_str())
}
return SUCCESS;
}
Status SwitchDataEdgesBypass::BypassSwitch(const NodePtr &node) {
auto node_type = NodeUtils::GetNodeType(*node);
if ((node_type != SWITCH) && (node_type != REFSWITCH)) {
return SUCCESS;
}
if (IsSwitchInWhileLoop(node)) {
return SUCCESS;
}
auto ret = BypassSwitchOut(node, SWITCH_FALSE_OUTPUT);
GE_CHK_STATUS_RET(ret, "By pass switch node %s false output failed", node->GetName().c_str())
ret = BypassSwitchOut(node, SWITCH_TRUE_OUTPUT);
GE_CHK_STATUS_RET(ret, "By pass switch node %s true output failed", node->GetName().c_str())
return SUCCESS;
}
} // namespace ge