Merge pull request #9840 from reyoung/feature/polish_reshape_op

Polish reshape op
wangkuiyi-patch-2
Yu Yang 8 years ago committed by GitHub
commit 06ddaa73f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -119,13 +119,15 @@ class ReshapeKernel : public framework::OpKernel<T> {
auto *shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
framework::DDim out_dims = out->dims();
if (shape_tensor) {
auto *shape_data = shape_tensor->data<int>();
if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
TensorCopy(*shape_tensor, platform::CPUPlace(), ctx.device_context(),
&cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>();
ctx.device_context().Wait();
}
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());

Loading…
Cancel
Save