diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py index 2f73ced061..9a3846c4f9 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py @@ -161,5 +161,12 @@ class CompilerPool: ret = task_id, "Exception: Not support return type:" + str(ret_type) return ret + def reset_task_info(self): + """ + reset task info when task compile error + """ + if self.__running_tasks: + self.__running_tasks.clear() + compile_pool = CompilerPool() diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc index 16ddec1b4a..3e7452031b 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc @@ -40,6 +40,7 @@ constexpr auto kParallelCompileModule = "mindspore._extends.parallel_compile.tbe constexpr auto kCreateParallelCompiler = "create_tbe_parallel_compiler"; constexpr auto kStartCompileOp = "start_compile_op"; constexpr auto kWaitOne = "wait_one"; +constexpr auto kResetTaskInfo = "reset_task_info"; bool TbeOpParallelBuild(std::vector anf_nodes) { auto build_manger = std::make_shared(); @@ -96,6 +97,8 @@ bool TbeOpParallelBuild(std::vector anf_nodes) { ParallelBuildManager::ParallelBuildManager() { tbe_parallel_compiler_ = TbePythonFuncs::TbeParallelCompiler(); } +ParallelBuildManager::~ParallelBuildManager() { ResetTaskInfo(); } + int32_t ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) const { PyObject *pRes = nullptr; PyObject *pArgs = PyTuple_New(1); @@ -234,5 +237,16 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s kernel_mod_ptr->SetWorkspaceSizeList(kernel_json_info.workspaces); return kernel_mod_ptr; } + +void ParallelBuildManager::ResetTaskInfo() { + if (task_map_.empty()) { + MS_LOG(INFO) << "All tasks are compiled success."; + return; + } + task_map_.clear(); + same_op_list_.clear(); + PyObject *pArg = Py_BuildValue("()"); + (void)PyObject_CallMethod(tbe_parallel_compiler_, kResetTaskInfo, "O", pArg); +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h index 5066e9457f..776aa0b1fc 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h @@ -40,7 +40,7 @@ struct KernelBuildTaskInfo { class ParallelBuildManager { public: ParallelBuildManager(); - ~ParallelBuildManager() = default; + ~ParallelBuildManager(); int32_t StartCompileOp(const nlohmann::json &kernel_json) const; void SaveTaskInfo(int32_t task_id, const AnfNodePtr &anf_node, const std::string &json_name, const std::vector &input_size_list, const std::vector &output_size_list, @@ -58,6 +58,7 @@ class ParallelBuildManager { KernelModPtr GenKernelMod(const string &json_name, const string &processor, const std::vector &input_size_list, const std::vector &output_size_list, const KernelPackPtr &kernel_pack) const; + void ResetTaskInfo(); private: PyObject *tbe_parallel_compiler_;