|
|
|
@ -25,6 +25,10 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
|
|
|
|
|
in.place().which(), dst_place.which(),
|
|
|
|
|
"Currently, model parallelism is only supported between CPU and CUDA");
|
|
|
|
|
|
|
|
|
|
// NOTE(yy): TransDataDevice should wait for computation of input.
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(in.place())->Wait();
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
|
|
|
|
|
|
|
|
|
|
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
|
|
|
|
|
// the enforced checkings have been done in GetDeviceContext, so the
|
|
|
|
|
// `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
|
|
|
|
|