|
|
|
@ -363,6 +363,20 @@ class ReshapeGradKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ReshapeDoubleGradKernel {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::ExecutionContext &ctx) const {
|
|
|
|
|
auto *dd_x = ctx.Input<framework::Tensor>("DDX");
|
|
|
|
|
auto *dd_out = ctx.Output<framework::Tensor>("DDOut");
|
|
|
|
|
|
|
|
|
|
auto out_dims = dd_out->dims();
|
|
|
|
|
|
|
|
|
|
dd_out->mutable_data(ctx.GetPlace(), dd_x->type());
|
|
|
|
|
framework::TensorCopySync(*dd_x, ctx.GetPlace(), dd_out);
|
|
|
|
|
dd_out->Resize(out_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// FIXME(zcd): reshape2 adds an intermediate output(XShape) based on reshape,
|
|
|
|
|
// the XShape is used to carry the shape and lod of X which will be used in
|
|
|
|
|
// reshape_grad, in this way, the framework can reuse the memory of X
|
|
|
|
@ -409,6 +423,7 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
auto *grad_op = new framework::OpDesc();
|
|
|
|
|
grad_op->SetType("reshape2_grad");
|
|
|
|
|
grad_op->SetInput("X", Input("X"));
|
|
|
|
|
grad_op->SetInput("XShape", Output("XShape"));
|
|
|
|
|
grad_op->SetInput("ShapeTensor", Input("ShapeTensor"));
|
|
|
|
|
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
@ -418,6 +433,27 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Reshape2DoubleGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
auto *grad_op = new framework::OpDesc();
|
|
|
|
|
grad_op->SetType("reshape2_grad_grad");
|
|
|
|
|
|
|
|
|
|
grad_op->SetInput("X", Input("X"));
|
|
|
|
|
grad_op->SetInput("ShapeTensor", Input("ShapeTensor"));
|
|
|
|
|
grad_op->SetInput("DOut", Input(framework::GradVarName("Out")));
|
|
|
|
|
grad_op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
|
|
|
|
|
|
|
|
|
|
auto ddx = OutputGrad(framework::GradVarName("X"));
|
|
|
|
|
|
|
|
|
|
grad_op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
|
|
|
|
|
grad_op->SetAttrMap(Attrs());
|
|
|
|
|
return std::unique_ptr<framework::OpDesc>(grad_op);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Reshape2GradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
Reshape2GradOp(const std::string &type,
|
|
|
|
@ -456,10 +492,47 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
Reshape2DoubleGradOp(const std::string &type,
|
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("DDX"), true,
|
|
|
|
|
"Input(X@GRAD_GRAD) shouldn't be null.");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) {
|
|
|
|
|
ctx->ShareDim("DOut", "DDOut");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
return framework::OpKernelType(ctx.Input<framework::Tensor>("DDX")->type(),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetKernelTypeForVar(
|
|
|
|
|
const std::string &var_name, const Tensor &tensor,
|
|
|
|
|
const framework::OpKernelType &expected_kernel_type) const override {
|
|
|
|
|
if (var_name == "ShapeTensor") {
|
|
|
|
|
return expected_kernel_type;
|
|
|
|
|
}
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
tensor.place(), tensor.layout());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"});
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut,
|
|
|
|
|
{framework::GradVarName("Out"),
|
|
|
|
|
framework::GradVarName("X")});
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInToOut, {"DDX", "DDOut"});
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
@ -471,6 +544,7 @@ REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
|
|
|
|
|
ops::ReshapeOpInplaceInToOut);
|
|
|
|
|
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp,
|
|
|
|
|
ops::ReshapeGradInplaceInToOut);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
|
int64_t, ops::ReshapeKernel);
|
|
|
|
@ -478,11 +552,13 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
|
|
|
|
|
ops::Reshape2GradMaker, ops::ReshapeOpInplaceInToOut);
|
|
|
|
|
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
|
|
|
|
|
ops::ReshapeGradInplaceInToOut);
|
|
|
|
|
ops::Reshape2DoubleGradMaker, ops::ReshapeGradInplaceInToOut);
|
|
|
|
|
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
|
|
|
|
|
ops::ReshapeDoubleGradInplaceInToOut);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
|
int64_t, ops::ReshapeKernel);
|
|
|
|
@ -490,6 +566,11 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad_grad, float,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, double,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, int,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeDoubleGradKernel);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
|
|
|
|
@ -510,4 +591,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel, plat::float16,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, double,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, int,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, plat::float16,
|
|
|
|
|
ops::ReshapeDoubleGradKernel);
|
|
|
|
|
#endif
|
|
|
|
|