|
|
|
@ -180,12 +180,11 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of FakeQuantizeOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FakeQuantizeOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
|
|
|
|
|
"Output(Scale) of FakeQuantizeOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeAbsMax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"FakeQuantizeAbsMax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
|
|
|
|
|
"FakeQuantizeAbsMax");
|
|
|
|
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
|
|
|
|
ctx->SetOutputDim("OutScale", {1});
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
@ -211,8 +210,11 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddAttr<int>("bit_length", "(int, default 8)")
|
|
|
|
|
.SetDefault(8)
|
|
|
|
|
.AddCustomChecker([](const int& bit_length) {
|
|
|
|
|
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
|
|
|
|
|
"'bit_length' should be between 1 and 16.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"'bit_length' should be between 1 and 16, but "
|
|
|
|
|
"the received is %d",
|
|
|
|
|
bit_length));
|
|
|
|
|
});
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
FakeQuantize operator
|
|
|
|
@ -230,14 +232,12 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of FakeChannelWiseQuantizeOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FakeChannelWiseQuantizeOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("OutScale"),
|
|
|
|
|
"Output(Scale) of FakeChannelWiseQuantizeOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
|
|
|
|
|
"FakeChannelWiseQuantizeAbsMax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"FakeChannelWiseQuantizeAbsMax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
|
|
|
|
|
"FakeChannelWiseQuantizeAbsMax");
|
|
|
|
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
|
|
|
|
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
@ -263,8 +263,11 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
|
|
|
|
|
AddAttr<int>("bit_length", "(int, default 8)")
|
|
|
|
|
.SetDefault(8)
|
|
|
|
|
.AddCustomChecker([](const int& bit_length) {
|
|
|
|
|
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
|
|
|
|
|
"'bit_length' should be between 1 and 16.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"'bit_length' should be between 1 and 16, but "
|
|
|
|
|
"the received is %d",
|
|
|
|
|
bit_length));
|
|
|
|
|
});
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
The scale of FakeChannelWiseQuantize operator is a vector.
|
|
|
|
@ -288,14 +291,11 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of FakeQuantizeRangeAbsMaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("OutScale"),
|
|
|
|
|
"Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"FakeQuantizeRangeAbsMax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
|
|
|
|
|
"FakeQuantizeRangeAbsMax");
|
|
|
|
|
if (ctx->HasOutput("OutScales")) {
|
|
|
|
|
int window_size = ctx->Attrs().Get<int>("window_size");
|
|
|
|
|
ctx->SetOutputDim("OutScales", {window_size});
|
|
|
|
@ -329,8 +329,11 @@ class FakeQuantizeRangeAbsMaxOpMaker
|
|
|
|
|
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
|
|
|
|
|
.SetDefault(8)
|
|
|
|
|
.AddCustomChecker([](const int& bit_length) {
|
|
|
|
|
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
|
|
|
|
|
"'bit_length' should be between 1 and 16.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"'bit_length' should be between 1 and 16, but "
|
|
|
|
|
"the received is %d",
|
|
|
|
|
bit_length));
|
|
|
|
|
});
|
|
|
|
|
AddAttr<bool>("is_test",
|
|
|
|
|
"(bool, default false) Set to true for inference only, false "
|
|
|
|
@ -357,16 +360,12 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of FakeQuantOrWithDequantMovingAverageAbsMaxOp "
|
|
|
|
|
"should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FakeQuantOrWithDequantMovingAverageAbsMaxOp "
|
|
|
|
|
"should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("OutScale"),
|
|
|
|
|
"Output(OutScale) of FakeQuantOrWithDequantMovingAverageAbsMaxOp "
|
|
|
|
|
"should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
|
|
|
|
|
"FakeQuantOrWithDequantMovingAverageAbsMax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"FakeQuantOrWithDequantMovingAverageAbsMax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
|
|
|
|
|
"FakeQuantOrWithDequantMovingAverageAbsMax");
|
|
|
|
|
if (ctx->HasOutput("OutState")) {
|
|
|
|
|
ctx->SetOutputDim("OutState", {1});
|
|
|
|
|
}
|
|
|
|
@ -404,8 +403,11 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
|
|
|
|
|
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
|
|
|
|
|
.SetDefault(8)
|
|
|
|
|
.AddCustomChecker([](const int& bit_length) {
|
|
|
|
|
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
|
|
|
|
|
"'bit_length' should be between 1 and 16.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"'bit_length' should be between 1 and 16, but "
|
|
|
|
|
"the received is %d",
|
|
|
|
|
bit_length));
|
|
|
|
|
});
|
|
|
|
|
AddAttr<bool>("is_test",
|
|
|
|
|
"(bool, default false) Set to true for inference only, false "
|
|
|
|
@ -434,15 +436,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of MovingAverageAbsMaxScaleOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of MovingAverageAbsMaxScaleOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
|
|
|
|
|
"Output(OutScale) of MovingAverageAbsMaxScaleOp"
|
|
|
|
|
"should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
|
|
|
|
|
"MovingAverageAbsMaxScale");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"MovingAverageAbsMaxScale");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
|
|
|
|
|
"MovingAverageAbsMaxScale");
|
|
|
|
|
if (ctx->HasOutput("OutState")) {
|
|
|
|
|
ctx->SetOutputDim("OutState", {1});
|
|
|
|
|
}
|
|
|
|
|