!13724 optimize execute order for commops

From: @kisnwang
Reviewed-by: @zhoufeng54,@jjfeing
Signed-off-by: @jjfeing
pull/13724/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit efe95ebbce

@ -1259,6 +1259,23 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
return false; return false;
} }
bool AnfRuntimeAlgorithm::IsFusedCommunicationOp(const AnfNodePtr &node) {
if (!IsCommunicationOp(node)) {
return false;
}
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
if (attr_fusion == nullptr) {
return false;
}
auto fusion = GetValue<int64_t>(attr_fusion);
if (fusion == 0) {
return false;
}
return true;
}
bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) { bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
auto kernel_name = AnfAlgo::GetCNodeName(node); auto kernel_name = AnfAlgo::GetCNodeName(node);
return kernel_name == kGetNextOpName; return kernel_name == kGetNextOpName;

@ -216,6 +216,7 @@ class AnfRuntimeAlgorithm {
// get real input index for some tbe ops which input order is different between me and tbe impl // get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
static bool IsCommunicationOp(const AnfNodePtr &node); static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsFusedCommunicationOp(const AnfNodePtr &node);
static bool IsInplaceNode(const AnfNodePtr &node, const string &type); static bool IsInplaceNode(const AnfNodePtr &node, const string &type);
static bool IsGetNext(const NotNull<AnfNodePtr> &node); static bool IsGetNext(const NotNull<AnfNodePtr> &node);
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);

@ -36,6 +36,7 @@ constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
constexpr size_t k5dDims = 5; constexpr size_t k5dDims = 5;
const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(), const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(),
prim::kPrimAssignSub->name()}; prim::kPrimAssignSub->name()};
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
std::unordered_set<AnfNodePtr> *visited_nodes) { std::unordered_set<AnfNodePtr> *visited_nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -129,7 +130,34 @@ void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std::
} }
} }
} }
std::string GetNodeGroup(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
}
return "";
}
bool NeedOptimizeCommOp(const AnfNodePtr &node, std::map<std::string, std::string> *optimized_comm_group) {
MS_EXCEPTION_IF_NULL(optimized_comm_group);
auto node_group = GetNodeGroup(node);
if (node_group.find(kSyncBnGroup) != string::npos) {
return false;
}
auto node_name = AnfAlgo::GetCNodeName(node);
auto iter = optimized_comm_group->find(node_name);
if (iter == optimized_comm_group->end()) {
(*optimized_comm_group)[node_name] = node_group;
return true;
} else if (iter->second == node_group) {
return true;
}
return false;
}
} // namespace } // namespace
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) { AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>(); auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) { if (value_node == nullptr) {
@ -153,7 +181,7 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
} }
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes) { std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) {
MS_EXCEPTION_IF_NULL(visit_queue); MS_EXCEPTION_IF_NULL(visit_queue);
MS_EXCEPTION_IF_NULL(visited_nodes); MS_EXCEPTION_IF_NULL(visited_nodes);
auto it = node_output_edges_.find(node); auto it = node_output_edges_.find(node);
@ -184,7 +212,8 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
// allreduce first // allreduce first
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) { if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
(void)visited_nodes->insert(next_node); (void)visited_nodes->insert(next_node);
if (AnfAlgo::IsCommunicationOp(next_node)) { bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node);
if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString(); MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString();
visit_queue->push(next_node); visit_queue->push(next_node);
} else { } else {
@ -206,18 +235,19 @@ void KernelGraph::SetExecOrderByDefault() {
execution_order_.clear(); execution_order_.clear();
std::unordered_set<AnfNodePtr> visited_nodes; std::unordered_set<AnfNodePtr> visited_nodes;
std::queue<AnfNodePtr> zero_input_nodes; std::queue<AnfNodePtr> zero_input_nodes;
AnfNodePtr last_communication_node = nullptr; std::stack<AnfNodePtr> delay_comm_stack;
std::queue<AnfNodePtr> communication_descendants; std::queue<AnfNodePtr> communication_descendants;
while (!seed_nodes.empty() || last_communication_node != nullptr) { std::map<std::string, std::string> optimized_comm_group;
// seed nodes first, then visit last all reduce node descendant while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
// seed nodes first, then delay comm nodes
if (seed_nodes.empty()) { if (seed_nodes.empty()) {
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
last_communication_node = nullptr; delay_comm_stack.pop();
} else { } else {
zero_input_nodes.push(seed_nodes.front()); zero_input_nodes.push(seed_nodes.front());
seed_nodes.pop(); seed_nodes.pop();
} }
// all reduce node descendant first, then common queue // comm descendant first, then common queue
while (!zero_input_nodes.empty() || !communication_descendants.empty()) { while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
AnfNodePtr node = nullptr; AnfNodePtr node = nullptr;
bool is_communication_descendant = false; bool is_communication_descendant = false;
@ -234,12 +264,20 @@ void KernelGraph::SetExecOrderByDefault() {
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) { if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
execution_order_.push_back(node->cast<CNodePtr>()); execution_order_.push_back(node->cast<CNodePtr>());
} }
// for all reduce node, visit last all reduce node descendant // delay execute comm ops that need optimize
if (AnfAlgo::IsCommunicationOp(node)) { bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node);
if (last_communication_node != nullptr) { bool optimize_comm = is_fused_comm;
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); if (optimize_comm) {
optimize_comm = NeedOptimizeCommOp(node, &optimized_comm_group);
} }
last_communication_node = node; if (optimize_comm) {
while (!delay_comm_stack.empty()) {
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
delay_comm_stack.pop();
}
delay_comm_stack.push(node);
} else if (is_fused_comm) {
delay_comm_stack.push(node);
} else if (is_communication_descendant) { } else if (is_communication_descendant) {
VisitNodeDescendants(node, &communication_descendants, &visited_nodes); VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
} else { } else {
@ -746,6 +784,7 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
// delete old kernel // delete old kernel
(void)backend_front_anf_map_.erase(old_backend_anf); (void)backend_front_anf_map_.erase(old_backend_anf);
} }
// get kernel by anf // get kernel by anf
AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) { AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) { if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) {

@ -284,7 +284,7 @@ class KernelGraph : public FuncGraph {
void ResetInFormat(const AnfNodePtr &node, const std::string &format) const; void ResetInFormat(const AnfNodePtr &node, const std::string &format) const;
AnfNodePtr MakeValueNode(const AnfNodePtr &node); AnfNodePtr MakeValueNode(const AnfNodePtr &node);
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes); std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true);
// update node edge list // update node edge list
void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes); void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
// add node depend edge by data edge or control depend // add node depend edge by data edge or control depend

@ -285,6 +285,7 @@ constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler";
// Communication world group // Communication world group
constexpr auto kNcclWorldGroup = "nccl_world_group"; constexpr auto kNcclWorldGroup = "nccl_world_group";
constexpr auto kHcclWorldGroup = "hccl_world_group"; constexpr auto kHcclWorldGroup = "hccl_world_group";
constexpr auto kSyncBnGroup = "sync_bn_group";
// Hcom Op Type // Hcom Op Type
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";

Loading…
Cancel
Save