!13724 optimize execute order for commops

From: @kisnwang
Reviewed-by: @zhoufeng54,@jjfeing
Signed-off-by: @jjfeing
pull/13724/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit efe95ebbce

@ -1259,6 +1259,23 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
return false;
}
bool AnfRuntimeAlgorithm::IsFusedCommunicationOp(const AnfNodePtr &node) {
if (!IsCommunicationOp(node)) {
return false;
}
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
if (attr_fusion == nullptr) {
return false;
}
auto fusion = GetValue<int64_t>(attr_fusion);
if (fusion == 0) {
return false;
}
return true;
}
bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
auto kernel_name = AnfAlgo::GetCNodeName(node);
return kernel_name == kGetNextOpName;

@ -216,6 +216,7 @@ class AnfRuntimeAlgorithm {
// get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsFusedCommunicationOp(const AnfNodePtr &node);
static bool IsInplaceNode(const AnfNodePtr &node, const string &type);
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);

@ -36,6 +36,7 @@ constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
constexpr size_t k5dDims = 5;
const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(),
prim::kPrimAssignSub->name()};
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
std::unordered_set<AnfNodePtr> *visited_nodes) {
MS_EXCEPTION_IF_NULL(node);
@ -129,7 +130,34 @@ void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std::
}
}
}
std::string GetNodeGroup(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
}
return "";
}
bool NeedOptimizeCommOp(const AnfNodePtr &node, std::map<std::string, std::string> *optimized_comm_group) {
MS_EXCEPTION_IF_NULL(optimized_comm_group);
auto node_group = GetNodeGroup(node);
if (node_group.find(kSyncBnGroup) != string::npos) {
return false;
}
auto node_name = AnfAlgo::GetCNodeName(node);
auto iter = optimized_comm_group->find(node_name);
if (iter == optimized_comm_group->end()) {
(*optimized_comm_group)[node_name] = node_group;
return true;
} else if (iter->second == node_group) {
return true;
}
return false;
}
} // namespace
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
@ -153,7 +181,7 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
}
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes) {
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) {
MS_EXCEPTION_IF_NULL(visit_queue);
MS_EXCEPTION_IF_NULL(visited_nodes);
auto it = node_output_edges_.find(node);
@ -184,7 +212,8 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
// allreduce first
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
(void)visited_nodes->insert(next_node);
if (AnfAlgo::IsCommunicationOp(next_node)) {
bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node);
if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString();
visit_queue->push(next_node);
} else {
@ -206,18 +235,19 @@ void KernelGraph::SetExecOrderByDefault() {
execution_order_.clear();
std::unordered_set<AnfNodePtr> visited_nodes;
std::queue<AnfNodePtr> zero_input_nodes;
AnfNodePtr last_communication_node = nullptr;
std::stack<AnfNodePtr> delay_comm_stack;
std::queue<AnfNodePtr> communication_descendants;
while (!seed_nodes.empty() || last_communication_node != nullptr) {
// seed nodes first, then visit last all reduce node descendant
std::map<std::string, std::string> optimized_comm_group;
while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
// seed nodes first, then delay comm nodes
if (seed_nodes.empty()) {
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
last_communication_node = nullptr;
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
delay_comm_stack.pop();
} else {
zero_input_nodes.push(seed_nodes.front());
seed_nodes.pop();
}
// all reduce node descendant first, then common queue
// comm descendant first, then common queue
while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
AnfNodePtr node = nullptr;
bool is_communication_descendant = false;
@ -234,12 +264,20 @@ void KernelGraph::SetExecOrderByDefault() {
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
execution_order_.push_back(node->cast<CNodePtr>());
}
// for all reduce node, visit last all reduce node descendant
if (AnfAlgo::IsCommunicationOp(node)) {
if (last_communication_node != nullptr) {
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
// delay execute comm ops that need optimize
bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node);
bool optimize_comm = is_fused_comm;
if (optimize_comm) {
optimize_comm = NeedOptimizeCommOp(node, &optimized_comm_group);
}
last_communication_node = node;
if (optimize_comm) {
while (!delay_comm_stack.empty()) {
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
delay_comm_stack.pop();
}
delay_comm_stack.push(node);
} else if (is_fused_comm) {
delay_comm_stack.push(node);
} else if (is_communication_descendant) {
VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
} else {
@ -746,6 +784,7 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
// delete old kernel
(void)backend_front_anf_map_.erase(old_backend_anf);
}
// get kernel by anf
AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) {

@ -284,7 +284,7 @@ class KernelGraph : public FuncGraph {
void ResetInFormat(const AnfNodePtr &node, const std::string &format) const;
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes);
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true);
// update node edge list
void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
// add node depend edge by data edge or control depend

@ -285,6 +285,7 @@ constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler";
// Communication world group
constexpr auto kNcclWorldGroup = "nccl_world_group";
constexpr auto kHcclWorldGroup = "hccl_world_group";
constexpr auto kSyncBnGroup = "sync_bn_group";
// Hcom Op Type
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";

Loading…
Cancel
Save