diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index 311abb634e..69b9e00457 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -260,6 +260,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { auto anf_node_list = graph->execution_order(); TaskGenerator::GenTasks(anf_node_list, &task_info_list, graph->graph_id()); + // Store the task_info_list + auto iter = task_map_.find(graph); + if (iter != task_map_.end()) { + MS_LOG(EXCEPTION) << "graph TaskInfo list already exist"; + } + task_map_[graph] = task_info_list; + + // Graph may have no compute node, such TensorAddGrad. + if (task_info_list.empty()) { + MS_LOG(WARNING) << "graph " << graph->graph_id() << " have no compute node"; + return true; + } + AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); // the streams' flag not HEAD_STREAM std::vector wait_active_stream_list = assign_instance.GetWaitStreams(); @@ -278,10 +291,6 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { graph_model_map_[graph] = model; graph_model_id_map_[graph] = graph->graph_id(); MS_LOG(INFO) << "TaskGenerator GetTaskInfo end..."; - - // Store the task_info_list - task_map_.insert(std::make_pair(graph, task_info_list)); - return true; } @@ -305,6 +314,11 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { return true; } + if (GraphWithEmptyTaskList(graph)) { + MS_LOG(WARNING) << "LoadTask end, task list is empty"; + return true; + } + auto task_iter = graph_model_map_.find(graph); if (task_iter == graph_model_map_.end()) { MS_LOG(ERROR) << "task not exist"; @@ -333,6 +347,11 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(context_ptr); ge::InputData input_tensors = ge::InputData(); ge::OutputData *output_tensors = nullptr; + if (GraphWithEmptyTaskList(graph)) { + MS_LOG(WARNING) << "RunTask end, no task info found"; + return true; + } + auto model_id = GetGraphModelId(graph); bool status = ge::model_runner::ModelRunner::Instance().RunModel(model_id, input_tensors, output_tensors); if (!status) { @@ -468,6 +487,14 @@ bool AscendKernelRuntime::DestroyHccl() { context_ptr->set_enable_hccl(false); return true; } + +bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph *graph) const { + auto iter = task_map_.find(graph); + if (iter == task_map_.end()) { + MS_LOG(EXCEPTION) << "Unknown graph ptr"; + } + return iter->second.empty(); +} } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h index 0eedad3d2b..547228d32f 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h @@ -55,6 +55,8 @@ class AscendKernelRuntime : public KernelRuntime { void ClearGraphModelMap(); void ReleaseDeviceRes() override; uint32_t GetGraphModelId(const session::KernelGraph *kernel_graph); + bool GraphWithEmptyTaskList(const session::KernelGraph *graph) const; + rtContext_t rt_context_{nullptr}; bool initialized_{false}; unordered_map>> task_map_;