|
|
@ -36,9 +36,11 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
|
|
|
|
VLOG(3) << "DeviceTransform in, src_place " << in.place()
|
|
|
|
VLOG(3) << "DeviceTransform in, src_place " << in.place()
|
|
|
|
<< " dst_place: " << dst_place;
|
|
|
|
<< " dst_place: " << dst_place;
|
|
|
|
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
|
|
|
|
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
|
|
|
|
dev_ctx->Wait();
|
|
|
|
|
|
|
|
TensorCopy(in, dst_place, *dev_ctx, out);
|
|
|
|
TensorCopy(in, dst_place, *dev_ctx, out);
|
|
|
|
dev_ctx->Wait();
|
|
|
|
if (platform::is_gpu_place(in.place()) && platform::is_cpu_place(dst_place)) {
|
|
|
|
|
|
|
|
dev_ctx->Wait();
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|