You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/metadef/graph/utils/tuning_utils.cc

684 lines
28 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/tuning_utils.h"
#include "../debug/ge_util.h"
#include "../debug/ge_op_types.h"
namespace ge {
const std::string peer_node_name_attr = "_peerNodeName";
const std::string parent_node_name_attr = "_parentNodeName";
const std::string alias_name_attr = "_aliasName";
const std::string parent_node_attr = "parentNode";
const std::string parent_node_anchor_index_attr = "_parentNodeAnchorIndex";
const std::string tuning_subgraph_prefix = "/aicore_subgraph_";
const std::string non_tuning_subgraph_prefix = "/subgraph_";
const std::set<std::string> kPartitionOpTypes = {PLACEHOLDER, END};
const std::set<std::string> kExeTypes = {DATA, NETOUTPUT};
NodeNametoNodeNameMap TuningUtils::data_2_netoutput_;
NodetoNodeNameMap TuningUtils::data_node_2_netoutput_;
NodetoNodeMap TuningUtils::data_node_2_netoutput_node_;
NodeSet TuningUtils::netoutput_nodes_;
NodeSet TuningUtils::merged_graph_nodes_;
SubgraphCreateOutNode TuningUtils::create_output_;
std::mutex TuningUtils::mutex_;
std::string TuningUtils::PrintCheckLog() {
std::stringstream ss;
ss << "d2n:{";
for (const auto &pair : data_2_netoutput_) {
ss << "data:" << pair.first << "-"
<< "netoutput:" << pair.second;
ss << " | ";
}
ss << "}";
ss << "netoutputs:{";
for (const auto &node : netoutput_nodes_) {
ss << "netoutput:" << node->GetName();
ss << " | ";
}
ss << "}";
return ss.str();
}
std::string TuningUtils::GetNodeNameByAnchor(const Anchor *anchor) {
if (anchor == nullptr) {
GELOGE(GRAPH_FAILED, "Anchor is nullptr");
return "Null";
}
auto node = anchor->GetOwnerNode();
return node == nullptr ? "Null" : node->GetName();
}
// part 1
graphStatus TuningUtils::ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_subgraphs,
std::vector<ComputeGraphPtr> non_tuning_subgraphs, bool exe_flag,
const std::string &path, const std::string &user_path) {
int64_t i = 0;
int64_t j = 0;
std::lock_guard<std::mutex> lock(mutex_);
for (auto &subgraph : tuning_subgraphs) {
create_output_.emplace(subgraph, nullptr);
auto help_info = HelpInfo{i, exe_flag, true, path, user_path};
if (MakeExeGraph(subgraph, help_info) != SUCCESS) {
GELOGE(GRAPH_FAILED, "TUU:subgraph %zu generate exe graph failed", i);
return GRAPH_FAILED;
}
i++;
}
for (auto &subgraph : non_tuning_subgraphs) {
create_output_.emplace(subgraph, nullptr);
auto help_info = HelpInfo{j, true, false, path, user_path};
if (MakeExeGraph(subgraph, help_info) != SUCCESS) {
GELOGE(GRAPH_FAILED, "TUU:non tuning_subgraph %zu generate exe graph failed", j);
return GRAPH_FAILED;
}
j++;
}
create_output_.clear();
return SUCCESS;
}
// +---------------+
// | pld pld |
// | \ / |
// | relu relu |
// | \ / |
// | add |
// | | |
// | end |
// +---------------+
// |
// |
// V
// +---------------+
// | data data |
// | \ / |
// | relu relu |
// | \ / |
// | add |
// | | |
// | netoutput |
// +---------------+
graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) {
GE_CHECK_NOTNULL(exe_graph);
// if not make exe, just dump and return
if (!help_info.exe_flag) {
DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path);
GELOGI("TUU:just return, dump original sub_graph[%s]index[%d]", exe_graph->GetName().c_str(), help_info.index);
return SUCCESS;
}
// modify sub graph
for (NodePtr &node : exe_graph->GetDirectNode()) {
// 1.handle pld
if (node->GetType() == PLACEHOLDER) {
if (HandlePld(node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(),
exe_graph->GetName().c_str());
return FAILED;
}
}
// 2.handle end
if (node->GetType() == END) {
if (HandleEnd(node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(),
exe_graph->GetName().c_str());
return FAILED;
}
}
}
graphStatus ret = exe_graph->TopologicalSorting();
if (ret != SUCCESS) {
GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret);
return ret;
}
// dump subgraphs which modified by us
if (help_info.user_path.empty()) {
DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path);
} else {
GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path);
}
return SUCCESS;
}
void TuningUtils::DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path) {
if (!path.empty()) {
if (is_tuning_graph) {
GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt");
} else {
GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt");
}
} else {
path = "./";
if (is_tuning_graph) {
GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt");
} else {
GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt");
}
}
}
graphStatus TuningUtils::CreateDataNode(NodePtr &node, NodePtr &data_node) {
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
auto data_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), DATA);
GE_CHECK_NOTNULL(data_op_desc);
auto pld_op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(pld_op_desc);
auto output_desc = pld_op_desc->GetOutputDesc(0); // only one output for pld and data
// data inputdesc & outputdesc set as same
if (data_op_desc->AddInputDesc(output_desc) != SUCCESS) {
GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str());
return FAILED;
}
if (data_op_desc->AddOutputDesc(output_desc) != SUCCESS) {
GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str());
return FAILED;
}
data_node = graph->AddNode(data_op_desc);
GE_CHECK_NOTNULL(data_node);
if (data_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed");
return FAILED;
}
return SUCCESS;
}
graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node) {
auto op_desc = data_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto pld_desc = pld->GetOpDesc();
GE_CHECK_NOTNULL(pld_desc);
// inherit
// a. set `end's input node type` as attr
std::string parent_op_type;
if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) {
GELOGE(FAILED, "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str());
return FAILED;
}
(void)AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type);
// b. set `end's input node name` as attr
std::string parent_op_name;
if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) {
GELOGE(FAILED, "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str());
return FAILED;
}
(void)AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name);
// c. set `end's input node's out anchor index` as attr
int parent_node_anchor_index;
if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) {
GELOGE(FAILED, "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str());
return FAILED;
}
(void)AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index);
GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(),
data_node->GetName().c_str(), data_node->GetType().c_str());
// d. set `end node name` as attr
std::string peer_end_name;
if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) {
GELOGE(FAILED, "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str());
return FAILED;
}
(void)AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name);
GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(),
data_node->GetName().c_str(), data_node->GetType().c_str());
return SUCCESS;
}
graphStatus TuningUtils::ChangePld2Data(NodePtr &node, NodePtr &data_node) {
auto type_pld = node->GetType();
auto type_data = data_node->GetType();
if (type_pld != PLACEHOLDER || type_data != DATA) {
GELOGE(FAILED, "TUU:Failed to change node %s from type %s to type %s", node->GetName().c_str(), type_pld.c_str(),
type_data.c_str());
return FAILED;
}
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
std::vector<int> output_map(node->GetAllOutDataAnchorsSize());
for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
output_map[i] = static_cast<int>(i);
}
auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:Failed to replace node %s by node %s error node %u", node->GetName().c_str(),
data_node->GetName().c_str(), ret);
return FAILED;
}
NodeUtils::UnlinkAll(*node);
ret = GraphUtils::RemoveNodeWithoutRelink(graph, node);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str());
return FAILED;
}
GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)", node->GetName().c_str(),
node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str());
return ret;
}
graphStatus TuningUtils::HandlePld(NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
NodePtr data_node = nullptr;
// 1. create data node
if (CreateDataNode(node, data_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
// 2. add necessary info to data_node for recovery whole graph
if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
// 3. replace pld node by data node created before
if (ChangePld2Data(node, data_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
GELOGD("TUU:pld[%s] handle success", node->GetName().c_str());
return SUCCESS;
}
graphStatus TuningUtils::CreateNetOutput(NodePtr &node, NodePtr &out_node) {
GE_CHECK_NOTNULL(node);
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
auto search = create_output_.find(graph);
if (search == create_output_.end()) {
GELOGE(FAILED, "TUU:node %s's owner sub graph %s not exist in create_output map", node->GetName().c_str(),
graph->GetName().c_str());
return FAILED;
}
if (search->second != nullptr) {
out_node = search->second;
GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str());
return SUCCESS;
}
auto out_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), NETOUTPUT);
GE_CHECK_NOTNULL(out_op_desc);
out_node = graph->AddNode(out_op_desc);
GE_CHECK_NOTNULL(out_node);
if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed");
return FAILED;
}
create_output_[graph] = out_node;
return SUCCESS;
}
graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node) {
GE_CHECK_NOTNULL(end);
GE_CHECK_NOTNULL(out_node);
auto op_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
std::vector<std::string> alias_names = {};
(void)AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names);
alias_names.push_back(end->GetName());
(void)AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names);
return SUCCESS;
}
graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) {
GE_CHECK_NOTNULL(end_node);
GE_CHECK_NOTNULL(out_node);
// get end in node is control node or normal node
AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr)
? Anchor::DynamicAnchorCast<Anchor>(end_node->GetInControlAnchor())
: Anchor::DynamicAnchorCast<Anchor>(end_node->GetInDataAnchor(0));
auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1
if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), end_node->GetName().c_str(),
end_node->GetOwnerComputeGraph()->GetName().c_str());
return FAILED;
}
// add edge between `end in node` and `out_node`
if (src_anchor->IsTypeOf<OutDataAnchor>()) {
std::shared_ptr<InDataAnchor> anchor =
ComGraphMakeShared<InDataAnchor>(out_node, out_node->GetAllInDataAnchors().size());
GE_CHECK_NOTNULL(anchor);
out_node->in_data_anchors_.push_back(anchor);
if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(),
end_node->GetOwnerComputeGraph()->GetName().c_str());
return FAILED;
}
auto end_op_desc = end_node->GetOpDesc();
GE_CHECK_NOTNULL(end_op_desc);
auto out_node_op_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(out_node_op_desc);
// end node always has one input
if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0)) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str());
return FAILED;
}
} else if (src_anchor->IsTypeOf<OutControlAnchor>()) {
auto anchor = out_node->GetInControlAnchor();
if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(),
end_node->GetOwnerComputeGraph()->GetName().c_str());
return FAILED;
}
} else {
GELOGE(FAILED, "TUU: node_name:%s, graph_name:%s handled failed", end_node->GetName().c_str(),
end_node->GetOwnerComputeGraph()->GetName().c_str());
return FAILED;
}
return SUCCESS;
}
graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) {
GE_CHECK_NOTNULL(end_node);
GE_CHECK_NOTNULL(out_node);
auto type_end = end_node->GetType();
auto type_out = out_node->GetType();
if (type_end != END || type_out != NETOUTPUT) {
GELOGE(FAILED, "TUU:Failed to change end_node %s from type %s to type %s", end_node->GetName().c_str(),
type_end.c_str(), type_out.c_str());
return FAILED;
}
// link all `end nodes's in node` to this out_node
if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) {
GELOGE(FAILED, "TUU:end_node [%s] LinkEnd2NetOutput failed.", end_node->GetName().c_str());
return FAILED;
}
// remove `end node`
NodeUtils::UnlinkAll(*end_node);
auto graph = end_node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
if (GraphUtils::RemoveNodeWithoutRelink(graph, end_node) != SUCCESS) {
GELOGE(FAILED, "TUU:end node [%s] RemoveNodeWithoutRelink failed.", end_node->GetName().c_str());
return FAILED;
}
return SUCCESS;
}
graphStatus TuningUtils::HandleEnd(NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
NodePtr out_node = nullptr;
// 1. create net_output node , add only one NetOutput node to one subgraph
if (CreateNetOutput(node, out_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
// 2. add necessary info to out_node for recovery whole graph
if (AddAttrToNetOutputForMergeGraph(node, out_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
// 3. replace all end nodes by one output node created before
if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
GELOGD("TUU:end[%s] handle success", node->GetName().c_str());
return SUCCESS;
}
// part 2
graphStatus TuningUtils::ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph) {
// 1. get all subgraph object
std::vector<ComputeGraphPtr> graphs;
// options format like {index:"subgraph_path"}
for (const auto &pair : options) {
ComputeGraphPtr compute_graph = ComGraphMakeShared<ComputeGraph>(std::to_string(pair.first));
if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), *compute_graph)) {
GELOGE(FAILED, "TUU:load graph from file failed");
}
graphs.push_back(compute_graph);
}
// 2. merge graph
ComputeGraphPtr merged_graph = ComGraphMakeShared<ComputeGraph>("whole_graph_after_tune");
GE_CHECK_NOTNULL(merged_graph);
if (MergeAllSubGraph(graphs, merged_graph) != SUCCESS) {
GELOGE(FAILED, "TUU:MergeGraph failed");
return FAILED;
}
// 3. set parent graph
for (const auto &node : merged_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
if (node->SetOwnerComputeGraph(merged_graph) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:node %s set owner graph failed", node->GetName().c_str());
return FAILED;
}
}
graph = GraphUtils::CreateGraphFromComputeGraph(merged_graph);
return SUCCESS;
}
// +----------------------------------+
// | const const |
// | \ / |
// | netoutput(end,end) |
// +----------------------------------+
// +
// +----------------------------------+
// | data(pld) data(pld) |
// | \ / |
// | relu relu |
// | \ / |
// | \ / |
// | add |
// | | |
// | netoutput(end) |
// +----------------------------------+
// +
// +----------------------------------+
// | data(pld) |
// | / |
// | netoutput |
// +----------------------------------+
// |
// |
// V
// +----------------------------------+
// | const const |
// | \ / |
// | relu relu |
// | \ / |
// | \ / |
// | add |
// | | |
// | netoutput |
// +----------------------------------+
graphStatus TuningUtils::MergeAllSubGraph(std::vector<ComputeGraphPtr> &subgraphs,
ComputeGraphPtr &output_merged_compute_graph) {
GE_CHECK_NOTNULL(output_merged_compute_graph);
// 1. handle all subgraphs
for (auto &subgraph : subgraphs) {
Status ret_status = MergeSubGraph(subgraph);
if (ret_status != SUCCESS) {
GELOGE(ret_status, "TUU:subgraph %s merge failed", subgraph->GetName().c_str());
return ret_status;
}
}
for (const auto &node : merged_graph_nodes_) {
(void)output_merged_compute_graph->AddNode(node);
GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str());
}
// 2. remove data and output node added by us
if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str());
return FAILED;
}
graphStatus ret = output_merged_compute_graph->TopologicalSorting();
if (ret != SUCCESS) {
GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", output_merged_compute_graph->GetName().c_str(), ret);
return ret;
}
GELOGD("TUU:Print-%s", PrintCheckLog().c_str());
GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str());
return SUCCESS;
}
graphStatus TuningUtils::MergeSubGraph(ComputeGraphPtr &subgraph) {
for (auto &node : subgraph->GetDirectNode()) {
if (kPartitionOpTypes.count(node->GetType()) > 0) {
GELOGE(FAILED, "TUU:subgraph passed in should not contain nodes of end or pld type");
return FAILED;
}
// handle data converted from pld node
if (node->GetType() == DATA) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
std::string peer_out_name;
bool has_valid_str = (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty());
if (has_valid_str) {
std::lock_guard<std::mutex> lock(mutex_);
data_2_netoutput_.emplace(op_desc->GetName(), peer_out_name);
data_node_2_netoutput_.emplace(node, peer_out_name);
continue;
}
}
// handle netoutput converted from end node
if (node->GetType() == NETOUTPUT) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
std::vector<string> out_alias_name;
bool has_valid_str =
(AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty());
if (has_valid_str) {
std::lock_guard<std::mutex> lock(mutex_);
netoutput_nodes_.insert(node);
}
}
{
std::lock_guard<std::mutex> lock(mutex_);
merged_graph_nodes_.emplace(node);
}
GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str());
}
GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str());
return SUCCESS;
}
graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(graph);
// 1. traverse
for (auto &pair : data_node_2_netoutput_) {
auto data_node = pair.first;
GE_CHECK_NOTNULL(data_node);
auto netoutput_name = pair.second;
auto netoutput_node = graph->FindNode(netoutput_name);
GE_CHECK_NOTNULL(netoutput_node);
data_node_2_netoutput_node_.emplace(data_node, netoutput_node);
// 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor`
AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr)
? Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutControlAnchor())
: Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutDataAnchor(0));
AnchorPtr net_output_in_anchor = nullptr;
AnchorPtr src_out_anchor = nullptr;
if (GetInAndOutAnchorPair(data_node, netoutput_node, net_output_in_anchor, src_out_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:get out node:%s 's in anchor related with data node:%s failed",
netoutput_node->GetName().c_str(), data_node->GetName().c_str());
return FAILED;
}
// 3. relink
if (GraphUtils::RemoveEdge(src_out_anchor, net_output_in_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(),
GetNodeNameByAnchor(net_output_in_anchor.get()).c_str(), net_output_in_anchor->GetIdx(),
data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
GE_CHECK_NOTNULL(data_out_anchor);
for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) {
if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(),
GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(),
data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(),
GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(),
data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
}
}
// 4. remove out nodes added by us
for (auto &node : netoutput_nodes_) {
NodeUtils::UnlinkAll(*node);
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str());
return FAILED;
}
GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str());
}
return SUCCESS;
}
graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor,
AnchorPtr &src_out_anchor) {
// 1. get `data parent node name`, i.e. `netoutput input node name`
std::string netoutput_input_name;
auto op_desc = data_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (!AttrUtils::GetStr(op_desc, parent_node_name_attr, netoutput_input_name)) {
GELOGE(FAILED, "TUU:Failed to get parent node attr from node %s", op_desc->GetName().c_str());
return FAILED;
}
// 2. find index
int parent_node_anchor_index;
if (!AttrUtils::GetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index)) {
GELOGE(FAILED, "TUU:Failed to get parent node anchor index attr from node %s", op_desc->GetName().c_str());
return FAILED;
}
// 3.find in data or ctrl anchor by 1&2 step
for (auto &in_anchor : out_node->GetAllInAnchors()) {
GE_CHECK_NOTNULL(in_anchor);
for (auto &src_anchor : in_anchor->GetPeerAnchors()) { // get all peer anchors for ctrl
GE_CHECK_NOTNULL(src_anchor);
auto src_node = src_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
if (src_node->GetName() == netoutput_input_name && src_anchor->GetIdx() == parent_node_anchor_index) {
dest_in_anchor = in_anchor;
src_out_anchor = src_anchor;
GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s",
out_node->GetName().c_str(), dest_in_anchor->GetIdx(), netoutput_input_name.c_str(),
parent_node_anchor_index, data_node->GetName().c_str());
break;
}
}
}
GE_CHECK_NOTNULL(dest_in_anchor);
GE_CHECK_NOTNULL(src_out_anchor);
return SUCCESS;
}
} // namespace ge