|
|
|
@ -128,8 +128,7 @@ Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t
|
|
|
|
|
NodePtr cur_node = nullptr;
|
|
|
|
|
for (std::size_t i = 1; i < nodes.size(); i++) {
|
|
|
|
|
cur_node = nodes[i];
|
|
|
|
|
GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(),
|
|
|
|
|
cur_node->GetName().c_str());
|
|
|
|
|
GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
|
|
|
|
|
if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.",
|
|
|
|
|
pre_node->GetName().c_str(), cur_node->GetName().c_str());
|
|
|
|
@ -155,10 +154,8 @@ Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(),
|
|
|
|
|
cur_node->GetName().c_str());
|
|
|
|
|
return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(),
|
|
|
|
|
cur_node->GetInControlAnchor());
|
|
|
|
|
GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
|
|
|
|
|
return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), cur_node->GetInControlAnchor());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph,
|
|
|
|
@ -200,9 +197,7 @@ Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph,
|
|
|
|
|
|
|
|
|
|
NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0);
|
|
|
|
|
GE_CHECK_NOTNULL(cast_node);
|
|
|
|
|
if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes,
|
|
|
|
|
cast_node, node,
|
|
|
|
|
node_2_switch_merge) != SUCCESS) {
|
|
|
|
|
if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, cast_node, node, node_2_switch_merge) != SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str());
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.",
|
|
|
|
|
graph->GetName().c_str());
|
|
|
|
@ -247,8 +242,7 @@ void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::s
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_nodes,
|
|
|
|
|
const std::vector<NodePtr> &merge_nodes,
|
|
|
|
|
const NodePtr &cast_node, const NodePtr &switch_node,
|
|
|
|
|
const std::vector<NodePtr> &merge_nodes, const NodePtr &cast_node, const NodePtr &switch_node,
|
|
|
|
|
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
|
|
|
|
|
for (const auto &group_node : group_nodes) {
|
|
|
|
|
auto itr = node_2_switch_merge.find(group_node);
|
|
|
|
|