/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/identity_pass.h" #include #include #include "framework/common/debug/ge_log.h" #include "graph/common/omg_util.h" #include "graph/utils/node_utils.h" #include "graph/utils/attr_utils.h" #include "graph/debug/ge_attr_define.h" namespace ge { namespace { /// /// 1. A `Identity` node may after a `Switch` node and has control-dependency-out nodes. /// Or a `Identity` node may before a `Merge` node and has control-dependency-in nodes. /// The identity nodes are used to represent control dependencies in condition branch, and can not be deleted. /// 2. Check identity is near subgraph. /// Eg. As output of Data node in subgraph /// or as input of Netoutput of subgraph /// or as input of one node with subgraph /// or as output of one node with subgraph /// 3. identity with attr no_need_constant_folding should not be deleted too Status CheckIdentityUsable(const NodePtr &node, bool &usable) { std::string node_type; if (node->GetOpDesc()->HasAttr(ge::ATTR_NO_NEED_CONSTANT_FOLDING)) { usable = true; return SUCCESS; } for (auto &in_node : node->GetInDataNodes()) { auto in_node_opdesc = in_node->GetOpDesc(); GE_CHECK_NOTNULL(in_node_opdesc); // near entrance of subgraph || near subgraph if ((in_node->GetType() == DATA && NodeUtils::IsSubgraphInput(in_node)) || !in_node_opdesc->GetSubgraphInstanceNames().empty()) { usable = true; return SUCCESS; } GE_CHK_STATUS_RET(GetOriginalType(in_node, node_type), "Failed to get node type from node %s", node->GetName().c_str()); bool need_skip = (node_type != SWITCH) && (node_type != REFSWITCH) && (node_type != SWITCHN); if (need_skip) { GELOGD("skip identity %s connected to switch", node->GetName().c_str()); break; } GE_CHECK_NOTNULL(node->GetOutControlAnchor()); if (!node->GetOutControlAnchor()->GetPeerInControlAnchors().empty()) { usable = true; return SUCCESS; } } for (auto &out_node : node->GetOutDataNodes()) { auto out_node_opdesc = out_node->GetOpDesc(); GE_CHECK_NOTNULL(out_node_opdesc); // near output of subgraph || near subgraph if (NodeUtils::IsSubgraphOutput(out_node) || !out_node_opdesc->GetSubgraphInstanceNames().empty()) { usable = true; return SUCCESS; } GE_CHK_STATUS_RET(GetOriginalType(out_node, node_type), "Failed to get node type from node %s", node->GetName().c_str()); if ((node_type != MERGE) && (node_type != REFMERGE)) { GELOGD("skip identity %s connected to merge", node->GetName().c_str()); break; } GE_CHECK_NOTNULL(node->GetInControlAnchor()); if (!node->GetInControlAnchor()->GetPeerOutControlAnchors().empty()) { usable = true; return SUCCESS; } } usable = false; return SUCCESS; } } // namespace Status IdentityPass::Run(NodePtr &node) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); string type; Status status_ret = GetOriginalType(node, type); if (status_ret != SUCCESS) { GELOGE(status_ret, "Identity pass get original type fail."); return status_ret; } if ((type != IDENTITY) && (type != IDENTITYN) && (type != READVARIABLEOP)) { return SUCCESS; } if (!force_) { bool usable = false; auto ret = CheckIdentityUsable(node, usable); if (ret != SUCCESS) { return ret; } if (usable) { return SUCCESS; } } size_t n = node->GetOpDesc()->GetOutputsSize(); if (node->GetOpDesc()->GetInputsSize() != n) { GELOGE(PARAM_INVALID, "Identity input / output size must be equal. in size:%lu, out size:%lu", node->GetOpDesc()->GetInputsSize(), n); return PARAM_INVALID; } std::vector io_map; for (size_t i = 0; i < n; i++) { io_map.push_back(i); } return IsolateAndDeleteNode(node, io_map); } } // namespace ge