|
|
|
@ -147,6 +147,7 @@ class ReshapeKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (!inplace) {
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
|
|
|
|
|
ctx.device_context().Wait();
|
|
|
|
|
// TensorCopy will resize to in_dims.
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
} else {
|
|
|
|
@ -169,6 +170,7 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto in_dims = d_x->dims();
|
|
|
|
|
if (!inplace) {
|
|
|
|
|
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
|
|
|
|
|
ctx.device_context().Wait();
|
|
|
|
|
d_x->Resize(in_dims);
|
|
|
|
|
} else {
|
|
|
|
|
d_x->ShareDataWith(*d_out);
|
|
|
|
|