fix pynative grad error

pull/7941/head
kswang 4 years ago
parent 2ba20dc8e7
commit 5c9f738bfc

@ -988,16 +988,16 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address);
}
MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address);
}
}
tensor->set_sync_status(kNoNeedSync);
}

@ -580,7 +580,8 @@ void Tensor::data_sync(bool need_wait) const {
if (device_sync_ == nullptr) {
return;
}
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
auto address = device_sync_;
if (!address->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
}
sync_status_ = kNeedSyncHostToDevice;

Loading…
Cancel
Save