|
|
|
@ -25,15 +25,11 @@ class DGCMomentumOp : public MomentumOp {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("current_step"), true,
|
|
|
|
|
"current_step should be set.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("nranks"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(nranks) of DGCMomentumOp is not found."));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Grad_out"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Output(Grad_out) of DGCMomentumOp is not found."));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("current_step"), "Input", "current_step",
|
|
|
|
|
"DGCMomentumOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("nranks"), "Input", "nranks", "DGCMomentumOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Grad_out"), "Output", "Grad_out",
|
|
|
|
|
"DGCMomentumOp");
|
|
|
|
|
return MomentumOp::InferShape(ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|