|
|
|
|
@ -82,7 +82,7 @@ struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
|
|
|
|
|
auto out_e = framework::EigenVector<T>::Flatten(*out);
|
|
|
|
|
out_e.device(*ctx.eigen_device()) =
|
|
|
|
|
(s / bin_cnt) * (bin_cnt * inv_s * out_e).round();
|
|
|
|
|
(bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
|
|
|
|
|
@ -171,20 +171,21 @@ struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext,
|
|
|
|
|
float>;
|
|
|
|
|
|
|
|
|
|
class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
FakeQuantizeAbsMaxOp(const std::string& type,
|
|
|
|
|
const framework::VariableNameMap& inputs,
|
|
|
|
|
const framework::VariableNameMap& outputs,
|
|
|
|
|
const framework::AttributeMap& attrs)
|
|
|
|
|
FakeQuantOrWithDequantAbsMaxOp(const std::string& type,
|
|
|
|
|
const framework::VariableNameMap& inputs,
|
|
|
|
|
const framework::VariableNameMap& outputs,
|
|
|
|
|
const framework::AttributeMap& attrs)
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeAbsMax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
|
|
|
|
|
"FakeQuantOrWithDequantAbsMaxOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"FakeQuantizeAbsMax");
|
|
|
|
|
"FakeQuantOrWithDequantAbsMaxOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
|
|
|
|
|
"FakeQuantizeAbsMax");
|
|
|
|
|
"FakeQuantOrWithDequantAbsMaxOp");
|
|
|
|
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
|
|
|
|
ctx->SetOutputDim("OutScale", {1});
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
@ -199,7 +200,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class FakeQuantOrWithDequantAbsMaxOpMaker
|
|
|
|
|
: public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddInput("X", "(Tensor) Input is float data type.");
|
|
|
|
|
@ -217,12 +219,19 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
bit_length));
|
|
|
|
|
});
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
FakeQuantize operator
|
|
|
|
|
This is a Base Op which support FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
|
|
|
|
|
FakeQuantAbsMaxOp operator is used in the dynamic quantization.
|
|
|
|
|
|
|
|
|
|
$$scale = max(abs(X))$$
|
|
|
|
|
$$range = 2^{bit_length - 1} - 1$$
|
|
|
|
|
$$Out = round(X/scale * range)$$
|
|
|
|
|
|
|
|
|
|
FakeQuantDequantAbsMaxOp operator do the abs_max quant and then dequant.
|
|
|
|
|
|
|
|
|
|
$$scale = max(abs(X))$$
|
|
|
|
|
$$range = 2^{bit\_length - 1} - 1$$
|
|
|
|
|
$$Out = round(X/scale * range) * scale / range$$
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@ -414,14 +423,14 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
|
|
|
|
|
"for training. Some layers may run faster when this is true.")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp
|
|
|
|
|
FakeQuantMovingAverageAbsMaxOp operator is used in static quantization.
|
|
|
|
|
This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp.
|
|
|
|
|
FakeQuantMovingAverageAbsMaxOp operator is used in the static quantization.
|
|
|
|
|
|
|
|
|
|
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
|
|
|
|
|
$$range = 2^{bit\_length - 1} - 1$$
|
|
|
|
|
$$Out = round(X/scale * range)$$
|
|
|
|
|
|
|
|
|
|
FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max op quant and then dequant.
|
|
|
|
|
FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max quant and then dequant.
|
|
|
|
|
|
|
|
|
|
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
|
|
|
|
|
$$range = 2^{bit\_length - 1} - 1$$
|
|
|
|
|
@ -490,6 +499,46 @@ $$Out = X$$
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
auto out_grad_name = framework::GradVarName("Out");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
|
|
|
|
|
"FakeQuantDequantGradOp");
|
|
|
|
|
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput(x_grad_name), true,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"FakeQuantDequantGradOp doesn't have the output named %s.",
|
|
|
|
|
x_grad_name));
|
|
|
|
|
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
|
|
|
|
|
ctx, framework::GradVarName("Out"));
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class FakeQuantDequantGradMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void Apply(GradOpPtr<T> grad_op) const override {
|
|
|
|
|
grad_op->SetType("fake_quantize_dequantize_grad");
|
|
|
|
|
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
|
|
|
|
grad_op->SetAttrMap(this->Attrs());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
@ -497,13 +546,21 @@ namespace ops = paddle::operators;
|
|
|
|
|
using CPU = paddle::platform::CPUDeviceContext;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp,
|
|
|
|
|
ops::FakeQuantizeAbsMaxOpMaker,
|
|
|
|
|
fake_quantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp,
|
|
|
|
|
ops::FakeQuantOrWithDequantAbsMaxOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
|
|
|
|
|
ops::FakeQuantizeAbsMaxKernel<CPU, float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(fake_quantize_dequantize_abs_max,
|
|
|
|
|
ops::FakeQuantOrWithDequantAbsMaxOp,
|
|
|
|
|
ops::FakeQuantOrWithDequantAbsMaxOpMaker,
|
|
|
|
|
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
|
|
|
|
|
ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
|
|
|
|
|
ops::FakeQuantizeRangeAbsMaxOpMaker,
|
|
|
|
|
@ -518,16 +575,14 @@ REGISTER_OPERATOR(
|
|
|
|
|
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
|
|
|
|
|
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
fake_quantize_dequantize_moving_average_abs_max,
|
|
|
|
|
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
|
|
|
|
|
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max,
|
|
|
|
|
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
|
|
|
|
|
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
|
|
|
|
|
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
fake_quantize_dequantize_moving_average_abs_max,
|
|
|
|
|
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);
|
|
|
|
|
@ -547,3 +602,7 @@ REGISTER_OPERATOR(
|
|
|
|
|
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
|
|
|
|
|
ops::MovingAverageAbsMaxScaleKernel<CPU, float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad,
|
|
|
|
|
ops::FakeQuantDequantGradKernel<CPU, float>);
|
|
|
|
|
|