From db985ab909a372641bdfd256b323bf6e81bf8ad5 Mon Sep 17 00:00:00 2001 From: kswang Date: Mon, 12 Oct 2020 11:29:13 +0800 Subject: [PATCH] add ms exception --- mindspore/ccsrc/backend/session/executor.cc | 91 ++++++------------- mindspore/ccsrc/backend/session/executor.h | 3 +- .../ccsrc/backend/session/session_basic.cc | 5 +- .../runtime/device/cpu/cpu_kernel_runtime.cc | 3 +- mindspore/core/ir/tensor.h | 2 + mindspore/core/utils/ms_exception.h | 48 ++++++++++ 6 files changed, 79 insertions(+), 73 deletions(-) create mode 100644 mindspore/core/utils/ms_exception.h diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 3e5c70f133..f0661e5338 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "backend/session/executor.h" +#include #include "runtime/device/kernel_runtime_manager.h" #include "backend/session/executor_manager.h" #include "utils/comm_manager.h" @@ -40,10 +41,7 @@ void UpdateOutputTensors(const VectorRef *outputs, tensor->set_device_address(address); } if (tensor->NeedSyncDeviceToHostImmediately()) { - auto tensor_address = tensor->device_address(); - MS_EXCEPTION_IF_NULL(tensor_address); - tensor_address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c()); + tensor->data_sync(false); tensor->set_device_address(nullptr); tensor->set_sync_status(kNeedSyncHostToDevice); } @@ -85,7 +83,11 @@ void BuildGraphTask::Run() { void RunGraphTask::Run() { MS_EXCEPTION_IF_NULL(session_); - session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); + try { + session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); + } catch (const std::exception &e) { + MsException::GetInstance().SetException(); + } UpdateOutputTensors(&outputs_, tensor_to_node_); for (auto &tensor : input_need_lock_tensors_) { tensor->SetNeedWait(false); @@ -115,14 +117,6 @@ Executor::Executor(const std::string &device_name, uint32_t device_id) { Executor::~Executor() { WorkerJoin(); } -void Executor::CheckException() { - if (exception_ptr_ != nullptr) { - auto exception_ptr = exception_ptr_; - exception_ptr_ = nullptr; - std::rethrow_exception(exception_ptr); - } -} - void Executor::WorkerJoin() { // Avoid worker thread join itself which will cause deadlock if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) { @@ -152,7 +146,7 @@ void Executor::WorkerLoop() { try { task->Run(); } catch (const std::exception &e) { - exception_ptr_ = std::current_exception(); + MsException::GetInstance().SetException(); } if (task->type_ != kRunGraph || task->sync_run_) { task = nullptr; @@ -200,48 +194,40 @@ bool Executor::IsTaskReady(const std::shared_ptr &task) { return true; } -GraphId Executor::CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { - CheckException(); +void Executor::SyncRunTask(const std::shared_ptr &task) { std::unique_lock lock(task_mutex_); + ready_tasks_.push(task); + task_cond_var_.notify_all(); + sync_cond_var_.wait(lock); + MsException::GetInstance().CheckException(); +} + +GraphId Executor::CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { auto task = std::make_shared(); task->session_ = session; task->nodes_ = lst; task->output_nodes_ = outputs; - ready_tasks_.push(task); - task_cond_var_.notify_all(); - sync_cond_var_.wait(lock); - CheckException(); + SyncRunTask(task); return task->graph_id_; } GraphId Executor::CompileGraph(const SessionPtr &session, NotNull func_graph) { - CheckException(); - std::unique_lock lock(task_mutex_); auto task = std::make_shared(); task->session_ = session; task->func_graph_ = func_graph; - ready_tasks_.push(task); - task_cond_var_.notify_all(); - sync_cond_var_.wait(lock); - CheckException(); + SyncRunTask(task); return task->graph_id_; } void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) { - CheckException(); - std::unique_lock lock(task_mutex_); auto task = std::make_shared(); task->session_ = session; task->graph_id_ = graphId; - ready_tasks_.push(task); - task_cond_var_.notify_all(); - sync_cond_var_.wait(lock); - CheckException(); + SyncRunTask(task); } void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { - CheckException(); MS_EXCEPTION_IF_NULL(session); MS_EXCEPTION_IF_NULL(outputs); auto task = std::make_shared(); @@ -251,30 +237,25 @@ void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id, session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_); task->outputs_ = *outputs; task->sync_run_ = true; - std::unique_lock lock(task_mutex_); - ready_tasks_.push(task); - task_cond_var_.notify_all(); mindspore::ScopedLongRunning long_running; - sync_cond_var_.wait(lock); - CheckException(); + SyncRunTask(task); } void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { - CheckException(); MS_EXCEPTION_IF_NULL(session); MS_EXCEPTION_IF_NULL(outputs); auto task = std::make_shared(); task->session_ = session; task->graph_id_ = graph_id; task->input_tensors_ = inputs; + task->input_need_lock_tensors_ = session->GetNeedLockInputTensors(graph_id, inputs); // lock inputs for (auto &tensor : inputs) { if (tensor->NeedWait()) { task->input_need_wait_tensors_.emplace_back(tensor); } } - task->input_need_lock_tensors_ = session->GetNeedLockInputTensors(graph_id, inputs); for (auto &tensor : task->input_need_lock_tensors_) { tensor->SetNeedWait(true); } @@ -285,12 +266,8 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, // sync run graph without output tensor(int dataset graph) if (!TensorInVector(outputs)) { task->sync_run_ = true; - std::unique_lock lock(task_mutex_); - ready_tasks_.push(task); - task_cond_var_.notify_all(); mindspore::ScopedLongRunning long_running; - sync_cond_var_.wait(lock); - CheckException(); + SyncRunTask(task); return; } @@ -307,54 +284,38 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, void Executor::BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, const std::vector &tensors_mask) { - CheckException(); - std::unique_lock lock(task_mutex_); auto task = std::make_shared(); task->session_ = session; task->op_run_info_ = op_run_info; task->graph_info_ = graph_info; task->input_tensors_ = input_tensors; task->tensors_mask_ = tensors_mask; - ready_tasks_.push(task); - task_cond_var_.notify_all(); - sync_cond_var_.wait(lock); - CheckException(); + SyncRunTask(task); } void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, VectorRef *outputs) { - CheckException(); - std::unique_lock lock(task_mutex_); auto task = std::make_shared(); task->session_ = session; task->op_run_info_ = op_run_info; task->graph_info_ = graph_info; task->input_tensors_ = input_tensors; - ready_tasks_.push(task); - task_cond_var_.notify_all(); - sync_cond_var_.wait(lock); - CheckException(); + SyncRunTask(task); *outputs = task->outputs_; } bool Executor::CreateCommGroup(const std::string &group_name, std::vector ranks) { - std::unique_lock lock(task_mutex_); auto task = std::make_shared(); task->group_name_ = group_name; task->ranks_ = ranks; - ready_tasks_.push(task); - task_cond_var_.notify_all(); - sync_cond_var_.wait(lock); + SyncRunTask(task); return task->result_; } bool Executor::DestroyCommGroup(const std::string &group_name) { - std::unique_lock lock(task_mutex_); auto task = std::make_shared(); task->group_name_ = group_name; - ready_tasks_.push(task); - task_cond_var_.notify_all(); - sync_cond_var_.wait(lock); + SyncRunTask(task); return task->result_; } diff --git a/mindspore/ccsrc/backend/session/executor.h b/mindspore/ccsrc/backend/session/executor.h index b2ee2e7486..abc0f9f684 100644 --- a/mindspore/ccsrc/backend/session/executor.h +++ b/mindspore/ccsrc/backend/session/executor.h @@ -26,7 +26,6 @@ #include #include #include -#include #include "backend/session/session_basic.h" #include "ir/anf.h" #include "ir/tensor.h" @@ -168,6 +167,7 @@ class Executor { bool DestroyCommGroup(const std::string &group_name); private: + void SyncRunTask(const std::shared_ptr &task); void UpdateOutputTensors(VectorRef *outputs, const std::map &tensor_to_node); std::vector> GetNewReadyTasks(); @@ -184,7 +184,6 @@ class Executor { std::queue> ready_tasks_; std::list> pending_tasks_; std::shared_ptr worker_; - std::exception_ptr exception_ptr_{nullptr}; }; } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 5bc577a1f3..67eac35de5 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -954,10 +954,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor } auto tensor_address = tensor->device_address(); if (tensor_address != device_address) { - if (tensor_address != nullptr) { - tensor_address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c()); - } + tensor->data_sync(false); return true; } return false; diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index f50e68603c..e9ab2628e8 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -251,8 +251,7 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(tensor); if (tensor_address != nullptr && tensor_address != address) { - tensor_address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c()); + tensor->data_sync(false); } if (tensor->data_type() == address->type_id_ || tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32 || tensor->data_type() == kNumberTypeInt64) { diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 7f65480e70..dc7c0ec89c 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -29,6 +29,7 @@ #include "utils/log_adapter.h" #include "base/float16.h" #include "utils/shape_utils.h" +#include "utils/ms_exception.h" // brief mindspore namespace. // @@ -88,6 +89,7 @@ struct WaitEvent { return; } cond_var_.wait(lock, [this] { return !need_wait_; }); + MsException::GetInstance().CheckException(); } void set_need_wait(bool need_wait) { diff --git a/mindspore/core/utils/ms_exception.h b/mindspore/core/utils/ms_exception.h new file mode 100644 index 0000000000..e4454878c3 --- /dev/null +++ b/mindspore/core/utils/ms_exception.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_ +#define MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_ +#include +#include "utils/ms_utils.h" +namespace mindspore { +class MsException { + public: + static MsException &GetInstance() { + static MsException instance; + return instance; + } + + void SetException() { exception_ptr_ = std::current_exception(); } + + void CheckException() { + if (exception_ptr_ != nullptr) { + auto exception_ptr = exception_ptr_; + exception_ptr_ = nullptr; + std::rethrow_exception(exception_ptr); + } + } + + private: + MsException() = default; + ~MsException() = default; + DISABLE_COPY_AND_ASSIGN(MsException) + + std::exception_ptr exception_ptr_{nullptr}; +}; +} // namespace mindspore + +#endif // MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_