You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/passes/transop_depth_fusion_pass.cc

320 lines
14 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/transop_depth_fusion_pass.h"
#include <algorithm>
#include "common/ge_inner_error_codes.h"
#include "common/types.h"
#include "graph/compute_graph.h"
#include "graph/ge_tensor.h"
#include "graph/op_desc.h"
#include "graph/utils/graph_utils.h"
#include "graph/common/transop_util.h"
#include "graph/utils/node_utils.h"
namespace ge {
graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) {
GELOGI("[TransOpDepthFusionPass]: optimize in depth begin...");
if (graph == nullptr) {
return GRAPH_SUCCESS;
}
for (const auto &node : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
if (TransOpUtil::IsTransOp(node)) {
continue;
}
GELOGD("Current normal node is: %s, type: %s, begin in-depth recursive", node->GetName().c_str(),
node->GetType().c_str());
for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
GE_CHECK_NOTNULL(out_anchor);
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
if (RecursiveInDepth(peer_in_anchor, graph) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Recursive failed, root node is: %s, type: %s", node->GetName().c_str(),
node->GetType().c_str());
}
}
}
}
GELOGI("[TransOpDepthFusionPass]: Optimize in depth success...");
return GRAPH_SUCCESS;
}
/// @@ Method:
/// Depth-first recursive strategy was utilized to traverse all the trans ops.
/// Both trans ops will be offset when the back one's output desc is consistent
/// with it's former neighbor's input.
/// @@ Limitation:
/// The current method only judge the neighbors. Trans ops separated by some
/// other ops which can't be offset are not taken into account in current
/// @@ Recursive depth
/// To ensure that the stack does not overflow, the maximum depth in recursive is
/// set to be maxRecursiveDepth = 20. More trans ops are seen abnormally.
graphStatus TransOpDepthFusionPass::RecursiveInDepth(const InDataAnchorPtr &dst_in_anchor,
const ge::ComputeGraphPtr &graph) {
static unsigned int temp_depth = 0;
static const unsigned int max_recursive_depth = 20;
temp_depth++;
if (temp_depth >= max_recursive_depth) {
GELOGI(
"Caution: recursive depth is become %u."
"It's abnormally to have so many trans ops between two normal ops"
"Please check your graph in detail!"
"The search terminate here and continue to another branch.",
temp_depth);
temp_depth--;
return GRAPH_SUCCESS;
}
if (dst_in_anchor == nullptr || dst_in_anchor->GetOwnerNode() == nullptr ||
dst_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
GELOGE(FAILED, "parameter is null.");
return GRAPH_FAILED;
}
auto node = dst_in_anchor->GetOwnerNode();
if (!TransOpUtil::IsTransOp(node) || dst_in_anchor->GetIdx() != TransOpUtil::GetTransOpDataIndex(node)) {
GELOGD("Now the end of this branch, node: %s, type: %s, recursive depth: %u", node->GetName().c_str(),
node->GetType().c_str(), temp_depth);
temp_depth--;
return GRAPH_SUCCESS;
} else if (CheckNodeCanBeDeleted(node)) {
GELOGD("node: %s, type: %s does not change memory, just delete", node->GetName().c_str(), node->GetType().c_str());
auto out_anchor = node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(out_anchor);
auto in_anchors = out_anchor->GetPeerInDataAnchors();
GE_CHK_STATUS_RET(RemoveNode(node, graph), "remove edge failed");
GELOGI("remove node: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str());
for (auto &in_anchor : in_anchors) {
GE_CHECK_NOTNULL(in_anchor);
GE_CHK_STATUS_RET(UpdateSrcAttr(in_anchor->GetPeerOutAnchor(), out_anchor, in_anchor), "UpdateSrcAttr failed");
GE_CHK_STATUS_RET(RecursiveInDepth(in_anchor, graph), "RecursiveInDepth failed");
}
} else if (trans_op_.empty() || !DescAreSymmetry(trans_op_.top(), node)) {
GELOGD("node: %s, type: %s can't be offset, push to trans_op_", node->GetName().c_str(), node->GetType().c_str());
trans_op_.push(node);
auto out_anchor = node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(out_anchor);
for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHK_STATUS_RET(RecursiveInDepth(in_anchor, graph), "RecursiveInDepth failed");
}
if (node->GetOutDataNodesSize() == 0) {
GE_CHK_STATUS_RET(RemoveNode(node, graph), "remove node failed");
GELOGI("backtracking, trans op: %s, type: %s will be removed", node->GetName().c_str(), node->GetType().c_str());
}
GELOGD("backtracking, trans_op_ fall back. pop node: %s, type: %s.", trans_op_.top()->GetName().c_str(),
trans_op_.top()->GetType().c_str());
trans_op_.pop();
} else if (DescAreSymmetry(trans_op_.top(), node)) {
GELOGD("current node: %s, type: %s can be offset with node: %s, type %s", node->GetName().c_str(),
node->GetType().c_str(), trans_op_.top()->GetName().c_str(), trans_op_.top()->GetType().c_str());
GELOGD("offset_op_ push node: %s, type: %s.", trans_op_.top()->GetName().c_str(),
trans_op_.top()->GetType().c_str());
offset_op_.push(trans_op_.top());
auto in_data_anchor = node->GetInDataAnchor(0);
GE_CHECK_NOTNULL(in_data_anchor);
auto old_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(old_out_anchor);
auto new_out_anchor = trans_op_.top()->GetInDataAnchor(0)->GetPeerOutAnchor();
GE_CHECK_NOTNULL(new_out_anchor);
GE_IF_BOOL_EXEC(RelinkEdges(new_out_anchor, old_out_anchor, in_data_anchor) != GRAPH_SUCCESS,
GELOGE(FAILED, "RelinkEdges fail.");
return FAILED)
auto out_anchor = node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(out_anchor);
auto in_anchors = out_anchor->GetPeerInDataAnchors();
GELOGD("begin offset,trans_op_ pop node: %s, type: %s.", trans_op_.top()->GetName().c_str(),
trans_op_.top()->GetType().c_str());
GELOGI("the offset node : %s, type: %s will be removed.", node->GetName().c_str(), node->GetType().c_str());
GE_CHK_STATUS_RET(RemoveNode(node, graph), "remove node failed");
trans_op_.pop();
for (const auto &in_anchor : in_anchors) {
GE_CHECK_NOTNULL(in_anchor);
GE_CHK_STATUS_RET(UpdateSrcAttr(in_anchor->GetPeerOutAnchor(), out_anchor, in_anchor), "UpdateSrcAttr failed");
GE_CHK_STATUS_RET(RecursiveInDepth(in_anchor, graph), "RecursiveInDepth failed");
}
GELOGD("backtracking, trans_op_ push node: %s, type: %s.", offset_op_.top()->GetName().c_str(),
offset_op_.top()->GetType().c_str());
trans_op_.push(offset_op_.top());
offset_op_.pop();
}
temp_depth--;
return GRAPH_SUCCESS;
}
bool TransOpDepthFusionPass::CheckNodeCanBeDeleted(const NodePtr &node) {
bool is_shape_unknown = false;
if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) {
if (is_shape_unknown) {
GELOGI("op:%s is unknown shape, can not be deleted.",
node->GetName().c_str());
return false;
}
}
return node->GetType() == RESHAPE || node->GetType() == REFORMAT || node->GetType() == SQUEEZE ||
node->GetType() == EXPANDDIMS;
}
bool TransOpDepthFusionPass::DescAreSymmetry(const NodePtr &src_node, const NodePtr &dst_node) {
if (src_node == nullptr || dst_node == nullptr || src_node->GetOpDesc() == nullptr ||
dst_node->GetOpDesc() == nullptr) {
return false;
}
const auto &src_input_desc = src_node->GetOpDesc()->MutableInputDesc(0);
const auto &dst_output_desc = dst_node->GetOpDesc()->MutableOutputDesc(0);
GE_CHECK_NOTNULL_EXEC(src_input_desc, return false);
GE_CHECK_NOTNULL_EXEC(dst_output_desc, return false);
const auto &src_input_dtype = src_input_desc->GetDataType();
const auto &src_input_format = src_input_desc->GetFormat();
const auto &src_input_shape = src_input_desc->GetShape().GetDims();
const auto &dst_output_dtype = dst_output_desc->GetDataType();
const auto &dst_output_format = dst_output_desc->GetFormat();
const auto &dst_output_shape = dst_output_desc->GetShape().GetDims();
if (src_node->GetType() == CAST && dst_node->GetType() == CAST) {
return src_input_dtype == dst_output_dtype && src_input_format == dst_output_format;
} else {
return src_input_dtype == dst_output_dtype && src_input_shape == dst_output_shape &&
src_input_format == dst_output_format;
}
}
// If the relationship was changed, the input and src name will be update
graphStatus TransOpDepthFusionPass::UpdateSrcAttr(const OutDataAnchorPtr &new_out_anchor,
const OutDataAnchorPtr &ori_out_anchor,
const InDataAnchorPtr &dst_in_anchor) {
if (dst_in_anchor == nullptr || dst_in_anchor->GetOwnerNode() == nullptr ||
dst_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
GELOGW("dst_in_anchor or it's owner node and op_desc is nullptr");
return GRAPH_SUCCESS;
}
GE_CHECK_NOTNULL(new_out_anchor);
GE_CHECK_NOTNULL(new_out_anchor->GetOwnerNode());
GE_CHECK_NOTNULL(ori_out_anchor);
GE_CHECK_NOTNULL(ori_out_anchor->GetOwnerNode());
auto new_name = new_out_anchor->GetOwnerNode()->GetName();
auto ori_name = ori_out_anchor->GetOwnerNode()->GetName();
auto dst_desc = dst_in_anchor->GetOwnerNode()->GetOpDesc();
auto ori_src_name = dst_desc->GetSrcName();
auto ori_input_name = dst_desc->GetInputName();
std::vector<string> new_src_name;
std::vector<string> new_input_name;
if (ori_src_name.empty()) {
new_src_name.push_back(new_name);
} else {
for (auto &src_name : ori_src_name) {
if (src_name == ori_name) {
new_src_name.push_back(new_name);
} else {
new_src_name.push_back(src_name);
}
}
}
if (ori_input_name.empty()) {
new_input_name.push_back(new_name);
} else {
for (auto &input_name : ori_input_name) {
if (input_name == ori_name) {
new_input_name.push_back(new_name);
} else {
new_input_name.push_back(input_name);
}
}
}
dst_desc->SetSrcName(new_src_name);
dst_desc->SetInputName(new_input_name);
return GRAPH_SUCCESS;
}
/// Relink the offset trans op with it's former neighbor's father node.
/// Note: control edge will be added to link the two offset ops, if the former op
/// has in control nodes
graphStatus TransOpDepthFusionPass::RelinkEdges(const OutDataAnchorPtr &new_out_anchor,
const OutDataAnchorPtr &old_out_anchor,
const InDataAnchorPtr &in_data_anchor) {
if (new_out_anchor == nullptr || old_out_anchor == nullptr || in_data_anchor == nullptr) {
GELOGE(INTERNAL_ERROR, "new_out_anchor or old_out_anchor or in_data_anchor is nullptr");
return GRAPH_FAILED;
}
if (new_out_anchor->GetOwnerNode() == nullptr || old_out_anchor->GetOwnerNode() == nullptr ||
in_data_anchor->GetOwnerNode() == nullptr) {
GELOGE(INTERNAL_ERROR, "anchor's owner node is nullptr");
return GRAPH_FAILED;
}
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(old_out_anchor, in_data_anchor), "remove edge failed");
GE_CHK_STATUS_RET(GraphUtils::AddEdge(new_out_anchor, in_data_anchor), "add edge failed");
GELOGD(
"relink edges before remove node, remove data edge between node: %s, "
"type: %s and node: %s, type: %s.",
old_out_anchor->GetOwnerNode()->GetName().c_str(), old_out_anchor->GetOwnerNode()->GetType().c_str(),
in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetOwnerNode()->GetType().c_str());
GELOGD(
"relink edges before remove node, add data edge between node: %s, "
"type: %s and node: %s, type: %s.",
new_out_anchor->GetOwnerNode()->GetName().c_str(), new_out_anchor->GetOwnerNode()->GetType().c_str(),
in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetOwnerNode()->GetType().c_str());
bool is_linked = false;
auto dst_node = in_data_anchor->GetOwnerNode();
auto src_node = old_out_anchor->GetOwnerNode();
auto in_ctrl_nodes = dst_node->GetInControlNodes();
if (!in_ctrl_nodes.empty()) {
auto iter = std::find(in_ctrl_nodes.begin(), in_ctrl_nodes.end(), src_node);
is_linked = iter != in_ctrl_nodes.end();
}
if (!src_node->GetInControlNodes().empty() && !is_linked) {
auto out_ctrl_anchor = src_node->GetOutControlAnchor();
auto in_ctrl_anchor = dst_node->GetInControlAnchor();
GE_CHK_STATUS_RET(GraphUtils::AddEdge(out_ctrl_anchor, in_ctrl_anchor), "add edge failed");
GELOGD(
"relink edges before remove node, add control edge between node: %s,"
" type: %s and node: %s, type: %s.",
src_node->GetName().c_str(), src_node->GetType().c_str(), dst_node->GetName().c_str(),
dst_node->GetType().c_str());
}
return GRAPH_SUCCESS;
}
// Remove trans op by using interface: IsolateNode & RemoveNodeWithoutRelink
graphStatus TransOpDepthFusionPass::RemoveNode(const NodePtr &node, const ge::ComputeGraphPtr &graph) {
if (node == nullptr || graph == nullptr) {
return GRAPH_FAILED;
}
if (GraphUtils::IsolateNode(node, {0}) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", node->GetName().c_str(),
node->GetType().c_str());
return GRAPH_FAILED;
}
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Remove node: %s, type: %s without relink failed", node->GetName().c_str(),
node->GetType().c_str());
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
} // namespace ge