|
|
|
@ -107,7 +107,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const {
|
|
|
|
|
void ReshapeKernel::operator()(const framework::ExecutionContext &ctx) const {
|
|
|
|
|
auto *out = ctx.Output<framework::LoDTensor>("Out");
|
|
|
|
|
auto *in = ctx.Input<framework::LoDTensor>("X");
|
|
|
|
|
|
|
|
|
@ -147,7 +147,7 @@ void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const {
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void ReshapeGradKernelBase::Compute(
|
|
|
|
|
void ReshapeGradKernel::operator()(
|
|
|
|
|
const framework::ExecutionContext &ctx) const {
|
|
|
|
|
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
@ -172,10 +172,10 @@ namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_EX(reshape, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t,
|
|
|
|
|
ops::ReshapeKernel);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<float>,
|
|
|
|
|
ops::ReshapeGradKernel<double>,
|
|
|
|
|
ops::ReshapeGradKernel<int>,
|
|
|
|
|
ops::ReshapeGradKernel<int64_t>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
|
int64_t, ops::ReshapeKernel);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|