Polish reshape op

wangkuiyi-patch-2
Yu Yang 7 years ago
parent 129859e732
commit 52987902c9

@ -119,13 +119,15 @@ class ReshapeKernel : public framework::OpKernel<T> {
auto *shape_tensor = ctx.Input<framework::LoDTensor>("Shape"); auto *shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
framework::DDim out_dims = out->dims(); framework::DDim out_dims = out->dims();
if (shape_tensor) { if (shape_tensor) {
auto *shape_data = shape_tensor->data<int>(); auto *shape_data = shape_tensor->data<int>();
if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor cpu_shape_tensor; framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
TensorCopy(*shape_tensor, platform::CPUPlace(), ctx.device_context(), TensorCopy(*shape_tensor, platform::CPUPlace(), ctx.device_context(),
&cpu_shape_tensor); &cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>(); shape_data = cpu_shape_tensor.data<int>();
ctx.device_context().Wait();
} }
auto shape = auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel()); std::vector<int>(shape_data, shape_data + shape_tensor->numel());

Loading…
Cancel
Save