|
|
@ -119,13 +119,15 @@ class ReshapeKernel : public framework::OpKernel<T> {
|
|
|
|
auto *shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
|
|
|
|
auto *shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
|
|
|
|
|
|
|
|
|
|
|
|
framework::DDim out_dims = out->dims();
|
|
|
|
framework::DDim out_dims = out->dims();
|
|
|
|
|
|
|
|
|
|
|
|
if (shape_tensor) {
|
|
|
|
if (shape_tensor) {
|
|
|
|
auto *shape_data = shape_tensor->data<int>();
|
|
|
|
auto *shape_data = shape_tensor->data<int>();
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
|
|
|
framework::Tensor cpu_shape_tensor;
|
|
|
|
framework::Tensor cpu_shape_tensor;
|
|
|
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
TensorCopy(*shape_tensor, platform::CPUPlace(), ctx.device_context(),
|
|
|
|
TensorCopy(*shape_tensor, platform::CPUPlace(), ctx.device_context(),
|
|
|
|
&cpu_shape_tensor);
|
|
|
|
&cpu_shape_tensor);
|
|
|
|
shape_data = cpu_shape_tensor.data<int>();
|
|
|
|
shape_data = cpu_shape_tensor.data<int>();
|
|
|
|
|
|
|
|
ctx.device_context().Wait();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto shape =
|
|
|
|
auto shape =
|
|
|
|
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
|
|
|
|
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
|
|
|
|