Merge pull request #9892 from JiayiFeng/refine_reshape_op

Add Wait() for reshape_op
wangkuiyi-patch-2
fengjiayi 7 years ago committed by GitHub
commit 51c219c9cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -147,6 +147,7 @@ class ReshapeKernel : public framework::OpKernel<T> {
if (!inplace) {
out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
ctx.device_context().Wait();
// TensorCopy will resize to in_dims.
out->Resize(out_dims);
} else {
@ -169,6 +170,7 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
auto in_dims = d_x->dims();
if (!inplace) {
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
ctx.device_context().Wait();
d_x->Resize(in_dims);
} else {
d_x->ShareDataWith(*d_out);

Loading…
Cancel
Save