|
|
|
|
@ -56,9 +56,11 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
"Input(X) of ReshapeOp should not be null.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) of ReshapeOp should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
"Output(Out) of ReshapeOp should not be null.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(Out) of ReshapeOp should not be null."));
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInputs("ShapeTensor")) {
|
|
|
|
|
// top prority shape
|
|
|
|
|
@ -304,9 +306,12 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::InvalidArgument("Input(X) shouldn't be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null."));
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -403,7 +408,8 @@ class Reshape2Op : public ReshapeOp {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
|
|
|
|
|
"Output(XShape) of ReshapeOp should not be null.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(XShape) of ReshapeOp should not be null."));
|
|
|
|
|
const auto &x_dims = ctx->GetInputDim("X");
|
|
|
|
|
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
|
|
|
|
|
xshape_dims[0] = 0;
|
|
|
|
|
@ -472,10 +478,12 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("XShape"), true,
|
|
|
|
|
"Input(XShape) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("XShape"), true,
|
|
|
|
|
platform::errors::InvalidArgument("Input(XShape) shouldn't be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null."));
|
|
|
|
|
auto xshape_dims = ctx->GetInputDim("XShape");
|
|
|
|
|
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
@ -511,8 +519,8 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("DDX"), true,
|
|
|
|
|
"Input(X@GRAD_GRAD) shouldn't be null.");
|
|
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X@GRAD_GRAD) shouldn't be null."));
|
|
|
|
|
if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) {
|
|
|
|
|
ctx->ShareDim("DOut", "DDOut");
|
|
|
|
|
}
|
|
|
|
|
|