|
|
|
@ -264,31 +264,31 @@ class ReduceOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of ReduceOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of ReduceOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ReduceOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ReduceOp");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto x_rank = x_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_LE(x_rank, 6,
|
|
|
|
|
"ShapeError: The input tensor X's dimensions of Reduce "
|
|
|
|
|
"should be less equal than 6. But received X's "
|
|
|
|
|
"dimensions = %d, X's shape = [%s].",
|
|
|
|
|
x_rank, x_dims);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input tensor X's dimensions of ReduceOp "
|
|
|
|
|
"should be less equal than 6. But received X's "
|
|
|
|
|
"dimensions = %d, X's shape = [%s].",
|
|
|
|
|
x_rank, x_dims));
|
|
|
|
|
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
dims.size(), 0,
|
|
|
|
|
"ShapeError: The input dim dimensions of Reduce "
|
|
|
|
|
"should be greater than 0. But received the dim dimesions of Reduce "
|
|
|
|
|
" = %d",
|
|
|
|
|
dims.size());
|
|
|
|
|
PADDLE_ENFORCE_GT(dims.size(), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input dim dimensions of ReduceOp "
|
|
|
|
|
"should be greater than 0. But received the dim "
|
|
|
|
|
"dimesions of Reduce = %d.",
|
|
|
|
|
dims.size()));
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < dims.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(dims[i], x_rank,
|
|
|
|
|
"ShapeError: The reduce dim index %d should be in the "
|
|
|
|
|
"range [-dimension(X), dimension(X)]."
|
|
|
|
|
"which dimesion = %d, But received dim index = %d",
|
|
|
|
|
i, x_rank, dims[i]);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The reduce dim index %d should be in the "
|
|
|
|
|
"range [-dimension(X), dimension(X)] "
|
|
|
|
|
"which dimesion = %d. But received dim index = %d.",
|
|
|
|
|
i, x_rank, dims[i]));
|
|
|
|
|
if (dims[i] < 0) dims[i] = x_rank + dims[i];
|
|
|
|
|
}
|
|
|
|
|
sort(dims.begin(), dims.end());
|
|
|
|
@ -346,19 +346,24 @@ class ReduceGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ReduceOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
"Out@GRAD", "ReduceOp");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto x_rank = x_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
|
|
|
|
|
PADDLE_ENFORCE_LE(x_rank, 6,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Tensors with rank at most 6 are supported by "
|
|
|
|
|
"ReduceOp. Received tensor with rank %d.",
|
|
|
|
|
x_rank));
|
|
|
|
|
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
|
|
|
|
|
for (size_t i = 0; i < dims.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(dims[i], x_rank,
|
|
|
|
|
"ShapeError: The reduce dim index %d should be in the "
|
|
|
|
|
"range [-dimension(X), dimension(X)]."
|
|
|
|
|
"which dimesion = %d, But received dim index = %d",
|
|
|
|
|
i, x_rank, dims[i]);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The reduce dim index %d should be in the "
|
|
|
|
|
"range [-dimension(X), dimension(X)], "
|
|
|
|
|
"which dimesion = %d. But received dim index = %d.",
|
|
|
|
|
i, x_rank, dims[i]));
|
|
|
|
|
if (dims[i] < 0) dims[i] = x_rank + dims[i];
|
|
|
|
|
}
|
|
|
|
|
sort(dims.begin(), dims.end());
|
|
|
|
|