diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc index 20e03b682f..a2fe9c9878 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc @@ -25,8 +25,8 @@ namespace kernel { bool HcomAllReduceKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs, void *stream_ptr) { MS_LOG(INFO) << "HcclAllReduce launch"; - if (inputs.size() != 1 || outputs.size() != 1) { - MS_LOG(ERROR) << "AllReduce input output size must be 1"; + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "Invalid AllReduce input output size(" << inputs.size() << ", " << outputs.size() << ")."; return false; } MS_EXCEPTION_IF_NULL(inputs[0]); diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index c0a35223f2..dcc4b785c9 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -475,7 +475,8 @@ bool TaskEmitAction(const ResourcePtr &res) { auto context_ptr = MsContext::GetInstance(); std::string backend = MsContext::GetInstance()->backend_policy(); MS_EXCEPTION_IF_NULL(context_ptr); - if (func_graph->ContainMultiTarget()) { + auto task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); + if (func_graph->ContainMultiTarget() || !task_sink) { bc_ptr->set_is_multi_graph_sink(false); context_ptr->set_param(MS_CTX_IS_MULTI_GRAPH_SINK, false); context_ptr->set_param(MS_CTX_ENABLE_LOOP_SINK, false); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index bf4f90e778..4fc0db50df 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -35,6 +35,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/ascend/profiling/profiling_utils.h" #include "runtime/device/ascend/ascend_memory_manager.h" +#include "runtime/device/ascend/ascend_event.h" #include "debug/data_dump/dump_json_parser.h" #include "toolchain/adx_datadump_server.h" #include "utils/trace_base.h" @@ -154,7 +155,7 @@ void AscendKernelRuntime::ClearGraphModelMap() { DumpJsonParser::GetInstance().PrintUnusedKernel(); graph_dynamic_kernel_map_.clear(); - + graph_kernel_events_map_.clear(); for (auto &iter : graph_model_map_) { MS_LOG(INFO) << "Ge UnloadModel " << iter.first; auto ret = ModelRunner::Instance().UnloadModel(iter.first); @@ -186,7 +187,10 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std MS_LOG(DEBUG) << "Start Clear graph:" << graph_id << " dynamic kernel"; graph_dynamic_kernel_map_.erase(dynamic_kernel_iter); } - + auto events_iter = graph_kernel_events_map_.find(graph_id); + if (events_iter != graph_kernel_events_map_.end()) { + graph_kernel_events_map_.erase(events_iter); + } MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) { MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id; @@ -340,9 +344,9 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { if (!is_task_sink) { + GenKernelEvents(graph); return true; } - // Do HcomExecutorInitialize if (graph->is_dynamic_shape() && !HcclExecutorManager::GetInstance().Initialize()) { MS_LOG(ERROR) << "Init Hccl Executor Failed"; @@ -357,6 +361,58 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { return true; } +void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + if (kernels.empty()) { + return; + } + auto kernel_events = + std::pair>>, std::vector>>>(); + auto &kernel_pre_run_events = kernel_events.first; + auto &kernel_post_run_events = kernel_events.second; + kernel_pre_run_events.resize(kernels.size()); + kernel_post_run_events.resize(kernels.size()); + for (size_t i = 0; i < kernels.size(); ++i) { + auto &kernel = kernels[i]; + if (!AnfAlgo::IsCommunicationOp(kernel)) { + continue; + } + auto pre_event = std::make_shared(); + auto post_event = std::make_shared(); + pre_event->set_wait_stream(communication_stream_); + pre_event->set_record_stream(stream_); + post_event->set_wait_stream(stream_); + post_event->set_record_stream(communication_stream_); + kernel_pre_run_events[i].emplace_back([pre_event]() { + pre_event->RecordEvent(); + pre_event->WaitEvent(); + }); + kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); }); + bool found_nearest_child = false; + for (size_t j = i + 1; j < kernels.size(); ++j) { + auto &child = kernels[j]; + MS_EXCEPTION_IF_NULL(child); + auto input_size = child->inputs().size() - 1; + for (size_t k = 0; k < input_size; ++k) { + auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0); + if (kernel_index.first == kernel) { + found_nearest_child = true; + break; + } + } + if (found_nearest_child) { + kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); }); + break; + } + } + if (!found_nearest_child) { + kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); }); + } + } + graph_kernel_events_map_[graph->graph_id()] = std::move(kernel_events); +} + bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "GenDynamicKernel start"; @@ -374,7 +430,7 @@ bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) { dynamic_kernel->Initialize(); dynamic_kernels.emplace_back(dynamic_kernel); } - graph_dynamic_kernel_map_[graph->graph_id()] = dynamic_kernels; + graph_dynamic_kernel_map_[graph->graph_id()] = std::move(dynamic_kernels); MS_LOG(INFO) << "GenDynamicKernel end"; return true; } @@ -852,8 +908,9 @@ bool AscendKernelRuntime::HcclInit() { MS_LOG(ERROR) << "Hcom init failed."; return false; } - if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { - MS_LOG(INFO) << "PyNative hccl init"; + auto task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); + if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode || !task_sink) { + MS_LOG(INFO) << "Hccl comm init."; return kernel::HcclContext::GetInstance().InitHccl(); } return true; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index 96ddd30c16..48cde66bdc 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -67,6 +67,7 @@ class AscendKernelRuntime : public KernelRuntime { bool KernelMemNotReuse(const AnfNodePtr &node) override; void KernelLaunchProfiling(const std::string &kernel_name) override; + void GenKernelEvents(const session::KernelGraph *graph) override; private: bool InitDevice(); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 9aecae33e7..c4e2bf52d5 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -31,6 +31,7 @@ #include "utils/shape_utils.h" #include "utils/utils.h" #include "frontend/parallel/context.h" + #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/ps_cache/ps_cache_manager.h" #endif @@ -920,6 +921,16 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList } } +void KernelRuntime::LaunchKernelEvent(const std::vector>> &kernel_events, + size_t index) { + if (index >= kernel_events.size()) { + return; + } + for (auto &event : kernel_events[index]) { + event(); + } +} + bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { const auto &kernels = graph.execution_order(); std::vector dynamic_kernel_list; @@ -931,12 +942,21 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { MS_LOG(EXCEPTION) << "The size of dynamic kernels " << dynamic_kernel_list.size() << " should be equal to the size of kernels " << kernels.size(); } + std::vector>> kernel_pre_run_events; + std::vector>> kernel_post_run_events; + auto events_iter = graph_kernel_events_map_.find(graph.graph_id()); + if (events_iter != graph_kernel_events_map_.end()) { + kernel_pre_run_events = events_iter->second.first; + kernel_post_run_events = events_iter->second.second; + } for (size_t i = 0; i < kernels.size(); ++i) { + LaunchKernelEvent(kernel_pre_run_events, i); if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr && dynamic_kernel_list[i]->is_dynamic_shape()) { dynamic_kernel_list[i]->InferShape(); dynamic_kernel_list[i]->UpdateArgs(); dynamic_kernel_list[i]->Execute(); + if (!SyncStream()) { MS_LOG(ERROR) << "SyncStream failed"; return false; @@ -958,20 +978,23 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { } continue; } - AddressPtrList kernel_inputs; AddressPtrList kernel_workspaces; AddressPtrList kernel_outputs; GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); - - auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); + bool ret; + if (AnfAlgo::IsCommunicationOp(kernel)) { + ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, communication_stream_); + } else { + ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); + } if (!ret) { MS_LOG(ERROR) << "Launch kernel failed."; return false; } - KernelLaunchProfiling(kernels[i]->fullname_with_scope()); } + LaunchKernelEvent(kernel_post_run_events, i); } return true; } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index f4f73e3f15..071df2bae7 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "runtime/device/device_address.h" #include "ir/tensor.h" @@ -132,10 +133,12 @@ class KernelRuntime { void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node); virtual void KernelLaunchProfiling(const std::string &kernel_name) {} + virtual void GenKernelEvents(const session::KernelGraph *graph) {} private: void AssignStaticMemoryOutput(const session::KernelGraph *graph); bool LaunchKernelMod(const session::KernelGraph &graph); + void LaunchKernelEvent(const std::vector>> &run_events, size_t index); static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); void RunOpAssignInputMemory(const std::vector &input_tensors, const session::KernelGraph *graph); @@ -160,6 +163,9 @@ class KernelRuntime { void *communication_stream_{nullptr}; std::shared_ptr mem_manager_{nullptr}; std::map> graph_dynamic_kernel_map_; + std::map>>, std::vector>>>> + graph_kernel_events_map_; std::vector> buffer_ptrs_ = {}; }; using KernelRuntimePtr = std::shared_ptr;