|
|
|
@ -462,7 +462,7 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPt
|
|
|
|
|
set<NodePtr> all_reduce_succs;
|
|
|
|
|
|
|
|
|
|
for (const NodePtr &node : graph->GetDirectNode()) {
|
|
|
|
|
if ((node->GetType() != HCOMALLREDUCE && node->GetType() != HVDCALLBACKALLREDUCE) ||
|
|
|
|
|
if (!IsHcomNode(node->GetType()) ||
|
|
|
|
|
node->GetInDataNodes().size() <= 1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
@ -507,7 +507,7 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPt
|
|
|
|
|
old_stream_to_new.emplace(old_stream, new_stream);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if ((node->GetType() != HCOMALLREDUCE && node->GetType() != HVDCALLBACKALLREDUCE)) {
|
|
|
|
|
if (!IsHcomNode(node->GetType())) {
|
|
|
|
|
GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream);
|
|
|
|
|
node->GetOpDesc()->SetStreamId(new_stream);
|
|
|
|
|
}
|
|
|
|
@ -517,6 +517,11 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPt
|
|
|
|
|
return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AllReduceParallelPass::IsHcomNode(const std::string& node_type) {
|
|
|
|
|
return (node_type == HCOMALLREDUCE || node_type == HVDCALLBACKALLREDUCE);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> &scheduler_confs,
|
|
|
|
|
const map<string, int> &max_parallel_num)
|
|
|
|
|
: scheduler_confs_(scheduler_confs), max_parallel_num_(max_parallel_num) {}
|
|
|
|
|