diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index 83c27ae779..b000b647f4 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -25,6 +25,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_adjust.h" #include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/oplib/oplib.h" #include "utils/utils.h" namespace mindspore { @@ -38,6 +39,7 @@ void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) Reset(); SetLoopSink(); ReorderIndependentOrders(graph_ptr); + TrailingTimeOptimizationByReorder(graph_ptr); AssignAllNodesStream(graph_ptr); UpdateAtomicAddrCleanStreamId(graph_ptr); @@ -128,6 +130,305 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull graph_ptr->set_execution_order(exe_orders); } +void AscendStreamAssign::CheckScenario(const NotNull &graph_ptr, + vector *last_grad_and_status) { + auto cnode_ptr_list = graph_ptr->execution_order(); + vector hcom_nodes; + CNodePtr cur_cnode_ptr = nullptr; + CNodePtr overflow_marker = nullptr; + std::string kNPUGetFloatStatusOpName = "NPUGetFloatStatus"; + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kNPUGetFloatStatusOpName) { + overflow_marker = cur_cnode_ptr; + } else if (IsHcom(cur_cnode_ptr)) { + hcom_nodes.emplace_back(cur_cnode_ptr); + } else if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) { + auto graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get()); + AnfAlgo::SetGraphId(graph_id, cnode_ptr_list[i - 1].get()); + } + } + + if (hcom_nodes.size() < 2 || overflow_marker == nullptr) { + MS_LOG(INFO) << "Current model isn't in distribute or mix-precision mode, no optimization needed"; + last_grad_and_status->clear(); + return; + } + + auto overflow_marker_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), overflow_marker); + auto last_hcom_ptr = hcom_nodes[hcom_nodes.size() - 1]; + auto last_hcom_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_hcom_ptr); + auto last_grad_hcom_ptr = hcom_nodes[hcom_nodes.size() - 2]; + auto last_grad_hcom_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_hcom_ptr); + if (last_grad_hcom_pos > overflow_marker_pos || last_hcom_pos < overflow_marker_pos) { + MS_LOG(INFO) << "Grads average done after overflow judgement or status aren't allgathered, no optimization needed"; + last_grad_and_status->clear(); + return; + } + + auto last_inputs = GetLastInputCnode(graph_ptr, last_grad_hcom_ptr); + if (last_inputs.empty() || last_inputs.size() > 1 || IsHcom(last_inputs[0])) { + MS_LOG(INFO) << "Inputs of last gradients allreduce is empty or include other allreduce, no optimization needed"; + last_grad_and_status->clear(); + return; + } + auto last_grad_ptr = last_inputs[0]; + MS_LOG(DEBUG) << "Last Hcom: " << last_grad_hcom_ptr->fullname_with_scope() + << "; last input: " << last_grad_ptr->fullname_with_scope(); + auto last_grad_hcom_graph_id = AnfAlgo::GetGraphId(last_grad_hcom_ptr.get()); + auto last_grad_graph_id = AnfAlgo::GetGraphId(last_grad_ptr.get()); + auto overflow_marker_graph_id = AnfAlgo::GetGraphId(overflow_marker.get()); + if (last_grad_graph_id != last_grad_hcom_graph_id || last_grad_graph_id != overflow_marker_graph_id) { + MS_LOG(INFO) << "The grads and grad_hcom or overflow marker were not on the same subgraph, no optimization needed"; + last_grad_and_status->clear(); + return; + } + + auto label_switch_pos = find_if(last_grad_hcom_pos, cnode_ptr_list.end(), + [](CNodePtr &node) -> bool { return AnfAlgo::GetCNodeName(node) == "LabelSwitch"; }); + if (label_switch_pos == cnode_ptr_list.end()) { + MS_LOG(INFO) << "No branches after getting overflow status, no optimization needed"; + last_grad_and_status->clear(); + return; + } + last_grad_and_status->emplace_back(last_grad_ptr); + last_grad_and_status->emplace_back(overflow_marker); + return; +} + +CNodePtr AscendStreamAssign::GetCNodesNeededMoved(vector *moved_backward_cnodes, + vector *moved_forward_cnodes, + const vector &last_grad_and_status, + const NotNull &graph_ptr) { + auto cnode_ptr_list = graph_ptr->execution_order(); + if (last_grad_and_status.size() != 2) { + return nullptr; + } + auto last_grad_ptr = last_grad_and_status[0]; + auto float_status_ptr = last_grad_and_status[1]; + auto last_grad_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_ptr); + auto float_status_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), float_status_ptr); + if (last_grad_pos == cnode_ptr_list.end() || float_status_pos == cnode_ptr_list.end()) { + moved_backward_cnodes->clear(); + moved_forward_cnodes->clear(); + return nullptr; + } + auto graph_id = AnfAlgo::GetGraphId(last_grad_ptr.get()); + moved_backward_cnodes->insert(moved_backward_cnodes->end(), last_grad_pos + 1, float_status_pos); + + auto it = float_status_pos; + while (AnfAlgo::GetGraphId((*it).get()) == graph_id && it < cnode_ptr_list.end()) { + if (AnfAlgo::GetCNodeName(*it) == kAtomicAddrCleanOpName) { + it++; + continue; + } + auto inputs = GetInputKernels(*it); + bool is_independent = true; + for (auto &input : inputs) { + if (find(moved_backward_cnodes->begin(), moved_backward_cnodes->end(), input) != moved_backward_cnodes->end()) { + is_independent = false; + break; + } + } + if (is_independent) { + if (AnfAlgo::GetCNodeName(*(it - 1)) == kAtomicAddrCleanOpName) { + moved_forward_cnodes->emplace_back(*(it - 1)); + } + moved_forward_cnodes->emplace_back(*it); + } else { + if (AnfAlgo::GetCNodeName(*(it - 1)) == kAtomicAddrCleanOpName) { + moved_backward_cnodes->emplace_back(*(it - 1)); + } + moved_backward_cnodes->emplace_back(*it); + } + it++; + } + // check ref nodes + for (auto &cnode : *moved_backward_cnodes) { + std::string op_name = AnfAlgo::GetCNodeName(cnode); + auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); + if (op_info != nullptr && op_info->is_ref()) { + MS_LOG(INFO) << "Find RefNode: " << op_name << ", full name: " << cnode->fullname_with_scope(); + moved_backward_cnodes->clear(); + moved_forward_cnodes->clear(); + return nullptr; + } + } + + size_t total_moved_size = it - last_grad_pos - 1; + if (moved_backward_cnodes->size() + moved_forward_cnodes->size() != total_moved_size) { + MS_LOG(DEBUG) << "Total number inconsistent, total cnode number: " << total_moved_size + << ", while move forward size: " << moved_forward_cnodes->size() + << ", moved backward size: " << moved_backward_cnodes->size(); + moved_forward_cnodes->clear(); + moved_backward_cnodes->clear(); + return nullptr; + } + + uint32_t subgraph_id = 0; + bool get_subgraph_id = false; + CNodePtr first_output_node_ptr = nullptr; + while (!get_subgraph_id && it < cnode_ptr_list.end()) { + auto inputs = GetInputKernels(*it); + for (auto &input : inputs) { + if (find(moved_backward_cnodes->begin(), moved_backward_cnodes->end(), input) != moved_backward_cnodes->end()) { + MS_LOG(DEBUG) << "get subgraph id: " << AnfAlgo::GetGraphId((*it).get()); + get_subgraph_id = true; + subgraph_id = AnfAlgo::GetGraphId((*it).get()); + first_output_node_ptr = *it; + break; + } + } + it++; + } + if (subgraph_id == 0) { + MS_LOG(INFO) << "The nodes moved backward were not used by any other nodes, no need moved"; + moved_forward_cnodes->clear(); + moved_backward_cnodes->clear(); + return nullptr; + } + + for (; it < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*it).get()) != subgraph_id; it++) { + auto inputs = GetInputKernels(*it); + for (auto &input : inputs) { + if (find(moved_backward_cnodes->begin(), moved_backward_cnodes->end(), input) != moved_backward_cnodes->end()) { + MS_LOG(INFO) << "The nodes moved backward were used by nodes on different subgraphs, no need moved"; + moved_forward_cnodes->clear(); + moved_backward_cnodes->clear(); + return nullptr; + } + } + } + return first_output_node_ptr; +} + +void AscendStreamAssign::FinetuneSubgraphExecOrder(vector *cnodes) { + MS_EXCEPTION_IF_NULL(cnodes); + auto hcom_pos = find_if(cnodes->begin(), cnodes->end(), + [](CNodePtr &node_ptr) -> bool { return AnfAlgo::GetCNodeName(node_ptr) == "AllReduce"; }); + if (hcom_pos == cnodes->end()) { + cnodes->clear(); + return; + } + CNodePtr hcom_ptr = *hcom_pos; + + vector ori_cnodes(cnodes->begin(), cnodes->end()); + cnodes->clear(); + vector atomic_addr_clean; + for (auto iter = ori_cnodes.begin(); iter < ori_cnodes.end(); iter++) { + if (AnfAlgo::GetCNodeName(*iter) == kAtomicAddrCleanOpName) { + atomic_addr_clean.emplace_back(*iter); + continue; + } + auto inputs = GetInputKernels(*iter); + auto last_input_pos = cnodes->end(); + for (auto &input : inputs) { + auto pos = find(cnodes->begin(), cnodes->end(), input); + if (pos != cnodes->end()) { + last_input_pos = (last_input_pos == cnodes->end() || last_input_pos < pos) ? pos : last_input_pos; + } + } + if (last_input_pos == cnodes->end()) { + auto hcom_it = find(cnodes->begin(), cnodes->end(), hcom_ptr); + if (hcom_it == cnodes->end() || AnfAlgo::GetCNodeName(*iter) == kLabelGotoOpName || + AnfAlgo::GetCNodeName(*iter) == kLabelSetOpName || AnfAlgo::GetCNodeName(*iter) == kLabelSwitchOpName) { + cnodes->emplace_back(*iter); + } else { + cnodes->insert(hcom_it, *iter); + } + } else { + cnodes->insert(last_input_pos + 1, *iter); + } + } + + for (auto &node : atomic_addr_clean) { + auto inputs = GetInputKernels(node); + auto first_input_pos = cnodes->end(); + for (auto &input : inputs) { + auto pos = find(cnodes->begin(), cnodes->end(), input); + first_input_pos = (first_input_pos == cnodes->end() || first_input_pos > pos) ? pos : first_input_pos; + } + if (first_input_pos == cnodes->end()) { + MS_LOG(DEBUG) << "node: " << node->fullname_with_scope() << " 's input was not found"; + cnodes->clear(); + return; + } else { + cnodes->insert(first_input_pos, node); + } + } + if (cnodes->size() != ori_cnodes.size()) { + MS_LOG(DEBUG) << "Total number inconsistent, original node size: " << ori_cnodes.size() + << ", while the new size after finetune order is: " << cnodes->size(); + cnodes->clear(); + return; + } +} + +// performance optimization for trailing time in distribute mode +// allreduce of the last batch of gradients and the optimizer can be done parallel +void AscendStreamAssign::TrailingTimeOptimizationByReorder(const NotNull &graph_ptr) { + MS_LOG(INFO) << "Trailing time optimization begin"; + vector last_grad_and_status; + CheckScenario(graph_ptr, &last_grad_and_status); + if (last_grad_and_status.empty()) { + MS_LOG(INFO) << "Unsuitable scenario, no optimization needed"; + return; + } + + auto cnode_ptr_list = graph_ptr->execution_order(); + vector moved_forward_cnodes; + vector moved_backward_cnodes; + CNodePtr first_output_ptr = + GetCNodesNeededMoved(&moved_backward_cnodes, &moved_forward_cnodes, last_grad_and_status, graph_ptr); + if (moved_backward_cnodes.empty() || first_output_ptr == nullptr) { + MS_LOG(INFO) << "Unsuitable scenario, no optimization needed"; + return; + } + + uint32_t subgraph_id = AnfAlgo::GetGraphId(first_output_ptr.get()); + auto last_grad_ptr = last_grad_and_status[0]; + auto last_grad_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_ptr); + vector cnodes(cnode_ptr_list.begin(), last_grad_pos + 1); + cnodes.insert(cnodes.end(), moved_forward_cnodes.begin(), moved_forward_cnodes.end()); + auto pos = last_grad_pos + moved_forward_cnodes.size() + moved_backward_cnodes.size() + 1; + while (pos < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*pos).get()) != subgraph_id) { + cnodes.emplace_back(*pos); + pos++; + } + + vector subgraph_cnodes; + while (pos < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*pos).get()) == subgraph_id) { + if (*pos != first_output_ptr) { + subgraph_cnodes.emplace_back(*pos); + } else { + subgraph_cnodes.insert(subgraph_cnodes.end(), moved_backward_cnodes.begin(), moved_backward_cnodes.end()); + subgraph_cnodes.emplace_back(*pos); + } + pos++; + } + + FinetuneSubgraphExecOrder(&subgraph_cnodes); + if (subgraph_cnodes.empty()) { + MS_LOG(INFO) << "Finetune subgraph execute order failed, no optimization needed"; + return; + } + + cnodes.insert(cnodes.end(), subgraph_cnodes.begin(), subgraph_cnodes.end()); + cnodes.insert(cnodes.end(), pos, cnode_ptr_list.end()); + if (cnodes.size() != cnode_ptr_list.size()) { + MS_LOG(INFO) << "Inconsistent cnodes number. Original size: " << cnode_ptr_list.size() + << ", while new order cnodes size: " << cnodes.size(); + return; + } + for (auto &node : subgraph_cnodes) { + AnfAlgo::SetGraphId(subgraph_id, node.get()); + } + + graph_ptr->set_execution_order(cnodes); + MS_LOG(INFO) << "Trailing time optimization end"; +} + // section 2 void AscendStreamAssign::AssignAllNodesStream(const NotNull &graph_ptr) { auto cnode_ptr_list = graph_ptr->execution_order(); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h index 8730934c33..4924d725ca 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h @@ -161,6 +161,13 @@ class AscendStreamAssign { void GetProcessedStream(const NotNull &graph_ptr); void GetNeedActiveStreams(const NotNull &graph_ptr); void ReorderIndependentOrders(const NotNull &graph_ptr); + + void CheckScenario(const NotNull &graph_ptr, vector *last_grad_and_status); + CNodePtr GetCNodesNeededMoved(vector *moved_backward_cnodes, vector *moved_forward_cnodes, + const vector &last_grad_and_status, const NotNull &graph_ptr); + void FinetuneSubgraphExecOrder(vector *cnodes); + void TrailingTimeOptimizationByReorder(const NotNull &graph_ptr); + uint32_t GetMaxIndexTarget(const NotNull &graph_ptr); uint32_t GetIndexByKey(const NotNull &graph_ptr, const CNodeKey &key); uint32_t GetIndependentStreamSwitchStreamId(const NotNull &graph_ptr);