/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/transop_without_reshape_fusion_pass.h" #include #include #include #include #include #include "common/ge/ge_util.h" #include "common/ge_inner_error_codes.h" #include "common/types.h" #include "graph/common/transop_util.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" #include "graph/op_desc.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" namespace { const char *const kRemainNode = "node_remain"; const int kInvalidFusionOpCount = -1; const char *const kAttrNameSrcFormat = "src_format"; const char *const kAttrNameDstFormat = "dst_format"; } // namespace namespace ge { void TransOpWithoutReshapeFusionPass::SetRemainNode( const vector> &nodes_anchor) { auto iter = nodes_anchor.begin(); while (iter != nodes_anchor.end()) { auto in_anchor = iter->second; if (in_anchor == nullptr) { return; } auto in_node = in_anchor->GetOwnerNode(); ++iter; if (in_node == nullptr) { return; } if (!IsTransOp(in_node)) { continue; } auto op_desc = in_node->GetOpDesc(); if (op_desc == nullptr) { continue; } GELOGI("SetRemainNode node is %s", op_desc->GetName().c_str()); GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return ); } } bool TransOpWithoutReshapeFusionPass::FormatContinuousCheck(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor) { if (out_anchor == nullptr || in_anchor == nullptr || in_anchor->GetOwnerNode() == nullptr || out_anchor->GetOwnerNode() == nullptr) { return false; } auto in_node = in_anchor->GetOwnerNode(); GE_IF_BOOL_EXEC(in_node == nullptr, GELOGE(INTERNAL_ERROR, "in_node is null"); return false); auto in_op = in_node->GetOpDesc(); auto out_owner_node = out_anchor->GetOwnerNode(); GE_IF_BOOL_EXEC(out_owner_node == nullptr, GELOGE(INTERNAL_ERROR, "out_owner_node is null"); return false); auto out_op = out_owner_node->GetOpDesc(); GE_IF_BOOL_EXEC(in_op == nullptr, GELOGE(INTERNAL_ERROR, "in_op is null"); return false); GE_IF_BOOL_EXEC(out_op == nullptr, GELOGE(INTERNAL_ERROR, "out_op is null"); return false); auto in_op_desc = in_op->GetInputDescPtr(in_anchor->GetIdx()); auto out_op_desc = out_op->GetOutputDescPtr(out_anchor->GetIdx()); GE_IF_BOOL_EXEC(in_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_op_desc is null"); return false); GE_IF_BOOL_EXEC(out_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_op_desc is null"); return false); if (!ShapeEqualCheck(in_op_desc->GetShape(), out_op_desc->GetShape())) { return false; } if (in_op->GetType() == CAST || out_op->GetType() == CAST) { return TransOpUtil::CheckPrecisionLoss(in_node); } if (in_op_desc->GetFormat() == FORMAT_ND) { return false; } if (out_op_desc->GetFormat() == FORMAT_ND) { return false; } if (in_op_desc->GetFormat() != out_op_desc->GetFormat()) { return false; } return FusionFormatSupport(in_op_desc->GetFormat()); } graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphNodesInfo() { vector sub_graph_has_reshape_node(sub_graph_anchors_.size(), false); vector transop_num_count(sub_graph_anchors_.size(), 0); vector> sub_graph_nodes(sub_graph_anchors_.size()); for (size_t i = 0; i < sub_graph_anchors_.size(); ++i) { auto nodes_anchor = sub_graph_anchors_[i]; vector nodes_tmp; auto iter = nodes_anchor.begin(); auto first_out_anchor = iter->first; if (first_out_anchor == nullptr) { continue; } nodes_tmp.push_back(first_out_anchor->GetOwnerNode()); while (iter != nodes_anchor.end()) { auto in_anchor = iter->second; GE_CHECK_NOTNULL(in_anchor); auto in_node = in_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(in_node); if (in_node->GetType() == RESHAPE) { sub_graph_has_reshape_node[i] = true; break; } auto out_anchor = iter->first; GE_CHECK_NOTNULL(out_anchor); if (!FormatContinuousCheck(out_anchor, in_anchor)) { sub_graph_has_reshape_node[i] = true; break; } nodes_tmp.push_back(in_node); if (IsTransOp(in_node)) { // count transop num transop_num_count[i]++; } ++iter; } sub_graph_nodes[i].swap(nodes_tmp); if (sub_graph_has_reshape_node[i]) { SetRemainNode(nodes_anchor); } } sub_graph_has_reshape_node_.swap(sub_graph_has_reshape_node); transop_num_count_.swap(transop_num_count); sub_graph_nodes_.swap(sub_graph_nodes); return GRAPH_SUCCESS; } void TransOpWithoutReshapeFusionPass::GetOutDataPeerInControlAnchors( const size_t index, vector> &out_data_peer_in_control_anchors) { // The caller guarantees that the index is legal. for (size_t j = 1; j < sub_graph_anchors_[index].size(); ++j) { auto nodes_anchor = sub_graph_anchors_[index][j]; auto out_data_anchor = nodes_anchor.first; GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor); for (const auto &peer_in_control_anchor : out_data_anchor->GetPeerInControlAnchors()) { GE_CHECK_NOTNULL_JUST_RETURN(peer_in_control_anchor); auto peer_node = peer_in_control_anchor->GetOwnerNode(); if (peer_node == nullptr) { continue; } auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node); if (iter == sub_graph_nodes_[index].end()) { out_data_peer_in_control_anchors[index].push_back(peer_in_control_anchor); } else { sub_graph_has_out_data_peer_in_control_edge_[index] = true; } } } } void TransOpWithoutReshapeFusionPass::GetInControlPeerOutControlAnchors( const size_t index, vector> &in_control_peer_out_control_anchors) { // The caller guarantees that the index is legal. for (size_t j = 1; j < (sub_graph_nodes_[index].size() - 1); ++j) { auto node = sub_graph_nodes_[index][j]; GE_CHECK_NOTNULL_JUST_RETURN(node); auto in_control_anchor = node->GetInControlAnchor(); if (in_control_anchor == nullptr) { continue; } for (const auto &peer_out_anchor : in_control_anchor->GetPeerOutControlAnchors()) { GE_CHECK_NOTNULL_JUST_RETURN(peer_out_anchor); auto peer_node = peer_out_anchor->GetOwnerNode(); if (peer_node == nullptr) { continue; } auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node); if (iter == sub_graph_nodes_[index].end()) { in_control_peer_out_control_anchors[index].push_back(peer_out_anchor); } else { sub_graph_has_control_edge_[index] = true; } } } } void TransOpWithoutReshapeFusionPass::GetOutControlPeerAnchors( const size_t index, vector> &out_control_peer_in_control_anchors, vector> &out_control_peer_in_data_anchors) { for (size_t j = 0; j < sub_graph_nodes_[index].size() - 1; ++j) { auto node = sub_graph_nodes_[index][j]; GE_CHECK_NOTNULL_JUST_RETURN(node); auto out_control_anchor = node->GetOutControlAnchor(); GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor); for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) { GE_CHECK_NOTNULL_JUST_RETURN(peer_in_anchor); auto peer_node = peer_in_anchor->GetOwnerNode(); if (peer_node == nullptr) { continue; } auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node); if (iter == sub_graph_nodes_[index].end()) { if (j > 0) { out_control_peer_in_control_anchors[index].push_back(peer_in_anchor); } } else { sub_graph_has_control_edge_[index] = true; } } for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) { GE_CHECK_NOTNULL_JUST_RETURN(peer_in_anchor); auto peer_node = peer_in_anchor->GetOwnerNode(); if (peer_node == nullptr) { continue; } auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node); if (iter == sub_graph_nodes_[index].end()) { if (j > 0) { out_control_peer_in_data_anchors[index].push_back(peer_in_anchor); } } else { sub_graph_has_control_edge_[index] = true; } } } } void TransOpWithoutReshapeFusionPass::GetControlAnchors() { vector> in_control_peer_out_control_anchors(sub_graph_nodes_.size()); vector> out_control_peer_in_control_anchors(sub_graph_nodes_.size()); vector> out_control_peer_in_data_anchors(sub_graph_nodes_.size()); vector> out_data_peer_in_control_anchors(sub_graph_nodes_.size()); vector sub_graph_has_control_edge(sub_graph_nodes_.size(), false); sub_graph_has_control_edge_.swap(sub_graph_has_control_edge); vector sub_graph_has_out_data_peer_in_control_edge(sub_graph_nodes_.size(), false); sub_graph_has_out_data_peer_in_control_edge_.swap(sub_graph_has_out_data_peer_in_control_edge); for (size_t i = 0; i < sub_graph_nodes_.size(); ++i) { if (sub_graph_has_reshape_node_[i]) { continue; } GetOutDataPeerInControlAnchors(i, out_data_peer_in_control_anchors); GetInControlPeerOutControlAnchors(i, in_control_peer_out_control_anchors); GetOutControlPeerAnchors(i, out_control_peer_in_control_anchors, out_control_peer_in_data_anchors); } in_control_peer_out_control_anchors_.swap(in_control_peer_out_control_anchors); out_control_peer_in_control_anchors_.swap(out_control_peer_in_control_anchors); out_control_peer_in_data_anchors_.swap(out_control_peer_in_data_anchors); out_data_peer_in_control_anchors_.swap(out_data_peer_in_control_anchors); } void TransOpWithoutReshapeFusionPass::EraseInvalidAnchorsPair() { auto sub_graph_iter = sub_graph_anchors_.begin(); while (sub_graph_iter != sub_graph_anchors_.end()) { if (sub_graph_iter->size() <= 1) { sub_graph_iter = sub_graph_anchors_.erase(sub_graph_iter); } else { ++sub_graph_iter; } } } void TransOpWithoutReshapeFusionPass::UpdateOutputName(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &old_peer_in_anchor, const NodePtr &in_owner_node) { if (out_anchor == nullptr || old_peer_in_anchor == nullptr || in_owner_node == nullptr) { GELOGI("out_anchor or old_peer_in_anchor or in_owner_node is nullptr"); return; } auto out_owner_node = out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node); GE_CHECK_NOTNULL_JUST_RETURN(old_peer_in_anchor->GetOwnerNode()); auto old_peer_in_name = old_peer_in_anchor->GetOwnerNode()->GetName(); auto output_op = out_owner_node->GetOpDesc(); GE_CHECK_NOTNULL_JUST_RETURN(output_op); auto output_names = output_op->GetAllOutputName(); auto old_peer_in_name_iter = output_names.find(old_peer_in_name); if (old_peer_in_name_iter != output_names.end()) { output_names.erase(old_peer_in_name_iter); } output_names[in_owner_node->GetName()] = out_anchor->GetIdx(); if (!output_op->UpdateOutputName(output_names)) { GELOGW("output_op UpdateOutputName failed"); } } void TransOpWithoutReshapeFusionPass::UpdateInputName(const OutDataAnchorPtr &old_peer_out_anchor, const InDataAnchorPtr &in_anchor, const NodePtr &out_owner_node) { if (old_peer_out_anchor == nullptr || in_anchor == nullptr || out_owner_node == nullptr) { GELOGI("old_peer_out_anchor or in_anchor or out_owner_node is nullptr"); return; } auto old_node = old_peer_out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL_JUST_RETURN(old_node); auto old_peer_out_name = old_node->GetName(); auto in_owner_node = in_anchor->GetOwnerNode(); GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node); auto input_op = in_owner_node->GetOpDesc(); GE_CHECK_NOTNULL_JUST_RETURN(input_op); auto input_names = input_op->GetAllInputName(); auto old_peer_out_name_iter = input_names.find(old_peer_out_name); if (old_peer_out_name_iter != input_names.end()) { input_names.erase(old_peer_out_name_iter); } input_names[out_owner_node->GetName()] = in_anchor->GetIdx(); input_op->UpdateInputName(input_names); } graphStatus TransOpWithoutReshapeFusionPass::RelinkSubGraphControlEdges( const pair &begin_anchors_pair, const pair &end_anchors_pair, const int index) { auto out_anchor = begin_anchors_pair.first; GE_CHECK_NOTNULL(out_anchor); auto out_owner_node = out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(out_owner_node); auto in_anchor = end_anchors_pair.second; GE_CHECK_NOTNULL(in_anchor); auto in_owner_node = in_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(in_owner_node); if (sub_graph_has_control_edge_[index]) { GELOGI("add control edge.src:%s, dst:%s", out_owner_node->GetName().c_str(), in_owner_node->GetName().c_str()); if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), in_owner_node->GetInControlAnchor()) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } if (sub_graph_has_out_data_peer_in_control_edge_[index]) { GELOGI("add out data 2 in contorl edge.src:%s, dst:%s", out_owner_node->GetName().c_str(), in_owner_node->GetName().c_str()); if (GraphUtils::AddEdge(out_anchor, in_owner_node->GetInControlAnchor()) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } return GRAPH_SUCCESS; } graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdgesWhenDescNotChanged( const pair &begin_anchors_pair, const pair &end_anchors_pair, const int index) { if (RelinkSubGraphControlEdges(begin_anchors_pair, end_anchors_pair, index) != GRAPH_SUCCESS) { return GRAPH_FAILED; } auto out_anchor = begin_anchors_pair.first; GE_CHECK_NOTNULL(out_anchor); auto out_owner_node = out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(out_owner_node); auto in_anchor = end_anchors_pair.second; GE_CHECK_NOTNULL(in_anchor); auto in_owner_node = in_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(in_owner_node); // can not remove old control edge for (const auto &peer_in_anchor : out_control_peer_in_control_anchors_[index]) { GE_CHECK_NOTNULL(peer_in_anchor); GELOGI("add control edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(), peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx()); if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } for (const auto &peer_out_anchor : in_control_peer_out_control_anchors_[index]) { GE_CHECK_NOTNULL(peer_out_anchor); GELOGI("add control edge.src:%s, src idx:%d, dst:%s", peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(), in_owner_node->GetName().c_str()); if (GraphUtils::AddEdge(peer_out_anchor, in_owner_node->GetInControlAnchor()) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } for (const auto &peer_in_anchor : out_control_peer_in_data_anchors_[index]) { GE_CHECK_NOTNULL(peer_in_anchor); GELOGI("add out control 2 in data edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(), peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx()); if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } for (const auto &peer_in_anchor : out_data_peer_in_control_anchors_[index]) { GE_CHECK_NOTNULL(peer_in_anchor); GELOGI("add out data 2 in control edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(), peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx()); if (GraphUtils::AddEdge(out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } return GRAPH_SUCCESS; } graphStatus TransOpWithoutReshapeFusionPass::RelinkNodesWhenDescNotChanged( const pair &begin_anchors_pair, const pair &end_anchors_pair, const int index) { auto out_anchor = begin_anchors_pair.first; GE_CHECK_NOTNULL(out_anchor); auto out_owner_node = out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(out_owner_node); auto in_anchor = end_anchors_pair.second; GE_CHECK_NOTNULL(in_anchor); auto in_owner_node = in_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(in_owner_node); GELOGI("remove edge.src %s, src idx:%d, dst:%s, dst idx:%d", end_anchors_pair.first->GetOwnerNode()->GetName().c_str(), end_anchors_pair.first->GetIdx(), in_owner_node->GetName().c_str(), in_anchor->GetIdx()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(end_anchors_pair.first, in_anchor), "remove edge failed"); GELOGI("relink node.src node:%s, src idx:%d, dst node:%s, dst idx:%d", out_owner_node->GetName().c_str(), out_anchor->GetIdx(), in_owner_node->GetName().c_str(), in_anchor->GetIdx()); if (GraphUtils::AddEdge(out_anchor, in_anchor) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "add edge failed!src:%s, src idx:%d, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(), out_anchor->GetIdx(), in_owner_node->GetName().c_str(), in_anchor->GetIdx()); return GRAPH_FAILED; } else { auto old_peer_in_anchor = begin_anchors_pair.second; UpdateOutputName(out_anchor, old_peer_in_anchor, in_owner_node); auto old_peer_out_anchor = end_anchors_pair.first; UpdateInputName(old_peer_out_anchor, in_anchor, out_owner_node); } return RelinkControlEdgesWhenDescNotChanged(begin_anchors_pair, end_anchors_pair, index); } OpDescPtr TransOpWithoutReshapeFusionPass::GetFormatTransferOp(const GeTensorDesc &format_trans_input_desc, const GeTensorDesc &format_trans_output_desc) { static std::atomic_long atomic_fusion_format_transfer_op_count(1); auto fusion_format_transfer_op_count = atomic_fusion_format_transfer_op_count.fetch_add(1); std::stringstream format_transfer_op_name; format_transfer_op_name << "fusion_format_transfer_" << fusion_format_transfer_op_count; OpDescPtr format_transfer_op = MakeShared(format_transfer_op_name.str().c_str(), TRANSDATA); if (format_transfer_op == nullptr) { GELOGE(INTERNAL_ERROR, "new format transfer op failed!"); return nullptr; } GE_IF_BOOL_EXEC(!AttrUtils::SetInt(format_transfer_op, ATTR_NAME_INPUT_FORMAT, static_cast(format_trans_input_desc.GetFormat())), GELOGE(INTERNAL_ERROR, "set ATTR_NAME_INPUT_FORMAT failed"); return nullptr); GE_IF_BOOL_EXEC(!AttrUtils::SetInt(format_transfer_op, ATTR_NAME_OUTPUT_FORMAT, static_cast(format_trans_output_desc.GetFormat())), GELOGE(INTERNAL_ERROR, "set ATTR_NAME_OUTPUT_FORMAT failed"); return nullptr); string src_format = TypeUtils::FormatToSerialString(format_trans_input_desc.GetFormat()); string dst_format = TypeUtils::FormatToSerialString(format_trans_output_desc.GetFormat()); GE_IF_BOOL_EXEC(!AttrUtils::SetStr(format_transfer_op, kAttrNameSrcFormat, src_format), GELOGE(INTERNAL_ERROR, "set kAttrNameSrcFormat failed"); return nullptr); GE_IF_BOOL_EXEC(!AttrUtils::SetStr(format_transfer_op, kAttrNameDstFormat, dst_format), GELOGE(INTERNAL_ERROR, "set kAttrNameDstFormat failed"); return nullptr); GE_IF_BOOL_EXEC(format_transfer_op->AddInputDesc(format_trans_input_desc) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "add input desc failed"); return nullptr); GE_IF_BOOL_EXEC(format_transfer_op->AddOutputDesc(format_trans_output_desc) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "add output desc failed"); return nullptr); GE_IF_BOOL_EXEC(!ge::AttrUtils::SetBool(format_transfer_op, ATTR_NEED_COMPILE, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return nullptr); return format_transfer_op; } OpDescPtr TransOpWithoutReshapeFusionPass::GetCastOp(const GeTensorDesc &cast_input_desc, const GeTensorDesc &cast_output_desc) { static std::atomic_long atomic_fusion_cast_op_count(1); auto fusion_cast_op_count = atomic_fusion_cast_op_count.fetch_add(1); std::stringstream cast_op_name; cast_op_name << "fusion_cast_op_" << fusion_cast_op_count; auto node_op = ge::OperatorFactory::CreateOperator(cast_op_name.str(), CAST); auto cast_op = ge::OpDescUtils::GetOpDescFromOperator(node_op); node_op.BreakConnect(); if (cast_op == nullptr) { GELOGE(INTERNAL_ERROR, "new cast op failed!"); return nullptr; } const int default_input_index = 0; const int default_output_index = 0; if (cast_op->GetInputsSize() == 0) { GE_IF_BOOL_EXEC(cast_op->AddInputDesc(cast_input_desc) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "add input desc failed"); return nullptr); } else { GE_IF_BOOL_EXEC(cast_op->UpdateInputDesc(default_input_index, cast_input_desc) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "update input desc failed"); return nullptr); } if (cast_op->GetOutputsSize() == 0) { GE_IF_BOOL_EXEC(cast_op->AddOutputDesc(cast_output_desc) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "add output desc failed"); return nullptr); } else { GE_IF_BOOL_EXEC(cast_op->UpdateOutputDesc(default_output_index, cast_output_desc) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "update output desc failed"); return nullptr); } if (!AttrUtils::SetInt(cast_op, CAST_ATTR_DST_TYPE, static_cast(cast_output_desc.GetDataType()))) { GELOGE(INTERNAL_ERROR, "set dst_type attr failed"); return nullptr; } if (!AttrUtils::SetBool(cast_op, ATTR_NEED_COMPILE, true)) { GELOGE(INTERNAL_ERROR, "set need_compile attr failed"); return nullptr; } return cast_op; } bool TransOpWithoutReshapeFusionPass::InsertCastFirstCheck(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc) const { return out_desc.GetDataType() != in_desc.GetDataType() && out_desc.GetDataType() != DT_FLOAT16 && in_desc.GetDataType() == DT_FLOAT16; } void TransOpWithoutReshapeFusionPass::GetFormatTransferDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc, GeTensorDesc &format_transfer_input, GeTensorDesc &format_transfer_output) { bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc); if (insert_cast_first) { format_transfer_input = out_desc; format_transfer_input.SetDataType(in_desc.GetDataType()); format_transfer_output = in_desc; } else { format_transfer_input = out_desc; format_transfer_output = in_desc; format_transfer_output.SetDataType(out_desc.GetDataType()); } } void TransOpWithoutReshapeFusionPass::GetCastOpDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc, GeTensorDesc &cast_input, GeTensorDesc &cast_output) { bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc); if (insert_cast_first) { cast_input = out_desc; cast_output = out_desc; cast_output.SetDataType(in_desc.GetDataType()); } else { cast_input = in_desc; cast_input.SetDataType(out_desc.GetDataType()); cast_output = in_desc; } } void TransOpWithoutReshapeFusionPass::GetBeginOutDescAndEndInDesc(const int index, GeTensorDesc &out_desc, GeTensorDesc &in_desc) { auto nodes_anchor = sub_graph_anchors_[index]; auto out_peer_anchor = nodes_anchor.front().second; GE_CHECK_NOTNULL_JUST_RETURN(out_peer_anchor); auto out_owner_node = out_peer_anchor->GetOwnerNode(); GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node); auto out_peer_op_desc = out_owner_node->GetOpDesc(); GE_IF_BOOL_EXEC(out_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_peer_op_desc is nullptr"); return ); out_desc = out_peer_op_desc->GetInputDesc(out_peer_anchor->GetIdx()); auto in_peer_anchor = nodes_anchor.back().first; GE_CHECK_NOTNULL_JUST_RETURN(in_peer_anchor); auto in_owner_node = in_peer_anchor->GetOwnerNode(); GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node); auto in_peer_op_desc = in_owner_node->GetOpDesc(); GE_IF_BOOL_EXEC(in_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_peer_op_desc is nullptr"); return ); in_desc = in_peer_op_desc->GetOutputDesc(in_peer_anchor->GetIdx()); } graphStatus TransOpWithoutReshapeFusionPass::FormatFusion(const int index, OpDescPtr &format_transfer_op, int32_t &fusion_op_count, bool &fusion_continue) { GeTensorDesc out_desc; GeTensorDesc in_desc; GetBeginOutDescAndEndInDesc(index, out_desc, in_desc); GeTensorDesc format_transfer_input; GeTensorDesc format_transfer_output; GetFormatTransferDesc(out_desc, in_desc, format_transfer_input, format_transfer_output); if (out_desc.GetFormat() == in_desc.GetFormat() && (!ShapeEqualCheck(out_desc.GetShape(), in_desc.GetShape()) || !ShapeEqualCheck(out_desc.GetOriginShape(), in_desc.GetOriginShape()))) { SetRemainNode(sub_graph_anchors_[index]); return GRAPH_SUCCESS; } if (out_desc.GetFormat() != in_desc.GetFormat() && FusionFormatSupport(out_desc.GetFormat()) && FusionFormatSupport(in_desc.GetFormat())) { // create format transop format_transfer_op = GetFormatTransferOp(format_transfer_input, format_transfer_output); if (format_transfer_op == nullptr) { return GRAPH_FAILED; } if (OpAccuracyAbilityCheck(format_transfer_op)) { ++fusion_op_count; GELOGI("support format transfer op %s", format_transfer_op->GetName().c_str()); } else { GELOGW("ability not support.src format:%d, src datatype:%d, dst format:%d, dst datatype:%d", format_transfer_input.GetFormat(), format_transfer_input.GetDataType(), format_transfer_output.GetFormat(), format_transfer_output.GetDataType()); fusion_op_count = kInvalidFusionOpCount; } } else if (out_desc.GetFormat() != in_desc.GetFormat()) { SetRemainNode(sub_graph_anchors_[index]); return GRAPH_SUCCESS; } fusion_continue = true; return GRAPH_SUCCESS; } graphStatus TransOpWithoutReshapeFusionPass::DataTypeFusion(const int index, OpDescPtr &cast_op, int32_t &fusion_op_count) { GeTensorDesc out_desc; GeTensorDesc in_desc; GetBeginOutDescAndEndInDesc(index, out_desc, in_desc); GeTensorDesc cast_input; GeTensorDesc cast_output; GetCastOpDesc(out_desc, in_desc, cast_input, cast_output); if (fusion_op_count != kInvalidFusionOpCount && out_desc.GetDataType() != in_desc.GetDataType()) { // create cast op cast_op = GetCastOp(cast_input, cast_output); if (cast_op == nullptr) { fusion_op_count = kInvalidFusionOpCount; return GRAPH_FAILED; } if (OpAccuracyAbilityCheck(cast_op)) { ++fusion_op_count; GELOGI("support cast op %s. src format:%d, src datatype:%d, dst format:%d, dst datatype:%d", cast_op->GetName().c_str(), cast_input.GetFormat(), cast_input.GetDataType(), cast_output.GetFormat(), cast_output.GetDataType()); } else { GELOGW("ability not support.src format:%d, src datatype:%d, dst format:%d, dst datatype:%d", cast_input.GetFormat(), cast_input.GetDataType(), cast_output.GetFormat(), cast_output.GetDataType()); fusion_op_count = kInvalidFusionOpCount; } } return GRAPH_SUCCESS; } graphStatus TransOpWithoutReshapeFusionPass::TransOpFuseHandle(const ComputeGraphPtr &graph, const int index) { bool fusion_continue = false; OpDescPtr format_transfer_op = nullptr; int32_t fusion_op_count = 0; auto fortmat_fusion_ret = FormatFusion(index, format_transfer_op, fusion_op_count, fusion_continue); if (fortmat_fusion_ret != GRAPH_SUCCESS || !fusion_continue) { SetRemainNode(sub_graph_anchors_[index]); return GRAPH_SUCCESS; } OpDescPtr cast_op = nullptr; if (DataTypeFusion(index, cast_op, fusion_op_count) != GRAPH_SUCCESS) { SetRemainNode(sub_graph_anchors_[index]); return GRAPH_SUCCESS; } if (fusion_op_count != kInvalidFusionOpCount && fusion_op_count < transop_num_count_[index]) { GeTensorDesc out_desc; GeTensorDesc in_desc; GetBeginOutDescAndEndInDesc(index, out_desc, in_desc); bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc); if (InsertNewTransOp(graph, cast_op, format_transfer_op, index, insert_cast_first) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } else { // remain all nodes SetRemainNode(sub_graph_anchors_[index]); } return GRAPH_SUCCESS; } void TransOpWithoutReshapeFusionPass::RemoveNousedNodes(const ComputeGraphPtr &graph) { if (graph == nullptr) { return; } for (size_t i = 0; i < sub_graph_nodes_.size(); ++i) { if (sub_graph_has_reshape_node_[i]) { continue; } for (const auto &node : sub_graph_nodes_[i]) { GE_CHECK_NOTNULL_JUST_RETURN(node); // remove nodes if (!IsTransOp(node)) { continue; } auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL_JUST_RETURN(op_desc); bool node_remain_flag = op_desc->TryGetExtAttr(kRemainNode, false); if (node_remain_flag) { continue; } GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return ); GELOGI("remove node:%s", node->GetName().c_str()); if (graph->RemoveNode(node) != GRAPH_SUCCESS) { GELOGW("remove node failed!node:%s", node->GetName().c_str()); continue; } } } } graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) { GELOGI("[TransOpWithoutReshapeFusionPass]: optimize begin."); if (graph == nullptr) { return GRAPH_SUCCESS; } for (const auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); if (IsTransOp(node)) { continue; } bool is_unknown = false; auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); if (ret != GRAPH_SUCCESS) { GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), node->GetType().c_str()); continue; } if (is_unknown) { GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(), node->GetType().c_str()); continue; } GELOGI("Current normal node name: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str()); for (const auto &out_anchor : node->GetAllOutDataAnchors()) { GE_CHECK_NOTNULL(out_anchor); vector>> sub_graph_anchors; vector> nodes_list; if (GetSubGraphsBetweenNormalNode(out_anchor, sub_graph_anchors, nodes_list) != GRAPH_SUCCESS) { GELOGW("get transops failed!"); continue; } sub_graph_anchors_.swap(sub_graph_anchors); EraseInvalidAnchorsPair(); if (sub_graph_anchors_.empty()) { continue; } // check reshape node if (GetSubGraphNodesInfo() != GRAPH_SUCCESS) { continue; } // save control edge GetControlAnchors(); if (TransOpFuse(graph) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } } GELOGI("[TransOpWithoutReshapeFusionPass]: Optimize end."); return GRAPH_SUCCESS; } bool TransOpWithoutReshapeFusionPass::DescEqualCheck(ConstGeTensorDescPtr &desc_src, ConstGeTensorDescPtr &desc_dst) const { if (desc_src == nullptr || desc_dst == nullptr) { return false; } if (desc_src->GetFormat() != desc_dst->GetFormat() || desc_src->GetDataType() != desc_dst->GetDataType()) { return false; } if (!ShapeEqualCheck(desc_src->GetShape(), desc_dst->GetShape())) { return false; } return ShapeEqualCheck(desc_src->GetOriginShape(), desc_dst->GetOriginShape()); } bool TransOpWithoutReshapeFusionPass::ShapeEqualCheck(const GeShape &src, const GeShape &dst) const { if (src.GetDims().size() != dst.GetDims().size()) { return false; } for (size_t i = 0; i < src.GetDims().size(); ++i) { if (src.GetDim(i) != dst.GetDim(i)) { return false; } } return true; } graphStatus TransOpWithoutReshapeFusionPass::TransOpFuse(const ComputeGraphPtr &graph) { for (size_t i = 0; i < sub_graph_anchors_.size(); ++i) { if (sub_graph_has_reshape_node_[i]) { continue; } auto nodes_anchor = sub_graph_anchors_[i]; auto out_anchor = nodes_anchor.front().first; GE_CHECK_NOTNULL(out_anchor); auto out_op_desc = out_anchor->GetOwnerNode()->GetOpDesc(); GE_CHECK_NOTNULL(out_op_desc); auto out_desc = out_op_desc->GetOutputDescPtr(out_anchor->GetIdx()); GE_CHECK_NOTNULL(out_desc); auto in_anchor = nodes_anchor.back().second; GE_CHECK_NOTNULL(in_anchor); auto in_op_desc = in_anchor->GetOwnerNode()->GetOpDesc(); GE_CHECK_NOTNULL(in_op_desc); auto in_desc = in_op_desc->GetInputDescPtr(in_anchor->GetIdx()); GE_CHECK_NOTNULL(in_desc); if (FusionFormatSupport(out_desc->GetFormat()) && DescEqualCheck(out_desc, in_desc)) { // relink begin_out to end_in if (RelinkNodesWhenDescNotChanged(nodes_anchor.front(), nodes_anchor.back(), static_cast(i)) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } else { if (TransOpFuseHandle(graph, static_cast(i)) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } } RemoveNousedNodes(graph); return GRAPH_SUCCESS; } graphStatus TransOpWithoutReshapeFusionPass::AddTransNode(const ComputeGraphPtr &graph, const OpDescPtr &transop, NodePtr &trans_node) { if (graph == nullptr) { return GRAPH_SUCCESS; } if (transop == nullptr) { return GRAPH_SUCCESS; } trans_node = graph->AddNode(transop); if (trans_node == nullptr) { GELOGE(GRAPH_FAILED, "add node failed!"); return GRAPH_FAILED; } return GRAPH_SUCCESS; } graphStatus TransOpWithoutReshapeFusionPass::GetTransNode(const ComputeGraphPtr &graph, const OpDescPtr &cast_op, const OpDescPtr &format_transfer_op, const bool insert_cast_first, std::vector &new_trans_nodes) { NodePtr format_transfer_node; if (AddTransNode(graph, format_transfer_op, format_transfer_node) != GRAPH_SUCCESS) { return GRAPH_FAILED; } NodePtr cast_node; if (AddTransNode(graph, cast_op, cast_node) != GRAPH_SUCCESS) { return GRAPH_FAILED; } if (insert_cast_first) { if (cast_node != nullptr) { new_trans_nodes.push_back(cast_node); } if (format_transfer_node != nullptr) { new_trans_nodes.push_back(format_transfer_node); } } else { if (format_transfer_node != nullptr) { new_trans_nodes.push_back(format_transfer_node); } if (cast_node != nullptr) { new_trans_nodes.push_back(cast_node); } } return GRAPH_SUCCESS; } graphStatus TransOpWithoutReshapeFusionPass::InsertNewTransOp(const ComputeGraphPtr &graph, const OpDescPtr &cast_op, const OpDescPtr &format_transfer_op, const int index, const bool insert_cast_first) { std::vector new_trans_nodes; if (GetTransNode(graph, cast_op, format_transfer_op, insert_cast_first, new_trans_nodes) != GRAPH_SUCCESS) { return GRAPH_FAILED; } if (new_trans_nodes.empty()) { GELOGI("No new trans node. Do not need insert new transop."); return GRAPH_SUCCESS; } pair begin_out = sub_graph_anchors_[index].front(); pair end_in = sub_graph_anchors_[index].back(); auto out_anchor = begin_out.first; GE_CHECK_NOTNULL(out_anchor); auto out_owner_node = out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(out_owner_node); auto in_anchor = end_in.second; GE_CHECK_NOTNULL(in_anchor); auto in_owner_node = in_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(in_owner_node); GELOGI("remove edge.src:%s, src idx:%d, dst:%s, dst idx:%d", end_in.first->GetOwnerNode()->GetName().c_str(), end_in.first->GetIdx(), in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(end_in.first, in_anchor), "remove edge failed"); GELOGI("add edge.src:%s, src idx:%d, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(), out_anchor->GetIdx(), new_trans_nodes.front()->GetName().c_str()); if (GraphUtils::AddEdge(out_anchor, new_trans_nodes.front()->GetInAnchor(0)) != GRAPH_SUCCESS) { return GRAPH_FAILED; } else { auto old_peer_in_anchor = begin_out.second; GE_CHECK_NOTNULL(old_peer_in_anchor); UpdateOutputName(out_anchor, old_peer_in_anchor, in_owner_node); } if (new_trans_nodes.size() > 1) { GELOGI("add edge.src:%s, dst:%s", new_trans_nodes.front()->GetName().c_str(), new_trans_nodes.back()->GetName().c_str()); if (GraphUtils::AddEdge(new_trans_nodes.front()->GetOutAnchor(0), new_trans_nodes.back()->GetInAnchor(0)) != GRAPH_SUCCESS) { return GRAPH_FAILED; } else { auto old_peer_out_anchor = end_in.first; GE_CHECK_NOTNULL(old_peer_out_anchor); UpdateInputName(old_peer_out_anchor, in_anchor, out_owner_node); } } GELOGI("add edge.src:%s, dst:%s, dst idx:%d", new_trans_nodes.back()->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx()); if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutAnchor(0), in_anchor) != GRAPH_SUCCESS) { return GRAPH_FAILED; } return RelinkControlEdge(index, out_anchor, new_trans_nodes); } graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdge(const int index, const OutDataAnchorPtr &out_anchor, const vector &new_trans_nodes) { GE_CHECK_NOTNULL(out_anchor); if (new_trans_nodes.front() == nullptr || new_trans_nodes.back() == nullptr) { return GRAPH_FAILED; } if (sub_graph_has_control_edge_[index]) { GELOGI("add control edge.src:%s, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(), new_trans_nodes.front()->GetName().c_str()); if (GraphUtils::AddEdge(out_anchor->GetOwnerNode()->GetOutControlAnchor(), new_trans_nodes.front()->GetInControlAnchor()) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } for (const auto &peer_in_anchor : out_control_peer_in_control_anchors_[index]) { GE_CHECK_NOTNULL(peer_in_anchor); GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(), peer_in_anchor->GetOwnerNode()->GetName().c_str()); if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } for (const auto &peer_out_anchor : in_control_peer_out_control_anchors_[index]) { GE_CHECK_NOTNULL(peer_out_anchor); GELOGI("add control edge.src:%s, dst:%s", peer_out_anchor->GetOwnerNode()->GetName().c_str(), new_trans_nodes.front()->GetName().c_str()); if (GraphUtils::AddEdge(peer_out_anchor, new_trans_nodes.front()->GetInControlAnchor()) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } for (const auto &peer_in_anchor : out_control_peer_in_data_anchors_[index]) { GE_CHECK_NOTNULL(peer_in_anchor); GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(), peer_in_anchor->GetOwnerNode()->GetName().c_str()); if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } for (const auto &peer_in_anchor : out_data_peer_in_control_anchors_[index]) { GE_CHECK_NOTNULL(peer_in_anchor); GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(), peer_in_anchor->GetOwnerNode()->GetName().c_str()); if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutDataAnchor(0), peer_in_anchor) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } if (sub_graph_has_out_data_peer_in_control_edge_[index]) { auto in_anchor = sub_graph_anchors_[index].back().second; GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutDataAnchor(0), in_anchor->GetOwnerNode()->GetInControlAnchor()) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } return GRAPH_SUCCESS; } bool TransOpWithoutReshapeFusionPass::OpAccuracyAbilityCheck(const OpDescPtr &op_desc) { auto instance = GELib::GetInstance(); if ((instance == nullptr) || (!instance->InitFlag())) { GELOGW("GELib is not initialized!"); return false; } if (op_desc == nullptr) { return false; } OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj(); vector op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); if (op_infos.empty()) { GELOGI("Can not get op info by op type:%s", op_desc->GetType().c_str()); return false; } std::string unsupported_reason; for (const auto &it : op_infos) { auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores(); auto &kernel_name = it.opKernelLib; auto kernel_info_store = kernel_map.find(kernel_name); if (kernel_info_store != kernel_map.end()) { if (kernel_info_store->second != nullptr && kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) { op_desc->SetOpEngineName(it.engine); op_desc->SetOpKernelLibName(kernel_name); GELOGI("Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), it.engine.c_str(), op_desc->GetName().c_str()); return true; } } } GELOGI("op %s CheckAccuracySupported failed!reason:%s", op_desc->GetType().c_str(), unsupported_reason.c_str()); return false; } bool TransOpWithoutReshapeFusionPass::FusionFormatSupport(Format format) { return format == FORMAT_NCHW || format == FORMAT_NHWC || format == FORMAT_FRACTAL_Z || format == FORMAT_NC1HWC0; } graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphsBetweenNormalNode( const OutDataAnchorPtr &out_anchor, std::vector>> &sub_graphs_out, vector> &nodes_list) { graphStatus ret = GRAPH_SUCCESS; if (out_anchor == nullptr) { return GRAPH_FAILED; } for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { if (peer_in_anchor == nullptr || peer_in_anchor->GetOwnerNode() == nullptr || peer_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { continue; } nodes_list.emplace_back(out_anchor, peer_in_anchor); auto peer_in_node = peer_in_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(peer_in_node); if (!IsTransOp(peer_in_node)) { sub_graphs_out.push_back(nodes_list); nodes_list.pop_back(); } else { for (const auto &peer_out_anchor : peer_in_node->GetAllOutDataAnchors()) { ret = GetSubGraphsBetweenNormalNode(peer_out_anchor, sub_graphs_out, nodes_list); if (ret != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "get all transops between normal node failed!node:%s", peer_in_node->GetName().c_str()); return GRAPH_FAILED; } } nodes_list.pop_back(); } } return GRAPH_SUCCESS; } bool TransOpWithoutReshapeFusionPass::IsTransOp(const NodePtr &node) { // The caller guarantees that the pointer is not null. return node->GetType() == CAST || node->GetType() == RESHAPE || node->GetType() == TRANSPOSE || node->GetType() == TRANSPOSED || node->GetType() == TRANSDATA; } } // namespace ge