|
|
|
@ -246,6 +246,88 @@ class ReshapeGradKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// FIXME(zcd): reshape2 adds an intermediate output(XShape) based on reshape,
|
|
|
|
|
// the XShape is used to carry the shape and lod of X which will be used in
|
|
|
|
|
// reshape_grad, in this way, the framework can reuse the memory of X
|
|
|
|
|
// immediately the reshape_op is finished.
|
|
|
|
|
// Considering compatibility issues, we could not fix reshape_op
|
|
|
|
|
class Reshape2Op : public ReshapeOp {
|
|
|
|
|
public:
|
|
|
|
|
Reshape2Op(const std::string &type, const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: ReshapeOp(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
ReshapeOp::InferShape(ctx);
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
|
|
|
|
|
"Output(XShape) of ReshapeOp should not be null.");
|
|
|
|
|
const auto &x_dims = ctx->GetInputDim("X");
|
|
|
|
|
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
|
|
|
|
|
xshape_dims[0] = 0;
|
|
|
|
|
for (int i = 0; i < x_dims.size(); ++i) {
|
|
|
|
|
xshape_dims[i + 1] = x_dims[i];
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "XShape");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Reshape2OpMaker : public ReshapeOpMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
|
ReshapeOpMaker::Make();
|
|
|
|
|
AddOutput("XShape",
|
|
|
|
|
"XShape is just used to store the shape and lod of X, which will "
|
|
|
|
|
"be used in FlattenGradOp.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Reshape2GradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
auto *grad_op = new framework::OpDesc();
|
|
|
|
|
grad_op->SetType("reshape2_grad");
|
|
|
|
|
grad_op->SetInput("XShape", Output("XShape"));
|
|
|
|
|
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
|
grad_op->SetAttrMap(Attrs());
|
|
|
|
|
return std::unique_ptr<framework::OpDesc>(grad_op);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Reshape2GradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
Reshape2GradOp(const std::string &type,
|
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("XShape"), "Input(XShape) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
auto xshape_dims = ctx->GetInputDim("XShape");
|
|
|
|
|
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
ctx->ShareLoD("XShape", framework::GradVarName("X"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(
|
|
|
|
|
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
|
|
|
|
|
->type()),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
@ -261,6 +343,17 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
|
|
|
|
|
ops::Reshape2GradMaker);
|
|
|
|
|
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
|
int64_t, ops::ReshapeKernel);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
@ -269,4 +362,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
|
int64_t, ops::ReshapeKernel);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
#endif
|
|
|
|
|