|
|
|
@ -43,7 +43,7 @@ const char *const kAttrNameDstFormat = "dst_format";
|
|
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
|
void TransOpWithoutReshapeFusionPass::SetRemainNode(
|
|
|
|
|
const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor) {
|
|
|
|
|
const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor) {
|
|
|
|
|
auto iter = nodes_anchor.begin();
|
|
|
|
|
while (iter != nodes_anchor.end()) {
|
|
|
|
|
auto in_anchor = iter->second;
|
|
|
|
@ -63,7 +63,8 @@ void TransOpWithoutReshapeFusionPass::SetRemainNode(
|
|
|
|
|
if (op_desc == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return);
|
|
|
|
|
GELOGI("SetRemainNode node is %s", op_desc->GetName().c_str());
|
|
|
|
|
GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return );
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -158,7 +159,7 @@ graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphNodesInfo() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TransOpWithoutReshapeFusionPass::GetOutDataPeerInControlAnchors(
|
|
|
|
|
const size_t index, vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors) {
|
|
|
|
|
const size_t index, vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors) {
|
|
|
|
|
// The caller guarantees that the index is legal.
|
|
|
|
|
for (size_t j = 1; j < sub_graph_anchors_[index].size(); ++j) {
|
|
|
|
|
auto nodes_anchor = sub_graph_anchors_[index][j];
|
|
|
|
@ -181,9 +182,9 @@ void TransOpWithoutReshapeFusionPass::GetOutDataPeerInControlAnchors(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TransOpWithoutReshapeFusionPass::GetInControlPeerOutControlAnchors(
|
|
|
|
|
const size_t index, vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors) {
|
|
|
|
|
const size_t index, vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors) {
|
|
|
|
|
// The caller guarantees that the index is legal.
|
|
|
|
|
for (size_t j = 1; j < sub_graph_nodes_[index].size(); ++j) {
|
|
|
|
|
for (size_t j = 1; j < (sub_graph_nodes_[index].size() - 1); ++j) {
|
|
|
|
|
auto node = sub_graph_nodes_[index][j];
|
|
|
|
|
GE_CHECK_NOTNULL_JUST_RETURN(node);
|
|
|
|
|
auto in_control_anchor = node->GetInControlAnchor();
|
|
|
|
@ -208,8 +209,8 @@ void TransOpWithoutReshapeFusionPass::GetInControlPeerOutControlAnchors(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TransOpWithoutReshapeFusionPass::GetOutControlPeerAnchors(
|
|
|
|
|
const size_t index, vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors,
|
|
|
|
|
vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors) {
|
|
|
|
|
const size_t index, vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors,
|
|
|
|
|
vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors) {
|
|
|
|
|
for (size_t j = 0; j < sub_graph_nodes_[index].size() - 1; ++j) {
|
|
|
|
|
auto node = sub_graph_nodes_[index][j];
|
|
|
|
|
GE_CHECK_NOTNULL_JUST_RETURN(node);
|
|
|
|
@ -335,8 +336,8 @@ void TransOpWithoutReshapeFusionPass::UpdateInputName(const OutDataAnchorPtr &ol
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
graphStatus TransOpWithoutReshapeFusionPass::RelinkSubGraphControlEdges(
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
|
|
|
|
|
auto out_anchor = begin_anchors_pair.first;
|
|
|
|
|
GE_CHECK_NOTNULL(out_anchor);
|
|
|
|
|
auto out_owner_node = out_anchor->GetOwnerNode();
|
|
|
|
@ -364,8 +365,8 @@ graphStatus TransOpWithoutReshapeFusionPass::RelinkSubGraphControlEdges(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdgesWhenDescNotChanged(
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
|
|
|
|
|
if (RelinkSubGraphControlEdges(begin_anchors_pair, end_anchors_pair, index) != GRAPH_SUCCESS) {
|
|
|
|
|
return GRAPH_FAILED;
|
|
|
|
|
}
|
|
|
|
@ -418,8 +419,8 @@ graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdgesWhenDescNotChange
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
graphStatus TransOpWithoutReshapeFusionPass::RelinkNodesWhenDescNotChanged(
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
|
|
|
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
|
|
|
|
|
auto out_anchor = begin_anchors_pair.first;
|
|
|
|
|
GE_CHECK_NOTNULL(out_anchor);
|
|
|
|
|
auto out_owner_node = out_anchor->GetOwnerNode();
|
|
|
|
@ -581,7 +582,7 @@ void TransOpWithoutReshapeFusionPass::GetBeginOutDescAndEndInDesc(const int inde
|
|
|
|
|
auto out_owner_node = out_peer_anchor->GetOwnerNode();
|
|
|
|
|
GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node);
|
|
|
|
|
auto out_peer_op_desc = out_owner_node->GetOpDesc();
|
|
|
|
|
GE_IF_BOOL_EXEC(out_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_peer_op_desc is nullptr"); return);
|
|
|
|
|
GE_IF_BOOL_EXEC(out_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_peer_op_desc is nullptr"); return );
|
|
|
|
|
out_desc = out_peer_op_desc->GetInputDesc(out_peer_anchor->GetIdx());
|
|
|
|
|
|
|
|
|
|
auto in_peer_anchor = nodes_anchor.back().first;
|
|
|
|
@ -589,7 +590,7 @@ void TransOpWithoutReshapeFusionPass::GetBeginOutDescAndEndInDesc(const int inde
|
|
|
|
|
auto in_owner_node = in_peer_anchor->GetOwnerNode();
|
|
|
|
|
GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node);
|
|
|
|
|
auto in_peer_op_desc = in_owner_node->GetOpDesc();
|
|
|
|
|
GE_IF_BOOL_EXEC(in_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_peer_op_desc is nullptr"); return);
|
|
|
|
|
GE_IF_BOOL_EXEC(in_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_peer_op_desc is nullptr"); return );
|
|
|
|
|
in_desc = in_peer_op_desc->GetOutputDesc(in_peer_anchor->GetIdx());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -721,7 +722,7 @@ void TransOpWithoutReshapeFusionPass::RemoveNousedNodes(const ComputeGraphPtr &g
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return);
|
|
|
|
|
GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return );
|
|
|
|
|
GELOGI("remove node:%s", node->GetName().c_str());
|
|
|
|
|
if (graph->RemoveNode(node) != GRAPH_SUCCESS) {
|
|
|
|
|
GELOGW("remove node failed!node:%s", node->GetName().c_str());
|
|
|
|
@ -743,7 +744,7 @@ graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) {
|
|
|
|
|
if (IsTransOp(node)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
GELOGD("Current normal node name: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str());
|
|
|
|
|
GELOGI("Current normal node name: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str());
|
|
|
|
|
for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
|
|
|
|
|
GE_CHECK_NOTNULL(out_anchor);
|
|
|
|
|
vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors;
|
|
|
|
@ -887,11 +888,6 @@ graphStatus TransOpWithoutReshapeFusionPass::GetTransNode(const ComputeGraphPtr
|
|
|
|
|
new_trans_nodes.push_back(cast_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (new_trans_nodes.empty()) {
|
|
|
|
|
GELOGE(GRAPH_FAILED, "no new transop!this should not happen!");
|
|
|
|
|
return GRAPH_FAILED;
|
|
|
|
|
}
|
|
|
|
|
return GRAPH_SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -902,6 +898,10 @@ graphStatus TransOpWithoutReshapeFusionPass::InsertNewTransOp(const ComputeGraph
|
|
|
|
|
if (GetTransNode(graph, cast_op, format_transfer_op, insert_cast_first, new_trans_nodes) != GRAPH_SUCCESS) {
|
|
|
|
|
return GRAPH_FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (new_trans_nodes.empty()) {
|
|
|
|
|
GELOGI("No new trans node. Do not need insert new transop.");
|
|
|
|
|
return GRAPH_SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pair<OutDataAnchorPtr, InDataAnchorPtr> begin_out = sub_graph_anchors_[index].front();
|
|
|
|
|
pair<OutDataAnchorPtr, InDataAnchorPtr> end_in = sub_graph_anchors_[index].back();
|
|
|
|
@ -1051,9 +1051,8 @@ bool TransOpWithoutReshapeFusionPass::FusionFormatSupport(Format format) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphsBetweenNormalNode(
|
|
|
|
|
const OutDataAnchorPtr &out_anchor,
|
|
|
|
|
std::vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out,
|
|
|
|
|
vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list) {
|
|
|
|
|
const OutDataAnchorPtr &out_anchor, std::vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out,
|
|
|
|
|
vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list) {
|
|
|
|
|
graphStatus ret = GRAPH_SUCCESS;
|
|
|
|
|
if (out_anchor == nullptr) {
|
|
|
|
|
return GRAPH_FAILED;
|
|
|
|
|