|
|
|
@ -447,8 +447,6 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
|
|
|
|
|
"MovingAverageAbsMaxScale");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"MovingAverageAbsMaxScale");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
|
|
|
|
|
"MovingAverageAbsMaxScale");
|
|
|
|
|
if (ctx->HasOutput("OutState")) {
|
|
|
|
@ -457,9 +455,7 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
|
|
|
|
|
if (ctx->HasOutput("OutAccum")) {
|
|
|
|
|
ctx->SetOutputDim("OutAccum", {1});
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
|
|
|
|
ctx->SetOutputDim("OutScale", {1});
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -477,8 +473,6 @@ class MovingAverageAbsMaxScaleOpMaker
|
|
|
|
|
AddInput("X", "(Tensor) Input is float data type.");
|
|
|
|
|
AddInput("InAccum", "Last accum.").AsDispensable();
|
|
|
|
|
AddInput("InState", "Last state.").AsDispensable();
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(Tensor) Output tensor is just equivalent to the input tensor.");
|
|
|
|
|
AddOutput("OutScale", " Current scale");
|
|
|
|
|
AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
|
|
|
|
|
AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
|
|
|
|
|