|
|
@ -37,11 +37,17 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
|
|
|
|
<< " dst_place: " << dst_place;
|
|
|
|
<< " dst_place: " << dst_place;
|
|
|
|
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
|
|
|
|
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
|
|
|
|
|
|
|
|
// the enforced checkings have been done in GetDeviceContext, so the
|
|
|
|
|
|
|
|
// `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
|
|
|
|
|
|
|
|
// slow, especially when the number of elements is little, for example,
|
|
|
|
|
|
|
|
// the elements of learning rate are one and it's CPU side.
|
|
|
|
|
|
|
|
// One solution is to use a CUDA kernel to complete the copy operation when
|
|
|
|
|
|
|
|
// the transforming is from CPU to GPU and the number of elements is little.
|
|
|
|
|
|
|
|
// But the embarrassment is that this solution this solution makes training
|
|
|
|
|
|
|
|
// slower.
|
|
|
|
TensorCopy(in, dst_place, *dev_ctx, out);
|
|
|
|
TensorCopy(in, dst_place, *dev_ctx, out);
|
|
|
|
|
|
|
|
dev_ctx->Wait();
|
|
|
|
if (in.place().which() != dst_place.which()) {
|
|
|
|
|
|
|
|
dev_ctx->Wait();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|