You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
792 lines
35 KiB
792 lines
35 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "graph/passes/same_transdata_breadth_fusion_pass.h"
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
#include "common/ge_inner_error_codes.h"
|
|
#include "common/types.h"
|
|
#include "graph/debug/ge_attr_define.h"
|
|
#include "graph/utils/graph_utils.h"
|
|
#include "graph/utils/op_desc_utils.h"
|
|
#include "init/gelib.h"
|
|
|
|
namespace {
|
|
const int kNoTransOp = 1;
|
|
} // namespace
|
|
|
|
namespace ge {
|
|
void SameTransdataBreadthFusionPass::GetSubGraphNodesInfo() {
|
|
vector<vector<NodePtr>> before_transdata_nodes(sub_graph_anchors_.size());
|
|
vector<pair<int, InDataAnchorPtr>> all_transdata_nodes;
|
|
for (size_t i = 0; i < sub_graph_anchors_.size(); ++i) {
|
|
auto nodes_anchor = sub_graph_anchors_[i];
|
|
auto iter = nodes_anchor.begin();
|
|
auto first_out_anchor = iter->first;
|
|
GE_CHECK_NOTNULL_JUST_RETURN(first_out_anchor);
|
|
before_transdata_nodes[i].push_back(first_out_anchor->GetOwnerNode());
|
|
GELOGD("index:%zu, node:%s, type:%s", i, first_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
first_out_anchor->GetOwnerNode()->GetType().c_str());
|
|
while (iter != nodes_anchor.end()) {
|
|
auto in_anchor = iter->second;
|
|
GE_CHECK_NOTNULL_JUST_RETURN(in_anchor);
|
|
auto in_node = in_anchor->GetOwnerNode();
|
|
GELOGD("index:%zu, node:%s, type:%s", i, first_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
first_out_anchor->GetOwnerNode()->GetType().c_str());
|
|
if (in_node->GetType() == TRANSDATA) {
|
|
all_transdata_nodes.emplace_back(i, in_anchor);
|
|
} else {
|
|
before_transdata_nodes[i].push_back(in_node);
|
|
}
|
|
++iter;
|
|
}
|
|
GELOGD("index:%zu, before trandata node size:%zu", i, before_transdata_nodes[i].size());
|
|
}
|
|
before_transdata_nodes_.swap(before_transdata_nodes);
|
|
all_transdata_nodes_.swap(all_transdata_nodes);
|
|
}
|
|
|
|
OpDescPtr SameTransdataBreadthFusionPass::GetCastOp(const GeTensorDesc &in_desc, const GeTensorDesc &out_desc) {
|
|
static std::atomic_long atomic_fusion_cast_op_count(1);
|
|
auto fusion_cast_op_count = atomic_fusion_cast_op_count.fetch_add(1);
|
|
std::stringstream cast_op_name;
|
|
cast_op_name << "fusion_cast_" << fusion_cast_op_count;
|
|
auto node_op = ge::OperatorFactory::CreateOperator(cast_op_name.str().c_str(), CAST);
|
|
auto cast_op = ge::OpDescUtils::GetOpDescFromOperator(node_op);
|
|
node_op.BreakConnect();
|
|
if (cast_op == nullptr) {
|
|
GELOGE(INTERNAL_ERROR, "new fusion cast op failed!");
|
|
return nullptr;
|
|
}
|
|
const int default_output_index = 0;
|
|
const int default_input_index = 0;
|
|
if (cast_op->GetInputsSize() == 0) {
|
|
if (cast_op->AddInputDesc(in_desc) != GRAPH_SUCCESS) {
|
|
GELOGW("AddInputDesc fail.");
|
|
}
|
|
} else {
|
|
if (cast_op->UpdateInputDesc(default_input_index, in_desc) != GRAPH_SUCCESS) {
|
|
GELOGW("UpdateInputDesc fail");
|
|
}
|
|
}
|
|
|
|
if (cast_op->GetOutputsSize() == 0) {
|
|
if (cast_op->AddOutputDesc(out_desc) != GRAPH_SUCCESS) {
|
|
GELOGW("AddOutputDesc fail.");
|
|
}
|
|
} else {
|
|
if (cast_op->UpdateOutputDesc(default_output_index, out_desc) != GRAPH_SUCCESS) {
|
|
GELOGW("UpdateOutputDesc fail");
|
|
}
|
|
}
|
|
if (!AttrUtils::SetInt(cast_op, CAST_ATTR_DST_TYPE, static_cast<int64_t>(out_desc.GetDataType()))) {
|
|
GELOGE(INTERNAL_ERROR, "set dst_type attr failed");
|
|
return nullptr;
|
|
}
|
|
return cast_op;
|
|
}
|
|
|
|
void SameTransdataBreadthFusionPass::InsertSameTransdataNodeIndex(int anchors_index,
|
|
vector<int> &same_transdata_nodes) {
|
|
auto same_iter = same_transdata_nodes.begin();
|
|
while (same_iter != same_transdata_nodes.end()) {
|
|
if (before_transdata_nodes_[anchors_index].size() <= before_transdata_nodes_[*same_iter].size()) {
|
|
same_transdata_nodes.insert(same_iter, anchors_index);
|
|
return;
|
|
}
|
|
++same_iter;
|
|
}
|
|
|
|
same_transdata_nodes.push_back(anchors_index);
|
|
}
|
|
|
|
std::set<std::string> SameTransdataBreadthFusionPass::GetInControlIdentityNodes(const NodePtr &node,
|
|
int subgraph_index) {
|
|
std::set<std::string> in_node_names;
|
|
for (const auto &in_node : node->GetInControlNodes()) {
|
|
if (in_node->GetType() == IDENTITY) {
|
|
in_node_names.insert(in_node->GetName());
|
|
}
|
|
}
|
|
for (const auto &subgraph_node : before_transdata_nodes_[subgraph_index]) {
|
|
for (const auto &in_node : subgraph_node->GetInControlNodes()) {
|
|
if (in_node->GetType() == IDENTITY) {
|
|
in_node_names.insert(in_node->GetName());
|
|
}
|
|
}
|
|
}
|
|
GELOGD("control in nodes for %s(%d): %zu", node->GetName().c_str(), subgraph_index, in_node_names.size());
|
|
return in_node_names;
|
|
}
|
|
|
|
void SameTransdataBreadthFusionPass::GetSameTransdataNode(vector<int> &same_transdata_nodes) {
|
|
auto iter = all_transdata_nodes_.begin();
|
|
same_transdata_nodes.push_back(iter->first);
|
|
|
|
auto node_for_compare_in_anchor = iter->second;
|
|
GE_CHECK_NOTNULL_JUST_RETURN(node_for_compare_in_anchor);
|
|
auto node_for_compare = node_for_compare_in_anchor->GetOwnerNode();
|
|
|
|
// Get op-desc, input/output desc, in-control-edges-from-identity, as the compare-key
|
|
auto op_desc_for_compare = node_for_compare->GetOpDesc();
|
|
GE_CHECK_NOTNULL_JUST_RETURN(op_desc_for_compare);
|
|
string op_compare_stream_label;
|
|
(void)AttrUtils::GetStr(op_desc_for_compare, ATTR_NAME_STREAM_LABEL, op_compare_stream_label);
|
|
auto op_compare_in_ctrl_nodes = GetInControlIdentityNodes(node_for_compare, iter->first);
|
|
auto input_desc_for_compare = op_desc_for_compare->GetInputDescPtr(node_for_compare_in_anchor->GetIdx());
|
|
GE_CHECK_NOTNULL_JUST_RETURN(input_desc_for_compare);
|
|
auto output_desc_for_compare = op_desc_for_compare->GetOutputDescPtr(0);
|
|
GE_CHECK_NOTNULL_JUST_RETURN(output_desc_for_compare);
|
|
|
|
iter = all_transdata_nodes_.erase(iter);
|
|
while (iter != all_transdata_nodes_.end()) {
|
|
auto in_anchor = iter->second;
|
|
if (in_anchor == nullptr) {
|
|
continue;
|
|
}
|
|
auto node_tmp = in_anchor->GetOwnerNode();
|
|
if (node_tmp == node_for_compare) {
|
|
++iter;
|
|
continue;
|
|
}
|
|
GE_CHECK_NOTNULL_JUST_RETURN(node_tmp);
|
|
auto op_desc_tmp = node_tmp->GetOpDesc();
|
|
GE_CHECK_NOTNULL_JUST_RETURN(op_desc_tmp);
|
|
auto input_desc_tmp = op_desc_tmp->GetInputDescPtr(in_anchor->GetIdx());
|
|
auto output_desc_tmp = op_desc_tmp->GetOutputDescPtr(0);
|
|
string op_tmp_stream_label;
|
|
(void)AttrUtils::GetStr(op_desc_tmp, ATTR_NAME_STREAM_LABEL, op_tmp_stream_label);
|
|
auto op_tmp_in_ctrl_nodes = GetInControlIdentityNodes(node_tmp, iter->first);
|
|
GE_CHECK_NOTNULL_JUST_RETURN(input_desc_tmp);
|
|
GE_CHECK_NOTNULL_JUST_RETURN(output_desc_tmp);
|
|
|
|
if ((op_compare_stream_label == op_tmp_stream_label) &&
|
|
(input_desc_tmp->GetFormat() == input_desc_for_compare->GetFormat()) &&
|
|
(output_desc_tmp->GetFormat() == output_desc_for_compare->GetFormat()) &&
|
|
(op_compare_in_ctrl_nodes == op_tmp_in_ctrl_nodes)) {
|
|
GELOGD("same transdata node:%s, src node:%s", node_tmp->GetName().c_str(), node_for_compare->GetName().c_str());
|
|
InsertSameTransdataNodeIndex(iter->first, same_transdata_nodes);
|
|
iter = all_transdata_nodes_.erase(iter);
|
|
} else {
|
|
++iter;
|
|
}
|
|
}
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::ReLinkDataOutput2PreNode(const NodePtr &transdata_node,
|
|
const OutDataAnchorPtr &pre_out_anchor,
|
|
const NodePtr &relink_node) {
|
|
GE_CHECK_NOTNULL(pre_out_anchor);
|
|
GE_CHECK_NOTNULL(transdata_node);
|
|
auto transdata_peer_out_control_anchor = pre_out_anchor->GetOwnerNode()->GetOutControlAnchor();
|
|
for (auto &out_anchor : transdata_node->GetAllOutDataAnchors()) {
|
|
// relink data edge
|
|
for (auto &transdata_peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
|
|
if (transdata_peer_in_anchor->GetOwnerNode() == relink_node) {
|
|
continue;
|
|
}
|
|
GELOGI("remove edge.src:%s, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::RemoveEdge(out_anchor, transdata_peer_in_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "remove edge failed!src node:%s, dst node:%s", transdata_node->GetName().c_str(),
|
|
transdata_peer_in_anchor->GetOwnerNode()->GetName().c_str());
|
|
return GRAPH_FAILED;
|
|
}
|
|
GELOGI("add edge.src:%s, dst:%s", pre_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::AddEdge(pre_out_anchor, transdata_peer_in_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "add edge failed!src node:%s, dst node:%s",
|
|
pre_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_anchor->GetOwnerNode()->GetName().c_str());
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
}
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::ReLinkOutDataPeerInControlNodes2PreNode(
|
|
const NodePtr &transdata_node, const OutDataAnchorPtr &pre_out_anchor) {
|
|
GE_CHECK_NOTNULL(pre_out_anchor);
|
|
GE_CHECK_NOTNULL(transdata_node);
|
|
auto transdata_peer_out_control_anchor = pre_out_anchor->GetOwnerNode()->GetOutControlAnchor();
|
|
for (auto &out_anchor : transdata_node->GetAllOutDataAnchors()) {
|
|
for (auto &transdata_peer_in_control_anchor : out_anchor->GetPeerInControlAnchors()) {
|
|
GELOGD("remove edge.src:%s, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_control_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::RemoveEdge(out_anchor, transdata_peer_in_control_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "remove edge failed!src node:%s, dst node:%s", transdata_node->GetName().c_str(),
|
|
transdata_peer_in_control_anchor->GetOwnerNode()->GetName().c_str());
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
if (transdata_peer_out_control_anchor == nullptr) {
|
|
GELOGD("add edge.src:%s, dst:%s", pre_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_control_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::AddEdge(pre_out_anchor, transdata_peer_in_control_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "add edge failed!src node:%s, dst node:%s",
|
|
pre_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_control_anchor->GetOwnerNode()->GetName().c_str());
|
|
return GRAPH_FAILED;
|
|
}
|
|
} else {
|
|
GELOGD("add edge.src node:%s, dst node:%s", pre_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_control_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::AddEdge(transdata_peer_out_control_anchor, transdata_peer_in_control_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "add edge failed!src node:%s, dst node:%s",
|
|
pre_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_control_anchor->GetOwnerNode()->GetName().c_str());
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::ReLinkTransdataOutput2PreNode(const NodePtr &transdata_node,
|
|
const OutDataAnchorPtr &pre_out_anchor,
|
|
const NodePtr &relink_node) {
|
|
GE_CHECK_NOTNULL(pre_out_anchor);
|
|
if (ReLinkDataOutput2PreNode(transdata_node, pre_out_anchor, relink_node) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
if (ReLinkOutDataPeerInControlNodes2PreNode(transdata_node, pre_out_anchor) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
auto transdata_peer_out_control_anchor = pre_out_anchor->GetOwnerNode()->GetOutControlAnchor();
|
|
|
|
return ReLinkTransdataControlOutput2PreNode(transdata_node, pre_out_anchor, transdata_peer_out_control_anchor);
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::ReLinkOutControlPeerInControlAnchors(
|
|
const NodePtr &transdata_node_keep, const OutDataAnchorPtr &pre_out_anchor,
|
|
const OutControlAnchorPtr &transdata_peer_out_control_anchor) {
|
|
GE_CHECK_NOTNULL(transdata_node_keep);
|
|
GE_CHECK_NOTNULL(pre_out_anchor);
|
|
auto out_control_anchor = transdata_node_keep->GetOutControlAnchor();
|
|
if (out_control_anchor == nullptr) {
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
for (auto &transdata_peer_in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) {
|
|
GELOGD("remove edge.src:%s, dst:%s", transdata_node_keep->GetName().c_str(),
|
|
transdata_peer_in_control_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::RemoveEdge(out_control_anchor, transdata_peer_in_control_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "remove transdata control edge failed!");
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
if (transdata_peer_out_control_anchor == nullptr) {
|
|
GELOGD("add edge.src:%s, dst:%s", pre_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_control_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::AddEdge(pre_out_anchor, transdata_peer_in_control_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "add control edge failed!");
|
|
return GRAPH_FAILED;
|
|
}
|
|
} else {
|
|
GELOGD("add edge.src:%s, dst:%s", transdata_peer_out_control_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_control_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::AddEdge(transdata_peer_out_control_anchor, transdata_peer_in_control_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "add control edge failed!");
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
}
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::ReLinkOutControlPeerInDataAnchors(
|
|
const NodePtr &transdata_node_keep, const OutDataAnchorPtr &pre_out_anchor,
|
|
const OutControlAnchorPtr &transdata_peer_out_control_anchor) {
|
|
GE_CHECK_NOTNULL(transdata_node_keep);
|
|
GE_CHECK_NOTNULL(pre_out_anchor);
|
|
auto out_control_anchor = transdata_node_keep->GetOutControlAnchor();
|
|
if (out_control_anchor == nullptr) {
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
for (auto &transdata_peer_in_data_anchor : out_control_anchor->GetPeerInDataAnchors()) {
|
|
if (transdata_peer_in_data_anchor == nullptr || transdata_peer_in_data_anchor->GetOwnerNode() == nullptr) {
|
|
continue;
|
|
}
|
|
GELOGD("remove edge.src:%s, dst:%s", transdata_node_keep->GetName().c_str(),
|
|
transdata_peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::RemoveEdge(out_control_anchor, transdata_peer_in_data_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "remove transdata control edge failed!");
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
if (transdata_peer_out_control_anchor == nullptr) {
|
|
GELOGD("add edge.src:%s, dst:%s", pre_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::AddEdge(pre_out_anchor, transdata_peer_in_data_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "add control edge failed!");
|
|
return GRAPH_FAILED;
|
|
}
|
|
} else {
|
|
GELOGD("add edge.src:%s, dst:%s", transdata_peer_out_control_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::AddEdge(transdata_peer_out_control_anchor, transdata_peer_in_data_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "add control edge failed!");
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
}
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::ReLinkTransdataControlOutput2PreNode(
|
|
const NodePtr &transdata_node_keep, const OutDataAnchorPtr &pre_out_anchor,
|
|
const OutControlAnchorPtr &transdata_peer_out_control_anchor) {
|
|
if (ReLinkOutControlPeerInControlAnchors(transdata_node_keep, pre_out_anchor, transdata_peer_out_control_anchor) !=
|
|
GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
return ReLinkOutControlPeerInDataAnchors(transdata_node_keep, pre_out_anchor, transdata_peer_out_control_anchor);
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::Run(ComputeGraphPtr graph) {
|
|
GELOGI("[SameTransdataBreadthFusionPass]: optimize begin.");
|
|
if (graph == nullptr) {
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
for (auto &node : graph->GetDirectNode()) {
|
|
if (IsTransOp(node) || node->GetOutDataNodesSize() <= 1) {
|
|
continue;
|
|
}
|
|
|
|
GELOGD("Current normal node name: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str());
|
|
for (auto &out_anchor : node->GetAllOutDataAnchors()) {
|
|
vector<std::vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors;
|
|
std::vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> nodes_list;
|
|
if (GetSubGraphsBetweenNormalAndTransdataNode(out_anchor, sub_graph_anchors, nodes_list) != GRAPH_SUCCESS) {
|
|
GELOGW("get transop failed!");
|
|
continue;
|
|
}
|
|
|
|
if (sub_graph_anchors.size() <= 1) {
|
|
continue;
|
|
}
|
|
sub_graph_anchors_.swap(sub_graph_anchors);
|
|
|
|
// check reshape node
|
|
GetSubGraphNodesInfo();
|
|
GELOGD("all trandata node size:%zu", all_transdata_nodes_.size());
|
|
if (ExtractTransNode(graph) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
}
|
|
|
|
GELOGI("[SameTransdataBreadthFusionPass]: Optimize success.");
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::ExtractTransNode(const ComputeGraphPtr &graph) {
|
|
while (all_transdata_nodes_.size() > 1) {
|
|
vector<int> same_transdata_nodes;
|
|
GetSameTransdataNode(same_transdata_nodes);
|
|
GELOGD("same transdata node size:%zu", same_transdata_nodes.size());
|
|
// reuse transdata ,new cast
|
|
if (same_transdata_nodes.size() <= 1) {
|
|
continue;
|
|
}
|
|
|
|
int anchors_index = same_transdata_nodes[0];
|
|
auto transdata_in_anchor = sub_graph_anchors_[anchors_index].back().second;
|
|
GE_CHECK_NOTNULL(transdata_in_anchor);
|
|
auto transdata_node_keep = transdata_in_anchor->GetOwnerNode();
|
|
auto transdata_out_anchor = transdata_node_keep->GetOutDataAnchor(0);
|
|
GELOGD("anchor index %d, before transdata node size:%zu", anchors_index,
|
|
before_transdata_nodes_[anchors_index].size());
|
|
if (before_transdata_nodes_[anchors_index].size() > 1) {
|
|
if (RelinkRemainTransdata(graph, same_transdata_nodes) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
|
|
if (LinkNewCastNode2RemainTransdata(graph, same_transdata_nodes, transdata_out_anchor, transdata_node_keep) !=
|
|
GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::RelinkRemainTransdata(const ComputeGraphPtr &graph,
|
|
const vector<int> &same_transdata_nodes) {
|
|
int anchors_index = same_transdata_nodes[0];
|
|
auto head_node_anchor = sub_graph_anchors_[anchors_index][0].first;
|
|
GE_CHECK_NOTNULL(head_node_anchor);
|
|
auto head_node = head_node_anchor->GetOwnerNode();
|
|
GE_CHECK_NOTNULL(head_node->GetOpDesc());
|
|
auto head_output_desc = head_node->GetOpDesc()->GetOutputDescPtr(head_node_anchor->GetIdx());
|
|
auto transdata_in_anchor = sub_graph_anchors_[anchors_index].back().second;
|
|
GE_CHECK_NOTNULL(transdata_in_anchor);
|
|
auto transdata_node_keep = transdata_in_anchor->GetOwnerNode();
|
|
GE_CHECK_NOTNULL(transdata_node_keep->GetOpDesc());
|
|
auto transdata_out_anchor = transdata_node_keep->GetOutDataAnchor(0);
|
|
GELOGD("head node:%s, transdata node keep:%s", head_node->GetName().c_str(), transdata_node_keep->GetName().c_str());
|
|
bool reuse_nodes = AllNodeBeforeTransdataHasOneDataOut(anchors_index);
|
|
UpdateTransdataDesc(transdata_in_anchor, transdata_node_keep->GetOpDesc(), head_output_desc);
|
|
auto transdata_peer_out_anchor = sub_graph_anchors_[anchors_index].back().first;
|
|
GE_CHECK_NOTNULL(transdata_peer_out_anchor);
|
|
auto transdata_peer_out_node = transdata_peer_out_anchor->GetOwnerNode();
|
|
GELOGI("remove edge.src:%s, dst:%s", transdata_peer_out_node->GetName().c_str(),
|
|
transdata_node_keep->GetName().c_str());
|
|
if (GraphUtils::RemoveEdge(transdata_peer_out_anchor, transdata_in_anchor) != GRAPH_SUCCESS) {
|
|
GELOGW("remove edge failed!out node %s, in node %s", transdata_peer_out_node->GetName().c_str(),
|
|
transdata_node_keep->GetName().c_str());
|
|
}
|
|
|
|
GELOGI("add edge.out node %s, in node %s", head_node->GetName().c_str(), transdata_node_keep->GetName().c_str());
|
|
if (GraphUtils::AddEdge(head_node_anchor, transdata_in_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "add edge failed!out node %s, in node %s", head_node->GetName().c_str(),
|
|
transdata_node_keep->GetName().c_str());
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
NodePtr relink_node;
|
|
// relink to transdata output nodes
|
|
if (reuse_nodes) {
|
|
if (ReuseNodesBeforeTransdata(anchors_index, transdata_out_anchor, relink_node) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
if (ReLinkTransdataOutput2PreNode(transdata_node_keep, transdata_peer_out_anchor, relink_node) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
} else {
|
|
OutDataAnchorPtr pre_out_anchor = transdata_out_anchor;
|
|
if (AddCastNode(graph, same_transdata_nodes[0], pre_out_anchor, relink_node) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
if (ReLinkTransdataOutput2PreNode(transdata_node_keep, pre_out_anchor, relink_node) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
void SameTransdataBreadthFusionPass::UpdateTransdataDesc(const InDataAnchorPtr &transdata_in_anchor,
|
|
const OpDescPtr &transdata_op_desc,
|
|
const ConstGeTensorDescPtr &head_output_desc) {
|
|
if (transdata_op_desc == nullptr || transdata_in_anchor == nullptr || head_output_desc == nullptr) {
|
|
return;
|
|
}
|
|
auto mutable_input_desc = transdata_op_desc->MutableInputDesc(transdata_in_anchor->GetIdx());
|
|
GE_CHECK_NOTNULL_JUST_RETURN(mutable_input_desc);
|
|
mutable_input_desc->SetDataType(head_output_desc->GetDataType());
|
|
mutable_input_desc->SetOriginDataType(head_output_desc->GetOriginDataType());
|
|
auto mutable_output_desc = transdata_op_desc->MutableOutputDesc(0);
|
|
GE_CHECK_NOTNULL_JUST_RETURN(mutable_output_desc);
|
|
mutable_output_desc->SetDataType(head_output_desc->GetDataType());
|
|
mutable_output_desc->SetOriginDataType(head_output_desc->GetOriginDataType());
|
|
// maybe need to check support
|
|
}
|
|
|
|
bool SameTransdataBreadthFusionPass::AllNodeBeforeTransdataHasOneDataOut(int anchors_index) {
|
|
for (size_t i = 1; i < before_transdata_nodes_[anchors_index].size(); ++i) {
|
|
auto node = before_transdata_nodes_[anchors_index][i];
|
|
if (node == nullptr) {
|
|
return false;
|
|
}
|
|
if (node->GetOutDataNodes().size() > 1 || node->GetInDataNodes().size() > 1) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::ReuseNodesBeforeTransdata(int anchors_index,
|
|
const OutDataAnchorPtr &transdata_out_anchor,
|
|
NodePtr &relink_node) {
|
|
auto head_node_anchor = sub_graph_anchors_[anchors_index][0].first;
|
|
auto head_node_peer_anchor = sub_graph_anchors_[anchors_index][0].second;
|
|
GE_CHECK_NOTNULL(head_node_anchor);
|
|
GE_CHECK_NOTNULL(head_node_peer_anchor);
|
|
GE_CHECK_NOTNULL(transdata_out_anchor);
|
|
GELOGI("remove edge.src:%s, dst:%s", head_node_anchor->GetOwnerNode()->GetName().c_str(),
|
|
head_node_peer_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (head_node_anchor->IsLinkedWith(head_node_peer_anchor)) {
|
|
if (GraphUtils::RemoveEdge(head_node_anchor, head_node_peer_anchor) != GRAPH_SUCCESS) {
|
|
GELOGW("remove edge failed!src:%s, dst:%s", head_node_anchor->GetOwnerNode()->GetName().c_str(),
|
|
head_node_peer_anchor->GetOwnerNode()->GetName().c_str());
|
|
}
|
|
} else {
|
|
GELOGW("edge not link now. src:%s, dst:%s", head_node_anchor->GetOwnerNode()->GetName().c_str(),
|
|
head_node_peer_anchor->GetOwnerNode()->GetName().c_str());
|
|
}
|
|
|
|
NodePtr transdata_node_keep = transdata_out_anchor->GetOwnerNode();
|
|
if (before_transdata_nodes_[anchors_index].size() == kNoTransOp) {
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
GELOGI("add edge.src:%s, dst:%s", transdata_node_keep->GetName().c_str(),
|
|
head_node_peer_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::AddEdge(transdata_out_anchor, head_node_peer_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "add edge.src:%s, dst:%s", transdata_node_keep->GetName().c_str(),
|
|
head_node_peer_anchor->GetOwnerNode()->GetName().c_str());
|
|
return GRAPH_FAILED;
|
|
}
|
|
relink_node = head_node_peer_anchor->GetOwnerNode();
|
|
GE_CHECK_NOTNULL(transdata_node_keep->GetOpDesc());
|
|
auto transdata_output_desc = transdata_node_keep->GetOpDesc()->GetOutputDescPtr(0);
|
|
GE_CHECK_NOTNULL(transdata_output_desc);
|
|
for (size_t i = 0; i < sub_graph_anchors_[anchors_index].size() - 1; ++i) {
|
|
auto in_data_anchor = sub_graph_anchors_[anchors_index][i].second;
|
|
GE_CHECK_NOTNULL(in_data_anchor);
|
|
auto in_owner_node = in_data_anchor->GetOwnerNode();
|
|
auto in_op_desc = in_owner_node->GetOpDesc();
|
|
GE_CHECK_NOTNULL(in_op_desc);
|
|
auto input_desc = in_op_desc->GetInputDesc(in_data_anchor->GetIdx());
|
|
CopyTensorDesc(transdata_output_desc, input_desc);
|
|
if (in_op_desc->UpdateInputDesc(in_data_anchor->GetIdx(), input_desc) != GRAPH_SUCCESS) {
|
|
GELOGE(FAILED, "UpdateInputDesc fail.");
|
|
return FAILED;
|
|
}
|
|
int output_idx = sub_graph_anchors_[anchors_index][i + 1].first->GetIdx();
|
|
auto output_desc = in_op_desc->GetOutputDesc(output_idx);
|
|
CopyTensorDesc(transdata_output_desc, output_desc);
|
|
GE_IF_BOOL_EXEC(in_op_desc->UpdateOutputDesc(output_idx, output_desc) != GRAPH_SUCCESS,
|
|
GELOGE(GRAPH_FAILED, "update input desc failed");
|
|
return GRAPH_FAILED);
|
|
// relink control edge
|
|
if (RelinkInControlEdge(in_owner_node, transdata_node_keep) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
void SameTransdataBreadthFusionPass::CopyTensorDesc(const ConstGeTensorDescPtr &src_desc, GeTensorDesc &dst_desc) {
|
|
if (src_desc == nullptr) {
|
|
return;
|
|
}
|
|
dst_desc.SetFormat(src_desc->GetFormat());
|
|
dst_desc.SetOriginFormat(src_desc->GetOriginFormat());
|
|
dst_desc.SetShape(src_desc->GetShape());
|
|
dst_desc.SetOriginShape(src_desc->GetOriginShape());
|
|
uint32_t real_dim = 0;
|
|
if (TensorUtils::GetRealDimCnt(*src_desc, real_dim) == GRAPH_SUCCESS) {
|
|
TensorUtils::SetRealDimCnt(dst_desc, real_dim);
|
|
}
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::LinkNewCastNode2RemainTransdata(
|
|
const ComputeGraphPtr &graph, const vector<int> &same_transdata_nodes, const OutDataAnchorPtr &transdata_out_anchor,
|
|
const NodePtr &transdata_node_keep) {
|
|
for (size_t i = 1; i < same_transdata_nodes.size(); ++i) {
|
|
int anchors_index = same_transdata_nodes[i];
|
|
bool reuse_nodes = AllNodeBeforeTransdataHasOneDataOut(anchors_index);
|
|
auto transdata_peer_out_anchor = sub_graph_anchors_[anchors_index].back().first;
|
|
GE_CHECK_NOTNULL(transdata_peer_out_anchor);
|
|
auto transdata_remove_in_anchor = sub_graph_anchors_[anchors_index].back().second;
|
|
GE_CHECK_NOTNULL(transdata_remove_in_anchor);
|
|
auto transdata_node_remove = transdata_remove_in_anchor->GetOwnerNode();
|
|
if (transdata_node_remove->GetInDataNodes().size() > 1) {
|
|
continue;
|
|
}
|
|
GELOGI("remove edge.src:%s, dst:%s", transdata_peer_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
transdata_remove_in_anchor->GetOwnerNode()->GetName().c_str());
|
|
if (GraphUtils::RemoveEdge(transdata_peer_out_anchor, transdata_remove_in_anchor) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
OutDataAnchorPtr pre_out_anchor = nullptr;
|
|
NodePtr relink_node = nullptr;
|
|
if (reuse_nodes) {
|
|
// reuse nodes before transdata
|
|
if (ReuseNodesBeforeTransdata(anchors_index, transdata_out_anchor, relink_node) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
if (before_transdata_nodes_[anchors_index].size() > kNoTransOp) {
|
|
pre_out_anchor = transdata_peer_out_anchor;
|
|
} else {
|
|
pre_out_anchor = transdata_out_anchor;
|
|
}
|
|
} else {
|
|
// miss cast control edge
|
|
pre_out_anchor = transdata_out_anchor;
|
|
if (AddCastNode(graph, same_transdata_nodes[i], pre_out_anchor, relink_node) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
|
|
if (ReLinkTransdataOutput2PreNode(transdata_node_remove, pre_out_anchor, relink_node) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
if (RelinkInControlEdge(transdata_node_remove, transdata_node_keep) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
if (graph->RemoveNode(transdata_node_remove) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "remove node %s failed!", transdata_node_remove->GetName().c_str());
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::RelinkInControlEdge(const NodePtr &node_src, const NodePtr &node_dst) {
|
|
GE_CHECK_NOTNULL(node_dst);
|
|
GE_CHECK_NOTNULL(node_src);
|
|
if (node_src->GetInControlNodes().empty()) {
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
GE_CHECK_NOTNULL(node_src->GetInControlAnchor());
|
|
for (auto &peer_out_control_anchor : node_src->GetInControlAnchor()->GetPeerOutControlAnchors()) {
|
|
GELOGD("remove edge.src:%s, dst:%s", peer_out_control_anchor->GetOwnerNode()->GetName().c_str(),
|
|
node_src->GetName().c_str());
|
|
if (GraphUtils::RemoveEdge(peer_out_control_anchor, node_src->GetInControlAnchor()) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "remove edge faliled!src:%s, dst:%s",
|
|
peer_out_control_anchor->GetOwnerNode()->GetName().c_str(), node_src->GetName().c_str());
|
|
return GRAPH_FAILED;
|
|
}
|
|
GELOGD("add edge.src:%s, dst:%s", peer_out_control_anchor->GetOwnerNode()->GetName().c_str(),
|
|
node_dst->GetName().c_str());
|
|
if (GraphUtils::AddEdge(peer_out_control_anchor, node_dst->GetInControlAnchor()) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "add edge failed!src:%s, dst:%s", peer_out_control_anchor->GetOwnerNode()->GetName().c_str(),
|
|
node_dst->GetName().c_str());
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::AddCastNode(const ComputeGraphPtr &graph, int anchors_index,
|
|
OutDataAnchorPtr &pre_out_anchor, NodePtr &first_link_node) {
|
|
GE_CHECK_NOTNULL(pre_out_anchor);
|
|
GE_CHECK_NOTNULL(graph);
|
|
auto pre_node = pre_out_anchor->GetOwnerNode();
|
|
GE_CHECK_NOTNULL(pre_node->GetOpDesc());
|
|
auto pre_output_desc = pre_node->GetOpDesc()->GetOutputDescPtr(pre_out_anchor->GetIdx());
|
|
GE_CHECK_NOTNULL(pre_output_desc);
|
|
for (size_t i = 0; i < sub_graph_anchors_[anchors_index].size() - 1; ++i) {
|
|
auto in_data_anchor = sub_graph_anchors_[anchors_index][i].second;
|
|
GE_CHECK_NOTNULL(in_data_anchor);
|
|
auto in_owner_node = in_data_anchor->GetOwnerNode();
|
|
auto in_op_desc = in_owner_node->GetOpDesc();
|
|
GE_CHECK_NOTNULL(in_op_desc);
|
|
auto input_desc = in_op_desc->GetInputDesc(in_data_anchor->GetIdx());
|
|
input_desc.SetFormat(pre_output_desc->GetFormat());
|
|
input_desc.SetOriginFormat(pre_output_desc->GetOriginFormat());
|
|
input_desc.SetShape(pre_output_desc->GetShape());
|
|
input_desc.SetOriginShape(pre_output_desc->GetOriginShape());
|
|
uint32_t real_dim = 0;
|
|
if (TensorUtils::GetRealDimCnt(*pre_output_desc, real_dim) != GRAPH_SUCCESS) {
|
|
GELOGW("get %s real dim cnt failed!", pre_node->GetName().c_str());
|
|
}
|
|
TensorUtils::SetRealDimCnt(input_desc, real_dim);
|
|
auto output_desc = in_op_desc->GetOutputDesc(sub_graph_anchors_[anchors_index][i + 1].first->GetIdx());
|
|
output_desc.SetFormat(pre_output_desc->GetFormat());
|
|
output_desc.SetOriginFormat(pre_output_desc->GetOriginFormat());
|
|
output_desc.SetShape(pre_output_desc->GetShape());
|
|
output_desc.SetOriginShape(pre_output_desc->GetOriginShape());
|
|
TensorUtils::SetRealDimCnt(output_desc, real_dim);
|
|
|
|
auto cast_op_desc = GetCastOp(input_desc, output_desc);
|
|
if (cast_op_desc == nullptr) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
auto cast_node = graph->AddNode(cast_op_desc);
|
|
if (cast_node == nullptr) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
GELOGD("add edge.src:%s, dst:%s", pre_out_anchor->GetOwnerNode()->GetName().c_str(), cast_node->GetName().c_str());
|
|
if (GraphUtils::AddEdge(pre_out_anchor, cast_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
|
|
return GRAPH_FAILED;
|
|
}
|
|
if (i == 0) {
|
|
first_link_node = cast_node;
|
|
}
|
|
|
|
if (!AttrUtils::SetBool(cast_op_desc, ATTR_NEED_COMPILE, true)) {
|
|
GELOGE(FAILED, "SetExtAttr fail.");
|
|
return FAILED;
|
|
}
|
|
pre_out_anchor = cast_node->GetOutDataAnchor(0);
|
|
}
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
graphStatus SameTransdataBreadthFusionPass::GetSubGraphsBetweenNormalAndTransdataNode(
|
|
OutDataAnchorPtr &out_anchor,
|
|
std::vector<std::vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out,
|
|
std::vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list) {
|
|
graphStatus ret = GRAPH_SUCCESS;
|
|
if (out_anchor == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "out data anchor is null!This should not happen!");
|
|
return GRAPH_FAILED;
|
|
}
|
|
|
|
for (auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
|
|
if (peer_in_anchor == nullptr || peer_in_anchor->GetOwnerNode() == nullptr ||
|
|
peer_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
|
|
continue;
|
|
}
|
|
|
|
nodes_list.push_back(make_pair(out_anchor, peer_in_anchor));
|
|
auto peer_in_node = peer_in_anchor->GetOwnerNode();
|
|
if ((peer_in_node->GetType() == TRANSDATA && peer_in_node->GetOutDataNodes().size() > 0) ||
|
|
!IsHandleOp(peer_in_node)) {
|
|
sub_graphs_out.push_back(nodes_list);
|
|
nodes_list.pop_back();
|
|
} else {
|
|
if (peer_in_node->GetType() == TRANSDATA) {
|
|
if (peer_in_node->GetOutDataNodes().size() == 0) {
|
|
nodes_list.pop_back();
|
|
continue;
|
|
}
|
|
}
|
|
for (auto &peer_out_anchor : peer_in_node->GetAllOutDataAnchors()) {
|
|
ret = GetSubGraphsBetweenNormalAndTransdataNode(peer_out_anchor, sub_graphs_out, nodes_list);
|
|
if (ret != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "get all transop between normal node failed!node:%s", peer_in_node->GetName().c_str());
|
|
return GRAPH_FAILED;
|
|
}
|
|
}
|
|
nodes_list.pop_back();
|
|
}
|
|
}
|
|
return GRAPH_SUCCESS;
|
|
}
|
|
|
|
bool SameTransdataBreadthFusionPass::IsTransOp(const NodePtr &node) {
|
|
if (node == nullptr) {
|
|
return false;
|
|
}
|
|
return node->GetType() == CAST || node->GetType() == TRANSPOSE || node->GetType() == TRANSPOSED ||
|
|
node->GetType() == RESHAPE || node->GetType() == TRANSDATA;
|
|
}
|
|
|
|
bool SameTransdataBreadthFusionPass::IsHandleOp(const NodePtr &node) {
|
|
if (node == nullptr) {
|
|
return false;
|
|
}
|
|
return node->GetType() == CAST || node->GetType() == TRANSDATA;
|
|
}
|
|
} // namespace ge
|