/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/unused_op_remove_pass.h" #include #include #include #include #include "common/debug/log.h" #include "common/op/ge_op_utils.h" #include "common/types.h" #include "common/util.h" #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "inc/pass_manager.h" #include "graph/passes/isolated_op_remove_pass.h" using domi::SUCCESS; namespace ge { const std::set kRemoveOpSet = {DROPOUT, PERMUTE, UNUSEDCONST, ASSERT}; const std::set kOtherRemoveOpSet = {DROPOUT}; Status UnusedOpRemovePass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); std::set remove_op_set; vector nodes_to_be_deleted; if (fmktype_ == TENSORFLOW) { remove_op_set = kRemoveOpSet; } else { remove_op_set = kOtherRemoveOpSet; } for (auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node->GetOpDesc()); std::string op_type_str = node->GetOpDesc()->GetType(); if (remove_op_set.count(op_type_str)) { if (IsExceptions(node)) { continue; } for (auto &out_anchor : node->GetAllOutDataAnchors()) { for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { NodePtr dst_node = in_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(dst_node->GetOpDesc()); int dst_index = in_anchor->GetIdx(); std::vector list_bool; GE_CHECK_NOTNULL(dst_node->GetOpDesc()); list_bool = dst_node->GetOpDesc()->GetIsInputConst(); GE_IF_BOOL_EXEC(list_bool.size() == 0, continue); list_bool.erase(list_bool.begin() + dst_index); dst_node->GetOpDesc()->SetIsInputConst(list_bool); } } if (op_type_str == ASSERT) { GE_CHK_STATUS_RET(CollectParentNode(graph, node, nodes_to_be_deleted), "remove node failed"); } else { GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node failed"); } } } for (auto &node : nodes_to_be_deleted) { for (InDataAnchorPtr &inAnchor : node->GetAllInDataAnchors()) { inAnchor->UnlinkAll(); } for (OutDataAnchorPtr &outAnchorPtr : node->GetAllOutDataAnchors()) { outAnchorPtr->UnlinkAll(); } if (node->GetOutControlAnchor() != nullptr) { node->GetOutControlAnchor()->UnlinkAll(); } GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node:%s failed", node->GetName().c_str()); } return SUCCESS; } Status UnusedOpRemovePass::CollectParentNode(const ComputeGraphPtr &graph, const NodePtr &node, vector &node_vec) { GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(node); node_vec.push_back(node); std::queue node_queue; for (auto &src_node : node->GetInDataNodes()) { if (src_node->GetOutDataNodesSize() == 1) { node_queue.push(src_node); } } while (!node_queue.empty()) { NodePtr temp = node_queue.front(); node_queue.pop(); for (auto &src_node : temp->GetInDataNodes()) { if (src_node->GetOutDataNodesSize() == 1) { node_queue.push(src_node); } } node_vec.push_back(temp); } return SUCCESS; } bool UnusedOpRemovePass::IsExceptions(const NodePtr &node) { GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); auto op_def = node->GetOpDesc(); GE_CHK_BOOL_EXEC(op_def != nullptr, return false, "opdesc is nullptr"); // permute optimised in permute_pass.cpp if (op_def->GetType() == PERMUTE) { GE_IF_BOOL_EXEC( (node->GetInDataNodes().size() != 0 && (node->GetInDataNodes().at(0) != nullptr && node->GetInDataNodes().at(0)->GetOpDesc() != nullptr && node->GetInDataNodes().at(0)->GetOpDesc()->GetType() == ATTENTIONDECODER)), return false); return true; } return false; } } // namespace ge