|
|
|
@ -48,48 +48,33 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("param"),
|
|
|
|
|
"Input (param) of average_accumulates op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("in_sum_1"),
|
|
|
|
|
"Input (sum_1) of average_accumulates op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("in_sum_2"),
|
|
|
|
|
"Input (sum_2) of average_accumulates op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("in_sum_3"),
|
|
|
|
|
"Input (sum_3) of average_accumulates op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("in_num_accumulates"),
|
|
|
|
|
"Input (in_num_accumulates) of average_accumulates op should "
|
|
|
|
|
"not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("in_old_num_accumulates"),
|
|
|
|
|
"Input (old_num_accumulates) of average_accumulates op "
|
|
|
|
|
"should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("in_num_updates"),
|
|
|
|
|
"Input (num_updates) of average_accumulates op should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("out_sum_1"),
|
|
|
|
|
"Output (sum_1) of average_accumulates op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("out_sum_2"),
|
|
|
|
|
"Output (sum_2) of average_accumulates op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("out_sum_3"),
|
|
|
|
|
"Output (sum_3) of average_accumulates op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("out_num_accumulates"),
|
|
|
|
|
"Output (num_accumulates) of average_accumulates op should "
|
|
|
|
|
"not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("out_old_num_accumulates"),
|
|
|
|
|
"Output (old_num_accumulates) of average_accumulates op "
|
|
|
|
|
"should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("out_num_updates"),
|
|
|
|
|
"Output (num_updates) of average_accumulates op should not be null.");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("param"), "Input", "param",
|
|
|
|
|
"AverageAccumulates");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("in_sum_1"), "Input", "in_sum_1",
|
|
|
|
|
"AverageAccumulates");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("in_sum_2"), "Input", "in_sum_2",
|
|
|
|
|
"AverageAccumulates");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("in_sum_3"), "Input", "in_sum_3",
|
|
|
|
|
"AverageAccumulates");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("in_num_accumulates"), "Input",
|
|
|
|
|
"in_num_accumulates", "AverageAccumulates");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("in_old_num_accumulates"), "Input",
|
|
|
|
|
"in_old_num_accumulates", "AverageAccumulates");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("in_num_updates"), "Input", "in_num_updates",
|
|
|
|
|
"AverageAccumulates");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("out_sum_1"), "Output", "out_sum_1",
|
|
|
|
|
"AverageAccumulates");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("out_sum_2"), "Output", "out_sum_2",
|
|
|
|
|
"AverageAccumulates");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("out_sum_3"), "Output", "out_sum_3",
|
|
|
|
|
"AverageAccumulates");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("out_num_accumulates"), "Output",
|
|
|
|
|
"out_num_accumulates", "AverageAccumulates");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("out_old_num_accumulates"), "Output",
|
|
|
|
|
"out_old_num_accumulates", "AverageAccumulates");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("out_num_updates"), "Output",
|
|
|
|
|
"out_num_updates", "AverageAccumulates");
|
|
|
|
|
auto in_dim = ctx->GetInputDim("param");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("out_sum_1", in_dim);
|
|
|
|
|