|
|
|
|
@ -88,10 +88,9 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeDequantizeMaxAbs");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"FakeDequantizeMaxAbs");
|
|
|
|
|
|
|
|
|
|
ctx->ShareDim("X", /*->*/ "Out");
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
@ -125,15 +124,12 @@ class FakeChannelWiseDequantizeMaxAbsOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of FakeChannelWiseDequantizeMaxAbsOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs("Scales"),
|
|
|
|
|
"Input(Scales) of FakeChannelWiseDequantizeMaxAbsOp "
|
|
|
|
|
"should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FakeChannelWiseDequantizeMaxAbsOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
|
|
|
|
|
"FakeChannelWiseDequantizeMaxAbs");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInputs("Scales"), "Input", "Scales",
|
|
|
|
|
"FakeChannelWiseDequantizeMaxAbs");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"FakeChannelWiseDequantizeMaxAbs");
|
|
|
|
|
|
|
|
|
|
ctx->ShareDim("X", /*->*/ "Out");
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
|