|
|
@ -38,7 +38,8 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
|
|
|
|
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
|
|
|
|
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
|
|
|
|
|
|
|
|
|
|
|
|
TensorCopy(in, dst_place, *dev_ctx, out);
|
|
|
|
TensorCopy(in, dst_place, *dev_ctx, out);
|
|
|
|
if (platform::is_gpu_place(in.place()) && platform::is_cpu_place(dst_place)) {
|
|
|
|
|
|
|
|
|
|
|
|
if (in.place().which() != dst_place.which()) {
|
|
|
|
dev_ctx->Wait();
|
|
|
|
dev_ctx->Wait();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|