|
|
|
@ -29,17 +29,17 @@ class FlattenOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
"Input (X) of Flatten op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
"Output (Output) of Flatten op should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Flatten");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Flatten");
|
|
|
|
|
const auto &axis = ctx->Attrs().Get<int>("axis");
|
|
|
|
|
const auto &in_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_GE(axis, 0,
|
|
|
|
|
"The axis should be greater than or equal to 0.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The axis should be greater than or equal to 0."));
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
axis, in_dims.size(),
|
|
|
|
|
"The axis should be less than or equal to input tensor's rank.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The axis should be less than or equal to input tensor's rank."));
|
|
|
|
|
|
|
|
|
|
const auto &out_dims = GetOutputShape(axis, in_dims);
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
|
|
|
|
@ -161,17 +161,17 @@ class Flatten2Op : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
"Input (X) of Flatten op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
"Output (Output) of Flatten op should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Flatten2");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Flatten2");
|
|
|
|
|
const auto &axis = ctx->Attrs().Get<int>("axis");
|
|
|
|
|
const auto &in_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_GE(axis, 0,
|
|
|
|
|
"The axis should be greater than or equal to 0.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The axis should be greater than or equal to 0."));
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
axis, in_dims.size(),
|
|
|
|
|
"The axis should be less than or equal to input tensor's rank.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The axis should be less than or equal to input tensor's rank"));
|
|
|
|
|
|
|
|
|
|
const auto &out_dims = FlattenOp::GetOutputShape(axis, in_dims);
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
|
|
|
|
@ -181,8 +181,7 @@ class Flatten2Op : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->ShareLoD("X", "Out");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
|
|
|
|
|
"Output (XShape) of Flatten op should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2");
|
|
|
|
|
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
|
|
|
|
|
xshape_dims[0] = 0;
|
|
|
|
|
for (int i = 0; i < in_dims.size(); ++i) {
|
|
|
|
@ -223,10 +222,10 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *context) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(context->HasInput("XShape"), true,
|
|
|
|
|
"Input(XShape) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
OP_INOUT_CHECK(context->HasInput("XShape"), "Input", "XShape",
|
|
|
|
|
"Flatten2Grad");
|
|
|
|
|
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
framework::GradVarName("Out"), "Flatten2Grad");
|
|
|
|
|
auto xshape_dims = context->GetInputDim("XShape");
|
|
|
|
|
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
|
|
|
|
|
context->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|