|
|
|
@ -26,10 +26,16 @@ class ReshapeKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
auto* out = ctx.Output<framework::Tensor>("Out");
|
|
|
|
|
auto* in = ctx.Input<framework::Tensor>("X");
|
|
|
|
|
bool inplace = ctx.Attr<bool>("inplace");
|
|
|
|
|
auto out_dims = out->dims();
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
if (!inplace) {
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
} else {
|
|
|
|
|
out->ShareDataWith(*in);
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -40,10 +46,16 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
d_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
bool inplace = ctx.Attr<bool>("inplace");
|
|
|
|
|
|
|
|
|
|
auto in_dims = d_x->dims();
|
|
|
|
|
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
|
|
|
|
|
d_x->Resize(in_dims);
|
|
|
|
|
if (!inplace) {
|
|
|
|
|
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
|
|
|
|
|
d_x->Resize(in_dims);
|
|
|
|
|
} else {
|
|
|
|
|
d_x->ShareDataWith(*d_out);
|
|
|
|
|
d_x->Resize(in_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|