|
|
@ -24,20 +24,20 @@ class ReduceOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
"Input(X) of ReduceOp should not be null.");
|
|
|
|
"Input(X) of ReduceOp should not be null.");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
"Output(Out) of ReduceOp should not be null.");
|
|
|
|
"Output(Out) of ReduceOp should not be null.");
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
auto x_rank = x_dims.size();
|
|
|
|
auto x_rank = x_dims.size();
|
|
|
|
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
|
|
|
|
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
|
|
|
|
int dim = ctx.Attr<int>("dim");
|
|
|
|
int dim = ctx->Attrs().Get<int>("dim");
|
|
|
|
if (dim < 0) dim = x_rank + dim;
|
|
|
|
if (dim < 0) dim = x_rank + dim;
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
dim, x_rank,
|
|
|
|
dim, x_rank,
|
|
|
|
"The dim should be in the range [-rank(input), rank(input)).");
|
|
|
|
"The dim should be in the range [-rank(input), rank(input)).");
|
|
|
|
bool keep_dim = ctx.Attr<bool>("keep_dim");
|
|
|
|
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
|
|
|
|
auto dims_vector = vectorize(x_dims);
|
|
|
|
auto dims_vector = vectorize(x_dims);
|
|
|
|
if (keep_dim || x_rank == 1) {
|
|
|
|
if (keep_dim || x_rank == 1) {
|
|
|
|
dims_vector[dim] = 1;
|
|
|
|
dims_vector[dim] = 1;
|
|
|
@ -45,10 +45,10 @@ class ReduceOp : public framework::OperatorWithKernel {
|
|
|
|
dims_vector.erase(dims_vector.begin() + dim);
|
|
|
|
dims_vector.erase(dims_vector.begin() + dim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto out_dims = framework::make_ddim(dims_vector);
|
|
|
|
auto out_dims = framework::make_ddim(dims_vector);
|
|
|
|
ctx.Output<framework::Tensor>("Out")->Resize(out_dims);
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
if (dim != 0) {
|
|
|
|
if (dim != 0) {
|
|
|
|
// Only pass LoD when not reducing on the first dim
|
|
|
|
// Only pass LoD when not reducing on the first dim.
|
|
|
|
ctx.ShareLoD("X", /*->*/ "Out");
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
@ -58,21 +58,22 @@ class ReduceGradOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
"Input(Out@GRAD) should not be null.");
|
|
|
|
"Input(Out@GRAD) should not be null.");
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
auto x_rank = x_dims.size();
|
|
|
|
auto x_rank = x_dims.size();
|
|
|
|
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
|
|
|
|
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
|
|
|
|
int dim = ctx.Attr<int>("dim");
|
|
|
|
int dim = ctx->Attrs().Get<int>("dim");
|
|
|
|
if (dim < 0) dim = x_rank + dim;
|
|
|
|
if (dim < 0) dim = x_rank + dim;
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
dim, x_rank,
|
|
|
|
dim, x_rank,
|
|
|
|
"The dim should be in the range [-rank(input), rank(input)).");
|
|
|
|
"The dim should be in the range [-rank(input), rank(input)).");
|
|
|
|
auto *x_grad =
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
if (x_grad) x_grad->Resize(x_dims);
|
|
|
|
ctx->SetOutputDim(x_grad_name, x_dims);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|