diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 6d877145cc..07e1409a02 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1259,6 +1259,23 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { return false; } +bool AnfRuntimeAlgorithm::IsFusedCommunicationOp(const AnfNodePtr &node) { + if (!IsCommunicationOp(node)) { + return false; + } + auto primitive = AnfAlgo::GetCNodePrimitive(node); + MS_EXCEPTION_IF_NULL(primitive); + ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion); + if (attr_fusion == nullptr) { + return false; + } + auto fusion = GetValue(attr_fusion); + if (fusion == 0) { + return false; + } + return true; +} + bool AnfRuntimeAlgorithm::IsGetNext(const NotNull &node) { auto kernel_name = AnfAlgo::GetCNodeName(node); return kernel_name == kGetNextOpName; diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 9d163bf2c6..f08620994b 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -216,6 +216,7 @@ class AnfRuntimeAlgorithm { // get real input index for some tbe ops which input order is different between me and tbe impl static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static bool IsCommunicationOp(const AnfNodePtr &node); + static bool IsFusedCommunicationOp(const AnfNodePtr &node); static bool IsInplaceNode(const AnfNodePtr &node, const string &type); static bool IsGetNext(const NotNull &node); static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index a65c5ed11e..b68e7c267d 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -36,6 +36,7 @@ constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; constexpr size_t k5dDims = 5; const std::set kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(), prim::kPrimAssignSub->name()}; + void PushNoVisitedNode(const AnfNodePtr &node, std::queue *que, std::unordered_set *visited_nodes) { MS_EXCEPTION_IF_NULL(node); @@ -129,7 +130,34 @@ void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vectorcast(); + if (AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) { + return AnfAlgo::GetNodeAttr(cnode, kAttrGroup); + } + return ""; +} + +bool NeedOptimizeCommOp(const AnfNodePtr &node, std::map *optimized_comm_group) { + MS_EXCEPTION_IF_NULL(optimized_comm_group); + auto node_group = GetNodeGroup(node); + if (node_group.find(kSyncBnGroup) != string::npos) { + return false; + } + auto node_name = AnfAlgo::GetCNodeName(node); + auto iter = optimized_comm_group->find(node_name); + if (iter == optimized_comm_group->end()) { + (*optimized_comm_group)[node_name] = node_group; + return true; + } else if (iter->second == node_group) { + return true; + } + return false; +} } // namespace + AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) { auto value_node = node->cast(); if (value_node == nullptr) { @@ -153,7 +181,7 @@ std::vector KernelGraph::outputs() const { } void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, - std::unordered_set *visited_nodes) { + std::unordered_set *visited_nodes, bool comm_first) { MS_EXCEPTION_IF_NULL(visit_queue); MS_EXCEPTION_IF_NULL(visited_nodes); auto it = node_output_edges_.find(node); @@ -184,7 +212,8 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queuefind(next_node) == visited_nodes->end()) { (void)visited_nodes->insert(next_node); - if (AnfAlgo::IsCommunicationOp(next_node)) { + bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node); + if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) { MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString(); visit_queue->push(next_node); } else { @@ -206,18 +235,19 @@ void KernelGraph::SetExecOrderByDefault() { execution_order_.clear(); std::unordered_set visited_nodes; std::queue zero_input_nodes; - AnfNodePtr last_communication_node = nullptr; + std::stack delay_comm_stack; std::queue communication_descendants; - while (!seed_nodes.empty() || last_communication_node != nullptr) { - // seed nodes first, then visit last all reduce node descendant + std::map optimized_comm_group; + while (!seed_nodes.empty() || !delay_comm_stack.empty()) { + // seed nodes first, then delay comm nodes if (seed_nodes.empty()) { - VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); - last_communication_node = nullptr; + VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); + delay_comm_stack.pop(); } else { zero_input_nodes.push(seed_nodes.front()); seed_nodes.pop(); } - // all reduce node descendant first, then common queue + // comm descendant first, then common queue while (!zero_input_nodes.empty() || !communication_descendants.empty()) { AnfNodePtr node = nullptr; bool is_communication_descendant = false; @@ -234,12 +264,20 @@ void KernelGraph::SetExecOrderByDefault() { if (node->isa() && AnfAlgo::IsRealKernel(node)) { execution_order_.push_back(node->cast()); } - // for all reduce node, visit last all reduce node descendant - if (AnfAlgo::IsCommunicationOp(node)) { - if (last_communication_node != nullptr) { - VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); + // delay execute comm ops that need optimize + bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node); + bool optimize_comm = is_fused_comm; + if (optimize_comm) { + optimize_comm = NeedOptimizeCommOp(node, &optimized_comm_group); + } + if (optimize_comm) { + while (!delay_comm_stack.empty()) { + VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); + delay_comm_stack.pop(); } - last_communication_node = node; + delay_comm_stack.push(node); + } else if (is_fused_comm) { + delay_comm_stack.push(node); } else if (is_communication_descendant) { VisitNodeDescendants(node, &communication_descendants, &visited_nodes); } else { @@ -540,7 +578,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { if (node->isa()) { auto parameter = node->cast(); MS_EXCEPTION_IF_NULL(parameter); - bool is_weight = AnfAlgo ::IsParameterWeight(parameter); + bool is_weight = AnfAlgo::IsParameterWeight(parameter); kernel_info->set_feature_map_flag(!is_weight); types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0)); } @@ -746,6 +784,7 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons // delete old kernel (void)backend_front_anf_map_.erase(old_backend_anf); } + // get kernel by anf AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) { if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index fb4b47ec48..2ee32dcd41 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -284,7 +284,7 @@ class KernelGraph : public FuncGraph { void ResetInFormat(const AnfNodePtr &node, const std::string &format) const; AnfNodePtr MakeValueNode(const AnfNodePtr &node); void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, - std::unordered_set *visited_nodes); + std::unordered_set *visited_nodes, bool comm_first = true); // update node edge list void UpdateNodeEdgeList(std::queue *seed_nodes); // add node depend edge by data edge or control depend diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 9b8626b7bc..5d418ac108 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -285,6 +285,7 @@ constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler"; // Communication world group constexpr auto kNcclWorldGroup = "nccl_world_group"; constexpr auto kHcclWorldGroup = "hccl_world_group"; +constexpr auto kSyncBnGroup = "sync_bn_group"; // Hcom Op Type constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";