|
|
|
@ -405,7 +405,9 @@ class ReshapeGradKernel {
|
|
|
|
|
auto in_dims = d_x->dims();
|
|
|
|
|
|
|
|
|
|
d_x->mutable_data(ctx.GetPlace(), d_out->type());
|
|
|
|
|
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
|
|
|
|
|
framework::TensorCopy(
|
|
|
|
|
*d_out, ctx.GetPlace(),
|
|
|
|
|
ctx.template device_context<platform::DeviceContext>(), d_x);
|
|
|
|
|
d_x->Resize(in_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -419,7 +421,9 @@ class ReshapeDoubleGradKernel {
|
|
|
|
|
auto out_dims = dd_out->dims();
|
|
|
|
|
|
|
|
|
|
dd_out->mutable_data(ctx.GetPlace(), dd_x->type());
|
|
|
|
|
framework::TensorCopySync(*dd_x, ctx.GetPlace(), dd_out);
|
|
|
|
|
framework::TensorCopy(
|
|
|
|
|
*dd_x, ctx.GetPlace(),
|
|
|
|
|
ctx.template device_context<platform::DeviceContext>(), dd_out);
|
|
|
|
|
dd_out->Resize(out_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|