!5664 reset tensor device_sync_

Merge pull request !5664 from kisnwang/async-run-graph
pull/5664/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2a491b5f82

@ -40,7 +40,7 @@ void UpdateOutputTensors(VectorRef *outputs,
}
if (tensor->NeedSyncDeviceToHostImmediately()) {
tensor->data_sync();
tensor->set_sync_status(kNoNeedSync);
tensor->set_device_address(nullptr);
}
}
}
@ -112,7 +112,9 @@ Executor::Executor(const std::string &device_name, uint32_t device_id) {
void Executor::CheckException() {
if (exception_ptr_ != nullptr) {
std::rethrow_exception(exception_ptr_);
auto exception_ptr = exception_ptr_;
exception_ptr_ = nullptr;
std::rethrow_exception(exception_ptr);
}
}

@ -550,12 +550,14 @@ std::string Tensor::ToStringRepr() const {
}
void Tensor::data_sync() const {
const_cast<Tensor *>(this)->Wait();
if (device_sync_ != nullptr) {
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
}
Wait();
if (device_sync_ == nullptr) {
return;
}
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
}
sync_status_ = kNeedSyncHostToDevice;
}
TypeId Tensor::set_data_type(const TypeId data_type) {

@ -79,10 +79,10 @@ using TensorDataPtr = std::shared_ptr<TensorData>;
struct WaitEvent {
bool need_wait_{false};
std::mutex mutex_;
std::condition_variable cond_var_;
mutable std::mutex mutex_;
mutable std::condition_variable cond_var_;
void Wait() {
void Wait() const {
std::unique_lock<std::mutex> lock(mutex_);
if (!need_wait_) {
return;
@ -285,7 +285,7 @@ class Tensor : public MetaTensor {
return false;
}
void Wait() {
void Wait() const {
if (event_ != nullptr) {
event_->Wait();
}
@ -307,7 +307,7 @@ class Tensor : public MetaTensor {
TensorDataPtr data_{nullptr};
std::string id_{""};
std::shared_ptr<WaitEvent> event_{nullptr};
TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
DeviceSyncPtr device_sync_{nullptr};
std::vector<Axis> padding_type_;
};

Loading…
Cancel
Save