|
|
|
@ -27,29 +27,19 @@ class SqueezeOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
"Input(X) of Squeeze operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
"Output(Out) of Squeeze operator should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Squeeze");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Squeeze");
|
|
|
|
|
|
|
|
|
|
const auto &x_dims = ctx->GetInputDim("X");
|
|
|
|
|
// Check input tensor dims (<6) Eigen limit.
|
|
|
|
|
PADDLE_ENFORCE_LE(x_dims.size(), 6,
|
|
|
|
|
"ShapeError: the dimensions of Input(X) "
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimensions of Input(X) "
|
|
|
|
|
"should be in the range of [1, 6] (Eigen limit)."
|
|
|
|
|
"But received X's dimensions = %d, X's shape=[%s].",
|
|
|
|
|
x_dims.size(), x_dims);
|
|
|
|
|
x_dims.size(), x_dims));
|
|
|
|
|
|
|
|
|
|
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
|
|
|
|
for (int a : axes) {
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
a, x_dims.size(),
|
|
|
|
|
"ShapeError: The squeeze axis should be less than input "
|
|
|
|
|
"tensor's dimensions. But received axis = %d, input "
|
|
|
|
|
"tensor's dimensions = %d, input tensor's shape = [%s].",
|
|
|
|
|
a, x_dims.size(), x_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_dims = GetOutputShape(axes, x_dims);
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
|
if (x_dims[0] == out_dims[0]) {
|
|
|
|
@ -78,10 +68,18 @@ class SqueezeOp : public framework::OperatorWithKernel {
|
|
|
|
|
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
|
|
|
|
|
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
|
|
|
|
|
: squeeze_dims[idx];
|
|
|
|
|
PADDLE_ENFORCE_GE(current, 0,
|
|
|
|
|
"Invalid axis, the axis should >= 0."
|
|
|
|
|
"Current axis is:%d, input tensor's shape = [%s].",
|
|
|
|
|
current, in_dims);
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
current, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Each axis in Attr(axes) should be in the range of [%d, %d]"
|
|
|
|
|
"But current axis is:%d, input tensor's shape = [%s].",
|
|
|
|
|
-in_dims.size(), in_dims.size() - 1, current, in_dims));
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
current, in_dims.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Each axis in Attr(axes) should be in the range of [%d, %d]"
|
|
|
|
|
"But current axis is:%d, input tensor's shape = [%s].",
|
|
|
|
|
-in_dims.size(), in_dims.size() - 1, current, in_dims));
|
|
|
|
|
|
|
|
|
|
if (!(should_squeeze[current])) {
|
|
|
|
|
++cnt_squeezed_dims;
|
|
|
|
@ -171,28 +169,19 @@ class Squeeze2Op : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
"Input(X) of Squeeze operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
"Output(Out) of Squeeze operator should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Squeeze2");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Squeeze2");
|
|
|
|
|
|
|
|
|
|
const auto &x_dims = ctx->GetInputDim("X");
|
|
|
|
|
// Check input tensor dims (<6) Eigen limit.
|
|
|
|
|
PADDLE_ENFORCE_LE(x_dims.size(), 6,
|
|
|
|
|
"ShapeError: the dimensions of Input(X) "
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimensions of Input(X) "
|
|
|
|
|
"should be in the range of [1, 6] (Eigen limit)."
|
|
|
|
|
"But received X's dimensions = %d, X's shape = [%s].",
|
|
|
|
|
x_dims.size(), x_dims);
|
|
|
|
|
x_dims.size(), x_dims));
|
|
|
|
|
|
|
|
|
|
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
|
|
|
|
for (int a : axes) {
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
a, x_dims.size(),
|
|
|
|
|
"ShapeError: The squeeze axis should be less than input "
|
|
|
|
|
"tensor's dimensions. But received axis = %d, input "
|
|
|
|
|
"tensor's dimensions = %d, input tensor's shape = [%s].",
|
|
|
|
|
a, x_dims.size(), x_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_dims = SqueezeOp::GetOutputShape(axes, x_dims);
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
@ -202,8 +191,8 @@ class Squeeze2Op : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->ShareLoD("X", "Out");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
|
|
|
|
|
"Output(XShape) of Squeeze operator should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Squeeze2");
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
|
|
|
|
|
xshape_dims[0] = 0;
|
|
|
|
|
for (int i = 0; i < x_dims.size(); ++i) {
|
|
|
|
@ -233,10 +222,10 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *context) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(context->HasInput("XShape"), true,
|
|
|
|
|
"Input(XShape) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
OP_INOUT_CHECK(context->HasInput("XShape"), "Input", "XShape",
|
|
|
|
|
"Squeeze2Grad");
|
|
|
|
|
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
framework::GradVarName("Out"), "Squeeze2Grad");
|
|
|
|
|
auto xshape_dims = context->GetInputDim("XShape");
|
|
|
|
|
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
|
|
|
|
|
context->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|