You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/passes/identity_pass.cc

137 lines
4.8 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/identity_pass.h"
#include <string>
#include <vector>
#include "framework/common/debug/ge_log.h"
#include "graph/common/omg_util.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/attr_utils.h"
#include "graph/debug/ge_attr_define.h"
namespace ge {
namespace {
///
/// 1. A `Identity` node may after a `Switch` node and has control-dependency-out nodes.
/// Or a `Identity` node may before a `Merge` node and has control-dependency-in nodes.
/// The identity nodes are used to represent control dependencies in condition branch, and can not be deleted.
/// 2. Check identity is near subgraph.
/// Eg. As output of Data node in subgraph
/// or as input of Netoutput of subgraph
/// or as input of one node with subgraph
/// or as output of one node with subgraph
/// 3. identity with attr no_need_constant_folding should not be deleted too
Status CheckIdentityUsable(const NodePtr &node, bool &usable) {
std::string node_type;
if (node->GetOpDesc()->HasAttr(ge::ATTR_NO_NEED_CONSTANT_FOLDING)) {
usable = true;
return SUCCESS;
}
for (auto &in_node : node->GetInDataNodes()) {
auto in_node_opdesc = in_node->GetOpDesc();
GE_CHECK_NOTNULL(in_node_opdesc);
// near entrance of subgraph || near subgraph
if ((in_node->GetType() == DATA && NodeUtils::IsSubgraphInput(in_node))
|| !in_node_opdesc->GetSubgraphInstanceNames().empty()) {
usable = true;
return SUCCESS;
}
GE_CHK_STATUS_RET(GetOriginalType(in_node, node_type),
"Failed to get node type from node %s", node->GetName().c_str());
bool need_skip = (node_type != SWITCH) && (node_type != REFSWITCH) && (node_type != SWITCHN);
if (need_skip) {
GELOGD("skip identity %s connected to switch", node->GetName().c_str());
break;
}
GE_CHECK_NOTNULL(node->GetOutControlAnchor());
if (!node->GetOutControlAnchor()->GetPeerInControlAnchors().empty()) {
usable = true;
return SUCCESS;
}
}
for (auto &out_node : node->GetOutDataNodes()) {
auto out_node_opdesc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(out_node_opdesc);
// near output of subgraph || near subgraph
if (NodeUtils::IsSubgraphOutput(out_node)
|| !out_node_opdesc->GetSubgraphInstanceNames().empty()) {
usable = true;
return SUCCESS;
}
GE_CHK_STATUS_RET(GetOriginalType(out_node, node_type),
"Failed to get node type from node %s", node->GetName().c_str());
if ((node_type != MERGE) && (node_type != REFMERGE)) {
GELOGD("skip identity %s connected to merge", node->GetName().c_str());
break;
}
GE_CHECK_NOTNULL(node->GetInControlAnchor());
if (!node->GetInControlAnchor()->GetPeerOutControlAnchors().empty()) {
usable = true;
return SUCCESS;
}
}
usable = false;
return SUCCESS;
}
} // namespace
Status IdentityPass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
string type;
Status status_ret = GetOriginalType(node, type);
if (status_ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Get original type for node:%s failed",
node->GetName().c_str());
GELOGE(status_ret, "Identity pass get original type fail.");
return status_ret;
}
if ((type != IDENTITY) && (type != IDENTITYN) && (type != READVARIABLEOP)) {
return SUCCESS;
}
if (!force_) {
bool usable = false;
auto ret = CheckIdentityUsable(node, usable);
if (ret != SUCCESS) {
return ret;
}
if (usable) {
return SUCCESS;
}
}
size_t n = node->GetOpDesc()->GetOutputsSize();
if (node->GetOpDesc()->GetInputsSize() != n) {
REPORT_CALL_ERROR("E19999", "Num:%zu of input desc node:%s(%s) not equal to it's output desc num:%zu, "
"check invalid", node->GetOpDesc()->GetInputsSize(),
node->GetName().c_str(), node->GetType().c_str(), n);
GELOGE(PARAM_INVALID, "Identity input / output size must be equal. in size:%lu, out size:%lu",
node->GetOpDesc()->GetInputsSize(), n);
return PARAM_INVALID;
}
std::vector<int> io_map;
for (size_t i = 0; i < n; i++) {
io_map.push_back(i);
}
return IsolateAndDeleteNode(node, io_map);
}
} // namespace ge