|
|
|
@ -999,16 +999,16 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
|
|
|
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
|
|
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
|
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
|
|
|
|
|
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
|
|
|
|
|
tensor->set_device_address(device_address);
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(device_address);
|
|
|
|
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0),
|
|
|
|
|
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
|
|
|
|
tensor->data_c())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
|
|
|
|
}
|
|
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
|
|
|
|
|
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
|
|
|
|
|
tensor->set_device_address(device_address);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
tensor->set_sync_status(kNoNeedSync);
|
|
|
|
|
}
|
|
|
|
|