/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/transop_depth_fusion_pass.h" #include #include "common/ge_inner_error_codes.h" #include "common/types.h" #include "graph/compute_graph.h" #include "graph/ge_tensor.h" #include "graph/op_desc.h" #include "graph/utils/graph_utils.h" #include "graph/common/transop_util.h" #include "graph/utils/node_utils.h" namespace ge { graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { GELOGI("[TransOpDepthFusionPass]: optimize in depth begin..."); if (graph == nullptr) { return GRAPH_SUCCESS; } for (const auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); if (TransOpUtil::IsTransOp(node)) { continue; } GELOGD("Current normal node is: %s, type: %s, begin in-depth recursive", node->GetName().c_str(), node->GetType().c_str()); for (const auto &out_anchor : node->GetAllOutDataAnchors()) { GE_CHECK_NOTNULL(out_anchor); for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { if (RecursiveInDepth(peer_in_anchor, graph) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Recursive failed, root node is: %s, type: %s", node->GetName().c_str(), node->GetType().c_str()); } } } } GELOGI("[TransOpDepthFusionPass]: Optimize in depth success..."); return GRAPH_SUCCESS; } /// @@ Method: /// Depth-first recursive strategy was utilized to traverse all the trans ops. /// Both trans ops will be offset when the back one's output desc is consistent /// with it's former neighbor's input. /// @@ Limitation: /// The current method only judge the neighbors. Trans ops separated by some /// other ops which can't be offset are not taken into account in current /// @@ Recursive depth /// To ensure that the stack does not overflow, the maximum depth in recursive is /// set to be maxRecursiveDepth = 20. More trans ops are seen abnormally. graphStatus TransOpDepthFusionPass::RecursiveInDepth(const InDataAnchorPtr &dst_in_anchor, const ge::ComputeGraphPtr &graph) { static unsigned int temp_depth = 0; static const unsigned int max_recursive_depth = 20; temp_depth++; if (temp_depth >= max_recursive_depth) { GELOGI( "Caution: recursive depth is become %u." "It's abnormally to have so many trans ops between two normal ops" "Please check your graph in detail!" "The search terminate here and continue to another branch.", temp_depth); temp_depth--; return GRAPH_SUCCESS; } if (dst_in_anchor == nullptr || dst_in_anchor->GetOwnerNode() == nullptr || dst_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { GELOGE(FAILED, "parameter is null."); return GRAPH_FAILED; } auto node = dst_in_anchor->GetOwnerNode(); if (!TransOpUtil::IsTransOp(node) || dst_in_anchor->GetIdx() != TransOpUtil::GetTransOpDataIndex(node)) { GELOGD("Now the end of this branch, node: %s, type: %s, recursive depth: %u", node->GetName().c_str(), node->GetType().c_str(), temp_depth); temp_depth--; return GRAPH_SUCCESS; } else if (CheckNodeCanBeDeleted(node)) { GELOGD("node: %s, type: %s does not change memory, just delete", node->GetName().c_str(), node->GetType().c_str()); auto out_anchor = node->GetOutDataAnchor(0); GE_CHECK_NOTNULL(out_anchor); auto in_anchors = out_anchor->GetPeerInDataAnchors(); GE_CHK_STATUS_RET(RemoveNode(node, graph), "remove edge failed"); GELOGI("remove node: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str()); for (auto &in_anchor : in_anchors) { GE_CHECK_NOTNULL(in_anchor); GE_CHK_STATUS_RET(UpdateSrcAttr(in_anchor->GetPeerOutAnchor(), out_anchor, in_anchor), "UpdateSrcAttr failed"); GE_CHK_STATUS_RET(RecursiveInDepth(in_anchor, graph), "RecursiveInDepth failed"); } } else if (trans_op_.empty() || !DescAreSymmetry(trans_op_.top(), node)) { GELOGD("node: %s, type: %s can't be offset, push to trans_op_", node->GetName().c_str(), node->GetType().c_str()); trans_op_.push(node); auto out_anchor = node->GetOutDataAnchor(0); GE_CHECK_NOTNULL(out_anchor); for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { GE_CHK_STATUS_RET(RecursiveInDepth(in_anchor, graph), "RecursiveInDepth failed"); } if (node->GetOutDataNodesSize() == 0) { GE_CHK_STATUS_RET(RemoveNode(node, graph), "remove node failed"); GELOGI("backtracking, trans op: %s, type: %s will be removed", node->GetName().c_str(), node->GetType().c_str()); } GELOGD("backtracking, trans_op_ fall back. pop node: %s, type: %s.", trans_op_.top()->GetName().c_str(), trans_op_.top()->GetType().c_str()); trans_op_.pop(); } else if (DescAreSymmetry(trans_op_.top(), node)) { GELOGD("current node: %s, type: %s can be offset with node: %s, type %s", node->GetName().c_str(), node->GetType().c_str(), trans_op_.top()->GetName().c_str(), trans_op_.top()->GetType().c_str()); GELOGD("offset_op_ push node: %s, type: %s.", trans_op_.top()->GetName().c_str(), trans_op_.top()->GetType().c_str()); offset_op_.push(trans_op_.top()); auto in_data_anchor = node->GetInDataAnchor(0); GE_CHECK_NOTNULL(in_data_anchor); auto old_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(old_out_anchor); auto new_out_anchor = trans_op_.top()->GetInDataAnchor(0)->GetPeerOutAnchor(); GE_CHECK_NOTNULL(new_out_anchor); GE_IF_BOOL_EXEC(RelinkEdges(new_out_anchor, old_out_anchor, in_data_anchor) != GRAPH_SUCCESS, GELOGE(FAILED, "RelinkEdges fail."); return FAILED) auto out_anchor = node->GetOutDataAnchor(0); GE_CHECK_NOTNULL(out_anchor); auto in_anchors = out_anchor->GetPeerInDataAnchors(); GELOGD("begin offset,trans_op_ pop node: %s, type: %s.", trans_op_.top()->GetName().c_str(), trans_op_.top()->GetType().c_str()); GELOGI("the offset node : %s, type: %s will be removed.", node->GetName().c_str(), node->GetType().c_str()); GE_CHK_STATUS_RET(RemoveNode(node, graph), "remove node failed"); trans_op_.pop(); for (const auto &in_anchor : in_anchors) { GE_CHECK_NOTNULL(in_anchor); GE_CHK_STATUS_RET(UpdateSrcAttr(in_anchor->GetPeerOutAnchor(), out_anchor, in_anchor), "UpdateSrcAttr failed"); GE_CHK_STATUS_RET(RecursiveInDepth(in_anchor, graph), "RecursiveInDepth failed"); } GELOGD("backtracking, trans_op_ push node: %s, type: %s.", offset_op_.top()->GetName().c_str(), offset_op_.top()->GetType().c_str()); trans_op_.push(offset_op_.top()); offset_op_.pop(); } temp_depth--; return GRAPH_SUCCESS; } bool TransOpDepthFusionPass::CheckNodeCanBeDeleted(const NodePtr &node) { bool is_shape_unknown = false; if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) { if (is_shape_unknown) { GELOGI("op:%s is unknown shape, can not be deleted.", node->GetName().c_str()); return false; } } return node->GetType() == RESHAPE || node->GetType() == REFORMAT || node->GetType() == SQUEEZE || node->GetType() == EXPANDDIMS; } bool TransOpDepthFusionPass::DescAreSymmetry(const NodePtr &src_node, const NodePtr &dst_node) { if (src_node == nullptr || dst_node == nullptr || src_node->GetOpDesc() == nullptr || dst_node->GetOpDesc() == nullptr) { return false; } const auto &src_input_desc = src_node->GetOpDesc()->MutableInputDesc(0); const auto &dst_output_desc = dst_node->GetOpDesc()->MutableOutputDesc(0); GE_CHECK_NOTNULL_EXEC(src_input_desc, return false); GE_CHECK_NOTNULL_EXEC(dst_output_desc, return false); const auto &src_input_dtype = src_input_desc->GetDataType(); const auto &src_input_format = src_input_desc->GetFormat(); const auto &src_input_shape = src_input_desc->GetShape().GetDims(); const auto &dst_output_dtype = dst_output_desc->GetDataType(); const auto &dst_output_format = dst_output_desc->GetFormat(); const auto &dst_output_shape = dst_output_desc->GetShape().GetDims(); if (src_node->GetType() == CAST && dst_node->GetType() == CAST) { return src_input_dtype == dst_output_dtype && src_input_format == dst_output_format; } else { return src_input_dtype == dst_output_dtype && src_input_shape == dst_output_shape && src_input_format == dst_output_format; } } // If the relationship was changed, the input and src name will be update graphStatus TransOpDepthFusionPass::UpdateSrcAttr(const OutDataAnchorPtr &new_out_anchor, const OutDataAnchorPtr &ori_out_anchor, const InDataAnchorPtr &dst_in_anchor) { if (dst_in_anchor == nullptr || dst_in_anchor->GetOwnerNode() == nullptr || dst_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { GELOGW("dst_in_anchor or it's owner node and op_desc is nullptr"); return GRAPH_SUCCESS; } GE_CHECK_NOTNULL(new_out_anchor); GE_CHECK_NOTNULL(new_out_anchor->GetOwnerNode()); GE_CHECK_NOTNULL(ori_out_anchor); GE_CHECK_NOTNULL(ori_out_anchor->GetOwnerNode()); auto new_name = new_out_anchor->GetOwnerNode()->GetName(); auto ori_name = ori_out_anchor->GetOwnerNode()->GetName(); auto dst_desc = dst_in_anchor->GetOwnerNode()->GetOpDesc(); auto ori_src_name = dst_desc->GetSrcName(); auto ori_input_name = dst_desc->GetInputName(); std::vector new_src_name; std::vector new_input_name; if (ori_src_name.empty()) { new_src_name.push_back(new_name); } else { for (auto &src_name : ori_src_name) { if (src_name == ori_name) { new_src_name.push_back(new_name); } else { new_src_name.push_back(src_name); } } } if (ori_input_name.empty()) { new_input_name.push_back(new_name); } else { for (auto &input_name : ori_input_name) { if (input_name == ori_name) { new_input_name.push_back(new_name); } else { new_input_name.push_back(input_name); } } } dst_desc->SetSrcName(new_src_name); dst_desc->SetInputName(new_input_name); return GRAPH_SUCCESS; } /// Relink the offset trans op with it's former neighbor's father node. /// Note: control edge will be added to link the two offset ops, if the former op /// has in control nodes graphStatus TransOpDepthFusionPass::RelinkEdges(const OutDataAnchorPtr &new_out_anchor, const OutDataAnchorPtr &old_out_anchor, const InDataAnchorPtr &in_data_anchor) { if (new_out_anchor == nullptr || old_out_anchor == nullptr || in_data_anchor == nullptr) { GELOGE(INTERNAL_ERROR, "new_out_anchor or old_out_anchor or in_data_anchor is nullptr"); return GRAPH_FAILED; } if (new_out_anchor->GetOwnerNode() == nullptr || old_out_anchor->GetOwnerNode() == nullptr || in_data_anchor->GetOwnerNode() == nullptr) { GELOGE(INTERNAL_ERROR, "anchor's owner node is nullptr"); return GRAPH_FAILED; } GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(old_out_anchor, in_data_anchor), "remove edge failed"); GE_CHK_STATUS_RET(GraphUtils::AddEdge(new_out_anchor, in_data_anchor), "add edge failed"); GELOGD( "relink edges before remove node, remove data edge between node: %s, " "type: %s and node: %s, type: %s.", old_out_anchor->GetOwnerNode()->GetName().c_str(), old_out_anchor->GetOwnerNode()->GetType().c_str(), in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetOwnerNode()->GetType().c_str()); GELOGD( "relink edges before remove node, add data edge between node: %s, " "type: %s and node: %s, type: %s.", new_out_anchor->GetOwnerNode()->GetName().c_str(), new_out_anchor->GetOwnerNode()->GetType().c_str(), in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetOwnerNode()->GetType().c_str()); bool is_linked = false; auto dst_node = in_data_anchor->GetOwnerNode(); auto src_node = old_out_anchor->GetOwnerNode(); auto in_ctrl_nodes = dst_node->GetInControlNodes(); if (!in_ctrl_nodes.empty()) { auto iter = std::find(in_ctrl_nodes.begin(), in_ctrl_nodes.end(), src_node); is_linked = iter != in_ctrl_nodes.end(); } if (!src_node->GetInControlNodes().empty() && !is_linked) { auto out_ctrl_anchor = src_node->GetOutControlAnchor(); auto in_ctrl_anchor = dst_node->GetInControlAnchor(); GE_CHK_STATUS_RET(GraphUtils::AddEdge(out_ctrl_anchor, in_ctrl_anchor), "add edge failed"); GELOGD( "relink edges before remove node, add control edge between node: %s," " type: %s and node: %s, type: %s.", src_node->GetName().c_str(), src_node->GetType().c_str(), dst_node->GetName().c_str(), dst_node->GetType().c_str()); } return GRAPH_SUCCESS; } // Remove trans op by using interface: IsolateNode & RemoveNodeWithoutRelink graphStatus TransOpDepthFusionPass::RemoveNode(const NodePtr &node, const ge::ComputeGraphPtr &graph) { if (node == nullptr || graph == nullptr) { return GRAPH_FAILED; } if (GraphUtils::IsolateNode(node, {0}) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", node->GetName().c_str(), node->GetType().c_str()); return GRAPH_FAILED; } if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Remove node: %s, type: %s without relink failed", node->GetName().c_str(), node->GetType().c_str()); return GRAPH_FAILED; } return GRAPH_SUCCESS; } } // namespace ge