|
|
|
@ -36,6 +36,7 @@ constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
|
|
|
|
constexpr size_t k5dDims = 5;
|
|
|
|
|
const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(),
|
|
|
|
|
prim::kPrimAssignSub->name()};
|
|
|
|
|
|
|
|
|
|
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
|
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
@ -129,7 +130,34 @@ void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std::
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GetNodeGroup(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
|
|
|
|
|
return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
|
|
|
|
|
}
|
|
|
|
|
return "";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool NeedOptimizeCommOp(const AnfNodePtr &node, std::map<std::string, std::string> *optimized_comm_group) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(optimized_comm_group);
|
|
|
|
|
auto node_group = GetNodeGroup(node);
|
|
|
|
|
if (node_group.find(kSyncBnGroup) != string::npos) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto node_name = AnfAlgo::GetCNodeName(node);
|
|
|
|
|
auto iter = optimized_comm_group->find(node_name);
|
|
|
|
|
if (iter == optimized_comm_group->end()) {
|
|
|
|
|
(*optimized_comm_group)[node_name] = node_group;
|
|
|
|
|
return true;
|
|
|
|
|
} else if (iter->second == node_group) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) {
|
|
|
|
|
auto value_node = node->cast<ValueNodePtr>();
|
|
|
|
|
if (value_node == nullptr) {
|
|
|
|
@ -153,7 +181,7 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
|
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes) {
|
|
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(visit_queue);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(visited_nodes);
|
|
|
|
|
auto it = node_output_edges_.find(node);
|
|
|
|
@ -184,7 +212,8 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
|
|
|
|
|
// allreduce first
|
|
|
|
|
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
|
|
|
|
|
(void)visited_nodes->insert(next_node);
|
|
|
|
|
if (AnfAlgo::IsCommunicationOp(next_node)) {
|
|
|
|
|
bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node);
|
|
|
|
|
if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
|
|
|
|
|
MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString();
|
|
|
|
|
visit_queue->push(next_node);
|
|
|
|
|
} else {
|
|
|
|
@ -206,18 +235,19 @@ void KernelGraph::SetExecOrderByDefault() {
|
|
|
|
|
execution_order_.clear();
|
|
|
|
|
std::unordered_set<AnfNodePtr> visited_nodes;
|
|
|
|
|
std::queue<AnfNodePtr> zero_input_nodes;
|
|
|
|
|
AnfNodePtr last_communication_node = nullptr;
|
|
|
|
|
std::stack<AnfNodePtr> delay_comm_stack;
|
|
|
|
|
std::queue<AnfNodePtr> communication_descendants;
|
|
|
|
|
while (!seed_nodes.empty() || last_communication_node != nullptr) {
|
|
|
|
|
// seed nodes first, then visit last all reduce node descendant
|
|
|
|
|
std::map<std::string, std::string> optimized_comm_group;
|
|
|
|
|
while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
|
|
|
|
|
// seed nodes first, then delay comm nodes
|
|
|
|
|
if (seed_nodes.empty()) {
|
|
|
|
|
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
|
|
|
|
|
last_communication_node = nullptr;
|
|
|
|
|
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
|
|
|
|
|
delay_comm_stack.pop();
|
|
|
|
|
} else {
|
|
|
|
|
zero_input_nodes.push(seed_nodes.front());
|
|
|
|
|
seed_nodes.pop();
|
|
|
|
|
}
|
|
|
|
|
// all reduce node descendant first, then common queue
|
|
|
|
|
// comm descendant first, then common queue
|
|
|
|
|
while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
|
|
|
|
|
AnfNodePtr node = nullptr;
|
|
|
|
|
bool is_communication_descendant = false;
|
|
|
|
@ -234,12 +264,20 @@ void KernelGraph::SetExecOrderByDefault() {
|
|
|
|
|
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
|
|
|
|
|
execution_order_.push_back(node->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
// for all reduce node, visit last all reduce node descendant
|
|
|
|
|
if (AnfAlgo::IsCommunicationOp(node)) {
|
|
|
|
|
if (last_communication_node != nullptr) {
|
|
|
|
|
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
|
|
|
|
|
// delay execute comm ops that need optimize
|
|
|
|
|
bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node);
|
|
|
|
|
bool optimize_comm = is_fused_comm;
|
|
|
|
|
if (optimize_comm) {
|
|
|
|
|
optimize_comm = NeedOptimizeCommOp(node, &optimized_comm_group);
|
|
|
|
|
}
|
|
|
|
|
last_communication_node = node;
|
|
|
|
|
if (optimize_comm) {
|
|
|
|
|
while (!delay_comm_stack.empty()) {
|
|
|
|
|
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
|
|
|
|
|
delay_comm_stack.pop();
|
|
|
|
|
}
|
|
|
|
|
delay_comm_stack.push(node);
|
|
|
|
|
} else if (is_fused_comm) {
|
|
|
|
|
delay_comm_stack.push(node);
|
|
|
|
|
} else if (is_communication_descendant) {
|
|
|
|
|
VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
|
|
|
|
|
} else {
|
|
|
|
@ -746,6 +784,7 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
|
|
|
|
|
// delete old kernel
|
|
|
|
|
(void)backend_front_anf_map_.erase(old_backend_anf);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// get kernel by anf
|
|
|
|
|
AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
|
|
|
|
|
if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) {
|
|
|
|
|