You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
4.8 KiB
137 lines
4.8 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "graph/passes/identity_pass.h"
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
#include "framework/common/debug/ge_log.h"
|
|
#include "graph/common/omg_util.h"
|
|
#include "graph/utils/node_utils.h"
|
|
#include "graph/utils/attr_utils.h"
|
|
#include "graph/debug/ge_attr_define.h"
|
|
|
|
namespace ge {
|
|
namespace {
|
|
///
|
|
/// 1. A `Identity` node may after a `Switch` node and has control-dependency-out nodes.
|
|
/// Or a `Identity` node may before a `Merge` node and has control-dependency-in nodes.
|
|
/// The identity nodes are used to represent control dependencies in condition branch, and can not be deleted.
|
|
/// 2. Check identity is near subgraph.
|
|
/// Eg. As output of Data node in subgraph
|
|
/// or as input of Netoutput of subgraph
|
|
/// or as input of one node with subgraph
|
|
/// or as output of one node with subgraph
|
|
/// 3. identity with attr no_need_constant_folding should not be deleted too
|
|
Status CheckIdentityUsable(const NodePtr &node, bool &usable) {
|
|
std::string node_type;
|
|
if (node->GetOpDesc()->HasAttr(ge::ATTR_NO_NEED_CONSTANT_FOLDING)) {
|
|
usable = true;
|
|
return SUCCESS;
|
|
}
|
|
|
|
for (auto &in_node : node->GetInDataNodes()) {
|
|
auto in_node_opdesc = in_node->GetOpDesc();
|
|
GE_CHECK_NOTNULL(in_node_opdesc);
|
|
// near entrance of subgraph || near subgraph
|
|
if ((in_node->GetType() == DATA && NodeUtils::IsSubgraphInput(in_node))
|
|
|| !in_node_opdesc->GetSubgraphInstanceNames().empty()) {
|
|
usable = true;
|
|
return SUCCESS;
|
|
}
|
|
|
|
GE_CHK_STATUS_RET(GetOriginalType(in_node, node_type),
|
|
"Failed to get node type from node %s", node->GetName().c_str());
|
|
bool need_skip = (node_type != SWITCH) && (node_type != REFSWITCH) && (node_type != SWITCHN);
|
|
if (need_skip) {
|
|
GELOGD("skip identity %s connected to switch", node->GetName().c_str());
|
|
break;
|
|
}
|
|
GE_CHECK_NOTNULL(node->GetOutControlAnchor());
|
|
if (!node->GetOutControlAnchor()->GetPeerInControlAnchors().empty()) {
|
|
usable = true;
|
|
return SUCCESS;
|
|
}
|
|
}
|
|
for (auto &out_node : node->GetOutDataNodes()) {
|
|
auto out_node_opdesc = out_node->GetOpDesc();
|
|
GE_CHECK_NOTNULL(out_node_opdesc);
|
|
// near output of subgraph || near subgraph
|
|
if (NodeUtils::IsSubgraphOutput(out_node)
|
|
|| !out_node_opdesc->GetSubgraphInstanceNames().empty()) {
|
|
usable = true;
|
|
return SUCCESS;
|
|
}
|
|
GE_CHK_STATUS_RET(GetOriginalType(out_node, node_type),
|
|
"Failed to get node type from node %s", node->GetName().c_str());
|
|
if ((node_type != MERGE) && (node_type != REFMERGE)) {
|
|
GELOGD("skip identity %s connected to merge", node->GetName().c_str());
|
|
break;
|
|
}
|
|
GE_CHECK_NOTNULL(node->GetInControlAnchor());
|
|
if (!node->GetInControlAnchor()->GetPeerOutControlAnchors().empty()) {
|
|
usable = true;
|
|
return SUCCESS;
|
|
}
|
|
}
|
|
usable = false;
|
|
return SUCCESS;
|
|
}
|
|
} // namespace
|
|
|
|
Status IdentityPass::Run(NodePtr &node) {
|
|
GE_CHECK_NOTNULL(node);
|
|
auto op_desc = node->GetOpDesc();
|
|
GE_CHECK_NOTNULL(op_desc);
|
|
string type;
|
|
Status status_ret = GetOriginalType(node, type);
|
|
if (status_ret != SUCCESS) {
|
|
REPORT_CALL_ERROR("E19999", "Get original type for node:%s failed",
|
|
node->GetName().c_str());
|
|
GELOGE(status_ret, "Identity pass get original type fail.");
|
|
return status_ret;
|
|
}
|
|
if ((type != IDENTITY) && (type != IDENTITYN) && (type != READVARIABLEOP)) {
|
|
return SUCCESS;
|
|
}
|
|
|
|
if (!force_) {
|
|
bool usable = false;
|
|
auto ret = CheckIdentityUsable(node, usable);
|
|
if (ret != SUCCESS) {
|
|
return ret;
|
|
}
|
|
if (usable) {
|
|
return SUCCESS;
|
|
}
|
|
}
|
|
size_t n = node->GetOpDesc()->GetOutputsSize();
|
|
if (node->GetOpDesc()->GetInputsSize() != n) {
|
|
REPORT_CALL_ERROR("E19999", "Num:%zu of input desc node:%s(%s) not equal to it's output desc num:%zu, "
|
|
"check invalid", node->GetOpDesc()->GetInputsSize(),
|
|
node->GetName().c_str(), node->GetType().c_str(), n);
|
|
GELOGE(PARAM_INVALID, "Identity input / output size must be equal. in size:%lu, out size:%lu",
|
|
node->GetOpDesc()->GetInputsSize(), n);
|
|
return PARAM_INVALID;
|
|
}
|
|
std::vector<int> io_map;
|
|
for (size_t i = 0; i < n; i++) {
|
|
io_map.push_back(i);
|
|
}
|
|
return IsolateAndDeleteNode(node, io_map);
|
|
}
|
|
} // namespace ge
|