!7941 fix pynative grad error

From: @kisnwang
Reviewed-by: @jjfeing,@chujinjin
Signed-off-by: @jjfeing
pull/7941/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit beb86391fe

@ -999,16 +999,16 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_EXCEPTION_IF_NULL(input_node); MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address);
}
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) { tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
} }
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address);
}
} }
tensor->set_sync_status(kNoNeedSync); tensor->set_sync_status(kNoNeedSync);
} }

@ -581,7 +581,8 @@ void Tensor::data_sync(bool need_wait) const {
if (device_sync_ == nullptr) { if (device_sync_ == nullptr) {
return; return;
} }
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) { auto address = device_sync_;
if (!address->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost failed."; MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
} }
sync_status_ = kNeedSyncHostToDevice; sync_status_ = kNeedSyncHostToDevice;

Loading…
Cancel
Save