|
|
|
@ -101,12 +101,22 @@ void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind)
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
|
|
|
|
auto execution_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);
|
|
|
|
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(runtime_instance);
|
|
|
|
|
runtime_instance->SetContext();
|
|
|
|
|
auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind);
|
|
|
|
|
if (ret_rt_memcpy != RT_ERROR_NONE) {
|
|
|
|
|
MS_EXCEPTION(DeviceProcessError) << "rtMemcpy failed";
|
|
|
|
|
|
|
|
|
|
// Only apply asynchronous copy in Pynative && RT_MEMCPY_HOST_TO_DEVICE mode
|
|
|
|
|
if (execution_mode != kPynativeMode || kind != RT_MEMCPY_HOST_TO_DEVICE) {
|
|
|
|
|
auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind);
|
|
|
|
|
if (ret_rt_memcpy != RT_ERROR_NONE) {
|
|
|
|
|
MS_EXCEPTION(DeviceProcessError) << "rtMemcpy failed";
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto ret = runtime_instance->MemcpyAsync(dst, src, size, static_cast<int32_t>(kind));
|
|
|
|
|
if (!ret) {
|
|
|
|
|
MS_EXCEPTION(DeviceProcessError) << "MemcpyAsync failed";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -527,7 +537,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size
|
|
|
|
|
if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
SyncStream();
|
|
|
|
|
|
|
|
|
|
bool sync_ok = false;
|
|
|
|
|
std::vector<size_t> host_shape;
|
|
|
|
|
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), LongToSize);
|
|
|
|
|