diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index ca9041e133..bff938496e 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -40,7 +40,7 @@ void UpdateOutputTensors(VectorRef *outputs, } if (tensor->NeedSyncDeviceToHostImmediately()) { tensor->data_sync(); - tensor->set_sync_status(kNoNeedSync); + tensor->set_device_address(nullptr); } } } @@ -112,7 +112,9 @@ Executor::Executor(const std::string &device_name, uint32_t device_id) { void Executor::CheckException() { if (exception_ptr_ != nullptr) { - std::rethrow_exception(exception_ptr_); + auto exception_ptr = exception_ptr_; + exception_ptr_ = nullptr; + std::rethrow_exception(exception_ptr); } } diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index 7ea71f6868..c4c380c193 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -550,12 +550,14 @@ std::string Tensor::ToStringRepr() const { } void Tensor::data_sync() const { - const_cast(this)->Wait(); - if (device_sync_ != nullptr) { - if (!device_sync_->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { - MS_LOG(EXCEPTION) << "SyncDeviceToHost failed."; - } + Wait(); + if (device_sync_ == nullptr) { + return; + } + if (!device_sync_->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { + MS_LOG(EXCEPTION) << "SyncDeviceToHost failed."; } + sync_status_ = kNeedSyncHostToDevice; } TypeId Tensor::set_data_type(const TypeId data_type) { diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 3e20d78cad..1c4631b7f2 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -79,10 +79,10 @@ using TensorDataPtr = std::shared_ptr; struct WaitEvent { bool need_wait_{false}; - std::mutex mutex_; - std::condition_variable cond_var_; + mutable std::mutex mutex_; + mutable std::condition_variable cond_var_; - void Wait() { + void Wait() const { std::unique_lock lock(mutex_); if (!need_wait_) { return; @@ -285,7 +285,7 @@ class Tensor : public MetaTensor { return false; } - void Wait() { + void Wait() const { if (event_ != nullptr) { event_->Wait(); } @@ -307,7 +307,7 @@ class Tensor : public MetaTensor { TensorDataPtr data_{nullptr}; std::string id_{""}; std::shared_ptr event_{nullptr}; - TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; + mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; DeviceSyncPtr device_sync_{nullptr}; std::vector padding_type_; };