From 6c22c8a09d84f85c9960b5af68ef15d1e5e28c76 Mon Sep 17 00:00:00 2001 From: gukecai Date: Thu, 20 Aug 2020 14:33:46 +0800 Subject: [PATCH] parallel ctrl --- .../optimizer/pass/communication_op_fusion.cc | 3 + .../device/ascend/ascend_stream_assign.cc | 450 ++++++++++++++++-- .../device/ascend/ascend_stream_assign.h | 22 +- 3 files changed, 440 insertions(+), 35 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc index 7263a1c86b..3042e822a4 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc @@ -211,8 +211,11 @@ bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const Commu start_index = end_index + 1; continue; } + auto kernel_graph = func_graph->cast(); + auto graph_id = kernel_graph->graph_id(); AnfNodePtr new_communication_op = CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index); + AnfAlgo::SetGraphId(graph_id, new_communication_op.get()); // replace old communication op with new communication op for (auto idx = start_index; idx <= end_index; ++idx) { std::vector tuple_getitem_input; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index e809f94969..6007ebc84e 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -36,6 +36,7 @@ const uint32_t kCommonMaxTask = 350; void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) { if (IsTaskSink()) { Reset(); + SetLoopSink(); ReorderIndependentOrders(graph_ptr); AssignAllNodesStream(graph_ptr); UpdateAtomicAddrCleanStreamId(graph_ptr); @@ -46,9 +47,9 @@ void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) InsertCtrlForIndependentParallel(graph_ptr); GetNeedActiveStreams(graph_ptr); - graph_ptr->PrintGraphExecuteOrder(); CheckResourceAssign(graph_ptr); MS_LOG(INFO) << "After finish stream assign"; + graph_ptr->PrintGraphExecuteOrder(); FindStreamRelations(graph_ptr); PrintStreamRelations(); @@ -58,6 +59,14 @@ void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) } } +void AscendStreamAssign::SetLoopSink() { + if (KernelAdjust::NeedInsertSwitch()) { + loop_sink_ = true; + } else { + loop_sink_ = false; + } +} + // section 1 void AscendStreamAssign::ReorderIndependentOrders(const NotNull &graph_ptr) { std::vector exe_orders; @@ -146,7 +155,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull &gra MS_LOG(INFO) << "Common start from 0, common stream nums:" << resource_manager.get_cur_stream_num(); if (exit_hcom) { - uint32_t first_hcom_stream_id = resource_manager.ApplyNewStream(); + std::map> graph_nodes_map; for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; // node has been assigned stream before @@ -155,28 +164,63 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull &gra } if (IsHcom(cur_cnode_ptr)) { - AssignHcomStreamId(cur_cnode_ptr); + auto hcom_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get()); + auto it = graph_nodes_map.find(hcom_graph_id); + if (it == graph_nodes_map.end()) { + graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr}; + } else { + it->second.emplace_back(cur_cnode_ptr); + } + } + } + MS_LOG(INFO) << "hcom diff graph id size:" << graph_nodes_map.size(); + for (const auto &item : graph_nodes_map) { + bool new_graph = true; + auto graph_id = item.first; + hcom_graph_map_[graph_id] = {}; + for (const auto &hcom_node_ptr : item.second) { + auto assigned_stream_id = AssignHcomStreamId(hcom_node_ptr, new_graph); + hcom_graph_map_[graph_id].emplace(assigned_stream_id); + new_graph = false; } } - MS_LOG(INFO) << "Hcom start from :" << first_hcom_stream_id << ", hcom stream nums:" << hcom_stream_map_.size(); + MS_LOG(INFO) << "hcom stream nums : " << hcom_stream_map_.size(); } if (exit_independent) { - uint32_t first_independ = resource_manager.ApplyNewStream(); + std::map> graph_nodes_map; for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { continue; } if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) { - AssignIndependentStreamId(cur_cnode_ptr); + auto independent_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get()); + auto it = graph_nodes_map.find(independent_graph_id); + if (it == graph_nodes_map.end()) { + graph_nodes_map[independent_graph_id] = {cur_cnode_ptr}; + } else { + it->second.emplace_back(cur_cnode_ptr); + } } } - MS_LOG(INFO) << "Independ start from:" << first_independ << ", stream nums:" << independent_stream_map_.size(); + + MS_LOG(INFO) << "independent diff graph id size:" << graph_nodes_map.size(); + for (const auto &item : graph_nodes_map) { + bool new_graph = true; + auto graph_id = item.first; + independent_graph_map_[graph_id] = {}; + for (const auto &independent_node_ptr : item.second) { + auto assigned_stream_id = AssignIndependentStreamId(independent_node_ptr, new_graph); + independent_graph_map_[graph_id].emplace(assigned_stream_id); + new_graph = false; + } + } + MS_LOG(INFO) << "stream nums:" << independent_stream_map_.size(); } MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num(); -} +} // namespace ascend void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) { MS_EXCEPTION_IF_NULL(cur_cnode_ptr); @@ -205,10 +249,15 @@ void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) { } } -void AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr) { +uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) { MS_EXCEPTION_IF_NULL(cur_cnode_ptr); AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - uint32_t cur_hcom_stream_id = resource_manager.GetCurAllocStreamId(); + uint32_t cur_hcom_stream_id; + if (new_graph) { + cur_hcom_stream_id = resource_manager.ApplyNewStream(); + } else { + cur_hcom_stream_id = resource_manager.GetCurAllocStreamId(); + } auto it = hcom_stream_map_.find(cur_hcom_stream_id); if (it == hcom_stream_map_.end()) { AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); @@ -223,26 +272,34 @@ void AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr) { hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); } } + return cur_hcom_stream_id; } -void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr) { +uint32_t AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) { MS_EXCEPTION_IF_NULL(cur_cnode_ptr); AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - uint32_t cur_independent_id = resource_manager.GetCurAllocStreamId(); - auto it = independent_stream_map_.find(cur_independent_id); + uint32_t cur_independent_stream_id; + if (new_graph) { + cur_independent_stream_id = resource_manager.ApplyNewStream(); + } else { + cur_independent_stream_id = resource_manager.GetCurAllocStreamId(); + } + auto it = independent_stream_map_.find(cur_independent_stream_id); if (it == independent_stream_map_.end()) { - AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); - independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); + AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get()); + independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1)); } else { if (it->second < kCommonMaxTask) { AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); it->second++; } else { - cur_independent_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); - independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); + cur_independent_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get()); + independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1)); } } + + return cur_independent_stream_id; } // section 3: @@ -262,6 +319,182 @@ void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr) { + InsertStreamActiveForCommon(graph_ptr); + InsertStreamActiveForIndependent(graph_ptr); + InsertStreamActiveForParallel(graph_ptr); +} + +void AscendStreamAssign::InsertStreamActiveForParallel(const NotNull &graph_ptr) { + if (hcom_graph_map_.empty() && independent_graph_map_.empty()) { + MS_LOG(INFO) << "Hcom and independent is empty"; + return; + } + auto root_graph_id = graph_ptr->graph_id(); + if (root_graph_id == kInvalidGraphId) { + MS_LOG(INFO) << "Root graph id is invalid"; + return; + } + + MS_LOG(DEBUG) << "Hcom grpah map size:" << hcom_graph_map_.size(); + std::map> other_graph; + for (const auto &item : hcom_graph_map_) { + MS_LOG(INFO) << "Graph id:" << item.first; + if (item.first == root_graph_id) { + if (loop_sink_) { + ActiveRootGraphHcom(graph_ptr, item.second); + } + } else { + other_graph[item.first] = item.second; + } + } + + MS_LOG(INFO) << "Independent graph map size:" << independent_graph_map_.size(); + for (const auto &item : independent_graph_map_) { + MS_LOG(DEBUG) << "Graph id:" << item.first; + if (item.first == root_graph_id) { + if (loop_sink_) { + ActiveRootGraphIndependent(graph_ptr, item.second); + } + } else { + auto it = other_graph.find(item.first); + if (it == other_graph.end()) { + other_graph[item.first] = item.second; + } else { + for (const auto &stream : item.second) { + it->second.emplace(stream); + } + } + } + } + + ActiveOtherGraphParallel(graph_ptr, other_graph); +} + +void AscendStreamAssign::ActiveOtherGraphParallel(const NotNull &graph_ptr, + std::map> other_graph) { + MS_LOG(INFO) << "Other graph size:" << other_graph.size(); + if (other_graph.empty()) { + return; + } + + auto root_graph_id = graph_ptr->graph_id(); + + std::vector update_stream_list; + auto exe_order = graph_ptr->execution_order(); + for (size_t i = 0; i < exe_order.size(); i++) { + auto cur_cnode_ptr = exe_order[i]; + auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get()); + if (cur_graph_id == root_graph_id) { + update_stream_list.emplace_back(cur_cnode_ptr); + continue; + } + + auto it = other_graph.find(cur_graph_id); + if (it == other_graph.end()) { + update_stream_list.emplace_back(cur_cnode_ptr); + continue; + } + + auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); + // 1.set stream id + AnfAlgo::SetStreamId(cur_stream_id, active_ptr.get()); + // 2.set active stream ids + std::vector active_index_list; + std::copy(it->second.begin(), it->second.end(), std::back_inserter(active_index_list)); + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); + + // find position for insert streamactive + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kLabelSetOpName) { + update_stream_list.emplace_back(cur_cnode_ptr); + update_stream_list.emplace_back(active_ptr); + } else { + update_stream_list.emplace_back(active_ptr); + update_stream_list.emplace_back(cur_cnode_ptr); + } + other_graph.erase(it); + } + graph_ptr->set_execution_order(update_stream_list); +} + +void AscendStreamAssign::ActiveRootGraphHcom(const NotNull &graph_ptr, + const std::set &hcom_streams) { + MS_LOG(INFO) << "Active root graph hcom start"; + std::vector update_cnode_list; + auto exe_orders = graph_ptr->execution_order(); + for (size_t i = 0; i < exe_orders.size(); i++) { + CNodePtr cur_cnode_ptr = exe_orders[i]; + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) != kStreamSwitchOpName) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + + if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, cur_cnode_ptr)) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + + auto kind = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrStreamSwitchKind); + if (kind != kFpBpStreamSwitch) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + + auto true_stream_id = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrTrueBranchStream); + MS_LOG(INFO) << "FpBpStreamswtich stream id:" << AnfAlgo::GetStreamId(cur_cnode_ptr) + << "; true branch stream id:" << true_stream_id; + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); + AnfAlgo::SetStreamId(true_stream_id, active_ptr.get()); + vector active_ids; + // active hcom stream + std::copy(hcom_streams.begin(), hcom_streams.end(), std::back_inserter(active_ids)); + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_ids), active_ptr); + update_cnode_list.emplace_back(cur_cnode_ptr); + update_cnode_list.emplace_back(active_ptr); + std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list)); + break; + } + + hcom_stream_activated_ = true; + graph_ptr->set_execution_order(update_cnode_list); +} + +void AscendStreamAssign::ActiveRootGraphIndependent(const NotNull &graph_ptr, + std::set independent_streams) { + MS_LOG(DEBUG) << "Start active root graph independent"; + std::vector update_cnode_list; + auto exe_orders = graph_ptr->execution_order(); + for (size_t i = 0; i < exe_orders.size(); i++) { + CNodePtr cur_cnode_ptr = exe_orders[i]; + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) != kStreamSwitchOpName) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + + if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, cur_cnode_ptr)) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + + auto kind = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrStreamSwitchKind); + if (kind != kIndependentStreamSwitch) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + + // first independetn stream id is minimum and order by std map; + auto first_independent_stream = *(independent_streams.begin()); + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(first_independent_stream), cur_cnode_ptr); + update_cnode_list.emplace_back(cur_cnode_ptr); + std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list)); + break; + } + + independent_stream_activated_ = true; + graph_ptr->set_execution_order(update_cnode_list); +} + +void AscendStreamAssign::InsertStreamActiveForCommon(const NotNull &graph_ptr) { MS_LOG(INFO) << "Start"; GetProcessedStream(graph_ptr); std::vector update_cnode_list; @@ -298,7 +531,8 @@ void AscendStreamAssign::InsertStreamActive(const NotNull &graph if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel"; - UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list); + // UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list); + update_cnode_list.emplace_back(cur_cnode_ptr); } else { update_cnode_list.emplace_back(cur_cnode_ptr); } @@ -308,6 +542,70 @@ void AscendStreamAssign::InsertStreamActive(const NotNull &graph pre_cnode_ptr = cur_cnode_ptr; } graph_ptr->set_execution_order(update_cnode_list); +} + +void AscendStreamAssign::InsertStreamActiveForIndependent(const NotNull &graph_ptr) { + auto root_graph_id = graph_ptr->graph_id(); + if (root_graph_id == kInvalidGraphId) { + return; + } + std::set independent_streams; + for (const auto &item : independent_graph_map_) { + if (item.first == root_graph_id) { + independent_streams = item.second; + } + } + + if (independent_streams.size() <= 1) { + MS_LOG(INFO) << "Root graph independent stream size is not more than one, no need insert active"; + return; + } + std::vector update_cnode_list; + auto exe_orders = graph_ptr->execution_order(); + + // first independent is been actived, active other independent stream + std::vector streams; + std::copy(independent_streams.begin(), independent_streams.end(), std::back_inserter(streams)); + std::sort(streams.begin(), streams.end()); + uint32_t node_num = 0; + uint32_t cur_stream_id = kInvalidStreamId; + for (size_t i = 0; i < exe_orders.size(); i++) { + auto cur_cnode_ptr = exe_orders[i]; + update_cnode_list.emplace_back(cur_cnode_ptr); + bool flag = AnfAlgo::IsIndependentNode(cur_cnode_ptr); + if (!flag) { + continue; + } + + auto graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get()); + if (graph_id != root_graph_id) { + continue; + } + + node_num++; + cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + auto it = std::find(streams.begin(), streams.end(), cur_stream_id); + if (it == streams.end()) { + MS_LOG(EXCEPTION) << "Can't find independent stream id:" << cur_stream_id; + } + + if (it == streams.end() - 1) { + std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list)); + break; + } else { + if (node_num == kCommonMaxTask) { + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); + // 1.set stream id + AnfAlgo::SetStreamId(cur_stream_id, active_ptr.get()); + // 2.set active stream ids + std::vector active_index_list{*(it + 1)}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); + update_cnode_list.emplace_back(active_ptr); + node_num = 0; + } + } + } + graph_ptr->set_execution_order(update_cnode_list); MS_LOG(INFO) << "End"; } @@ -373,7 +671,7 @@ void AscendStreamAssign::UpdateStreamSwitch(const NotNull &graph break; } } else { - MS_LOG(ERROR) << "independent stream switch exit, but independent stream is empty"; + MS_LOG(ERROR) << "Independent stream switch exit, but independent stream is empty"; } // update processed stream @@ -472,6 +770,77 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull &graph_ptr, + const CNodePtr &cur_cnode_ptr) { + auto cnode_ptr_list = graph_ptr->execution_order(); + auto &inputs = cur_cnode_ptr->inputs(); + auto it_pos = cnode_ptr_list.begin(); + for (size_t i = 1; i < inputs.size(); i++) { + if (inputs[i]->isa()) { + auto cnode = inputs[i]->cast(); + while (opt::IsNopNode(cnode)) { + cnode = cnode->inputs()[1]->cast(); + } + + auto it = std::find(it_pos, cnode_ptr_list.end(), cnode); + if (it != cnode_ptr_list.end()) { + it_pos = it; + } + } else { + continue; + } + } + + if (it_pos == cnode_ptr_list.begin() && *it_pos != inputs[1]) { + MS_LOG(EXCEPTION) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found"; + } + + MS_LOG(INFO) << "The las input of node:" << cur_cnode_ptr->DebugString() << " is:" << (*it_pos)->fullname_with_scope() + << "; name:" << (*it_pos)->DebugString(); + return *it_pos; +} + +// after memory reuse is correct, use this function +void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + vector cnodes; + CNodePtr cur_cnode_ptr = nullptr; + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (i == 0) { + cnodes.emplace_back(cur_cnode_ptr); + continue; + } + + if (!IsHcom(cur_cnode_ptr)) { + cnodes.emplace_back(cur_cnode_ptr); + continue; + } + + // get the input which located in the lastr exe orders + auto last_input_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr); + auto it = std::find(cnodes.begin(), cnodes.end(), last_input_cnode); + if (it == cnodes.end()) { + MS_LOG(ERROR) << "hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) + << "get last input:" << AnfAlgo::GetCNodeName(last_input_cnode) << "; but last input not in cnodes"; + } else { + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + auto last_stream_id = AnfAlgo::GetStreamId(last_input_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, last_stream_id); + cnodes.insert(it + 1, send); + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id); + cnodes.emplace_back(recv); + cnodes.emplace_back(cur_cnode_ptr); + } + } + + graph_ptr->set_execution_order(cnodes); + MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num(); +} + void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull &graph_ptr) { AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); @@ -641,7 +1010,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNullDebugString() << "]"; + MS_LOG(DEBUG) << "Deal independent op[" << (*it)->DebugString() << "]"; CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); it = cnodes.insert(it + 1, send_cnode_ptr); @@ -690,7 +1059,7 @@ void AscendStreamAssign::GetIndependentMaxTarget(const NotNull & for (size_t k = 1; k < new_inputs.size(); k++) { auto new_real_input = AnfAlgo::VisitKernel(new_inputs[k], 0); if (key == new_real_input.first.get()) { - MS_LOG(INFO) << "Nop node find max target op:" << AnfAlgo::GetCNodeName(cur_node); + MS_LOG(DEBUG) << "Nop node find max target op:" << AnfAlgo::GetCNodeName(cur_node); independent_targets_.emplace(target_node.get()); flag = true; break; @@ -699,7 +1068,7 @@ void AscendStreamAssign::GetIndependentMaxTarget(const NotNull & } else { auto real_input = AnfAlgo::VisitKernel(input, 0); if (key == real_input.first.get()) { - MS_LOG(INFO) << "Find max target op:" << AnfAlgo::GetCNodeName(cur_node); + MS_LOG(DEBUG) << "Find max target op:" << AnfAlgo::GetCNodeName(cur_node); independent_targets_.emplace(target_node.get()); flag = true; } @@ -772,7 +1141,7 @@ void AscendStreamAssign::InsertCtrlForIndependentParallel(const NotNullexecution_order(); if (max_index >= exe_orders.size()) { - MS_LOG(EXCEPTION) << "max target index:" << max_index << " is greater than graph orders size:" << exe_orders.size(); + MS_LOG(EXCEPTION) << "Max target index:" << max_index << " is greater than graph orders size:" << exe_orders.size(); } auto max_node_stream = AnfAlgo::GetStreamId(exe_orders[max_index]); @@ -813,16 +1182,19 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull &gra } // 2)independent stream:if has not been activate, push to need active vector + auto root_graph_id = graph_ptr->graph_id(); if (!independent_stream_activated_) { - for (auto &item : independent_stream_map_) { - need_first_active_streams_.emplace_back(item.first); + auto it = independent_graph_map_.find(root_graph_id); + if (it != independent_graph_map_.end()) { + need_first_active_streams_.push_back(*(it->second.begin())); } } // 3)hcom stream:if has not been activate, push to need active vector if (!hcom_stream_activated_) { - for (auto &item : hcom_stream_map_) { - need_first_active_streams_.emplace_back(item.first); + auto it = hcom_graph_map_.find(root_graph_id); + if (it != hcom_graph_map_.end()) { + std::copy(it->second.begin(), it->second.end(), std::back_inserter(need_first_active_streams_)); } } @@ -831,6 +1203,10 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull &gra if (it == need_first_active_streams_.end()) { need_first_active_streams_.emplace_back(0); } + MS_LOG(INFO) << "Finally, need active first stream include:"; + for (const auto &item : need_first_active_streams_) { + MS_LOG(INFO) << "stream id:" << item; + } } // section8 @@ -977,14 +1353,14 @@ vector::iterator AscendStreamAssign::FindTargetOp(vector::it for (size_t j = 1; j < new_inputs.size(); j++) { auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0); if (node == new_real_input.first) { - MS_LOG(INFO) << "Nop node find target op[" << (*begin)->DebugString() << "]"; + MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]"; return begin; } } } else { auto real_input = AnfAlgo::VisitKernel(input, 0); if (node == real_input.first) { - MS_LOG(INFO) << "Find target op[" << (*begin)->DebugString() << "]"; + MS_LOG(DEBUG) << "Find target op[" << (*begin)->DebugString() << "]"; return begin; } } @@ -1040,6 +1416,7 @@ void AscendStreamAssign::GetHcomStreams(std::vector *streams) { void AscendStreamAssign::Reset() { independent_stream_activated_ = false; hcom_stream_activated_ = false; + loop_sink_ = false; independent_stream_map_.clear(); hcom_stream_map_.clear(); common_stream_map_.clear(); @@ -1049,6 +1426,9 @@ void AscendStreamAssign::Reset() { stream_relations_.clear(); event_map_.clear(); independent_targets_.clear(); + independent_graph_map_.clear(); + hcom_graph_map_.clear(); + middle_active_streams_.clear(); } // section 10 @@ -1101,7 +1481,12 @@ void AscendStreamAssign::DFS(uint32_t start, std::vector *group) { } void AscendStreamAssign::GetStreamRelations() { - for (const auto &start : need_first_active_streams_) { + auto starts = middle_active_streams_; + for (const auto &stream : need_first_active_streams_) { + starts.emplace(stream); + } + + for (const auto &start : starts) { vector group{start}; DFS(start, &group); } @@ -1188,7 +1573,8 @@ void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull &graph_ptr); void AssignAllNodesStream(const NotNull &graph_ptr); void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr); - void AssignHcomStreamId(const CNodePtr &cur_cnode_ptr); - void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr); + uint32_t AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph); + uint32_t AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph); void UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr); void FindHcomParallelStreams(const NotNull &graph_ptr); void InsertStreamActive(const NotNull &graph_ptr); + void InsertStreamActiveForCommon(const NotNull &graph_ptr); + void InsertStreamActiveForIndependent(const NotNull &graph_ptr); + void InsertStreamActiveForParallel(const NotNull &graph_ptr); + void ActiveRootGraphHcom(const NotNull &graph_ptr, const std::set &hcom_streams); + void ActiveRootGraphIndependent(const NotNull &graph_ptr, std::set independent_streams); + void ActiveOtherGraphParallel(const NotNull &graph_ptr, + std::map> other_graph); void UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, vector *orders); void InsertEventForIndependentParallel(const NotNull &graph_ptr); @@ -135,9 +142,11 @@ class AscendStreamAssign { void InsertEventForHcomParallel(const NotNull &graph_ptr); void InsertEventCommonDependHcom(const NotNull &graph_ptr); void InsertEventHcomDependCommon(const NotNull &graph_ptr); + void InsertEventHcomDependCommonBak(const NotNull &graph_ptr); void InsertEventHcomDependHcom(const NotNull &graph_ptr); void InsertEventBetweenHcom(const NotNull &graph_ptr, const map> &hcom_index, uint32_t first_hcom_stream, uint32_t last_hcom_stream); + CNodePtr GetLastInputCnode(const NotNull &graph_ptr, const CNodePtr &cur_cnode_ptr); bool IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, size_t index); void GetProcessedStream(const NotNull &graph_ptr); @@ -155,6 +164,7 @@ class AscendStreamAssign { vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, const CNodePtr &node); void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); + void SetLoopSink(); // function for memory resue void GetStreamRelations(); @@ -172,17 +182,23 @@ class AscendStreamAssign { bool independent_stream_activated_{false}; bool hcom_stream_activated_{false}; + bool loop_sink_{false}; + // key:stream id, value:task nums; std::map independent_stream_map_{}; std::map hcom_stream_map_{}; std::map common_stream_map_{}; std::set processed_streams_{}; std::vector need_first_active_streams_{}; std::set independent_targets_; + // key:graph id, value:stream set + std::map> hcom_graph_map_; + std::map> independent_graph_map_; // attr for memory copy reuse std::map> stream_relations_{}; std::vector> stream_groups_{}; - std::map event_map_; + std::map event_map_{}; + std::set middle_active_streams_{}; // new policy end }; } // namespace ascend