refine reshape grad and double grad kernel, use tensor copy async (#29128)

revert-31562-mean
Leo Chen 4 years ago committed by GitHub
parent f7b45fd694
commit 4e19ce1df5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -405,7 +405,9 @@ class ReshapeGradKernel {
auto in_dims = d_x->dims();
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
framework::TensorCopy(
*d_out, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_x);
d_x->Resize(in_dims);
}
};
@ -419,7 +421,9 @@ class ReshapeDoubleGradKernel {
auto out_dims = dd_out->dims();
dd_out->mutable_data(ctx.GetPlace(), dd_x->type());
framework::TensorCopySync(*dd_x, ctx.GetPlace(), dd_out);
framework::TensorCopy(
*dd_x, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dd_out);
dd_out->Resize(out_dims);
}
};

Loading…
Cancel
Save