|
|
|
@ -649,13 +649,18 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
|
|
|
|
|
"MovingAverageAbsMaxScale");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
|
|
|
|
|
"MovingAverageAbsMaxScale");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput("OutState")) {
|
|
|
|
|
ctx->SetOutputDim("OutState", {1});
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasOutput("OutAccum")) {
|
|
|
|
|
ctx->SetOutputDim("OutAccum", {1});
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasOutput("Out")) {
|
|
|
|
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
|
|
|
|
ctx->SetOutputDim("OutScale", {1});
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -673,6 +678,9 @@ class MovingAverageAbsMaxScaleOpMaker
|
|
|
|
|
AddInput("X", "(Tensor) Input is float data type.");
|
|
|
|
|
AddInput("InAccum", "Last accum.").AsDispensable();
|
|
|
|
|
AddInput("InState", "Last state.").AsDispensable();
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(Tensor) Output tensor is just equivalent to the input tensor.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddOutput("OutScale", " Current scale");
|
|
|
|
|
AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
|
|
|
|
|
AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
|
|
|
|
@ -693,7 +701,7 @@ $$Out = X$$
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
@ -701,9 +709,9 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto out_grad_name = framework::GradVarName("Out");
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
|
|
|
|
|
"FakeQuantDequantGradOp");
|
|
|
|
|
"StrightThroughEstimatorGradOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
|
|
|
|
|
"FakeQuantDequantGradOp");
|
|
|
|
|
"StrightThroughEstimatorGradOp");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
|
|
|
|
|
}
|
|
|
|
@ -717,13 +725,13 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class FakeQuantDequantGradMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
class StrightThroughEstimatorMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void Apply(GradOpPtr<T> grad_op) const override {
|
|
|
|
|
grad_op->SetType("fake_quantize_dequantize_grad");
|
|
|
|
|
grad_op->SetType("stright_throuth_estimator_grad");
|
|
|
|
|
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
|
|
|
|
grad_op->SetAttrMap(this->Attrs());
|
|
|
|
@ -744,11 +752,11 @@ REGISTER_OPERATOR(
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
|
|
|
|
|
ops::FakeQuantizeAbsMaxKernel<CPU, float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(fake_quantize_dequantize_abs_max,
|
|
|
|
|
ops::FakeQuantOrWithDequantAbsMaxOp,
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
fake_quantize_dequantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp,
|
|
|
|
|
ops::FakeQuantOrWithDequantAbsMaxOpMaker,
|
|
|
|
|
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
|
|
|
|
|
ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>);
|
|
|
|
|
|
|
|
|
@ -769,11 +777,12 @@ REGISTER_OPERATOR(
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
|
|
|
|
|
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max,
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
fake_quantize_dequantize_moving_average_abs_max,
|
|
|
|
|
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
|
|
|
|
|
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
|
|
|
|
|
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
fake_quantize_dequantize_moving_average_abs_max,
|
|
|
|
|
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);
|
|
|
|
@ -789,20 +798,22 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
|
|
|
|
|
ops::MovingAverageAbsMaxScaleOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
|
|
|
|
|
ops::MovingAverageAbsMaxScaleKernel<CPU, float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad,
|
|
|
|
|
ops::FakeQuantDequantGradKernel<CPU, float>);
|
|
|
|
|
REGISTER_OPERATOR(stright_throuth_estimator_grad,
|
|
|
|
|
ops::StrightThroughEstimatorGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(stright_throuth_estimator_grad,
|
|
|
|
|
ops::StrightThroughEstimatorGradKernel<CPU, float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(fake_channel_wise_quantize_dequantize_abs_max,
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
fake_channel_wise_quantize_dequantize_abs_max,
|
|
|
|
|
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
|
|
|
|
|
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
|
|
|
|
|
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
fake_channel_wise_quantize_dequantize_abs_max,
|
|
|
|
|
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CPU, float>);
|
|
|
|
@ -820,4 +831,8 @@ REGISTER_OP_VERSION(moving_average_abs_max_scale)
|
|
|
|
|
"Out",
|
|
|
|
|
"Delete output in order to make the inference model not "
|
|
|
|
|
"save moving_average_abs_max_scale operator. This will "
|
|
|
|
|
"make the quantitative model be correctly applied in inference."));
|
|
|
|
|
"make the quantitative model be correctly applied in inference."))
|
|
|
|
|
.AddCheckpoint(
|
|
|
|
|
R"ROC(Incompatible upgrade of output [Out])ROC",
|
|
|
|
|
paddle::framework::compatible::OpVersionDesc().NewOutput(
|
|
|
|
|
"Out", "In order to support dygraph qat, add output again."));
|
|
|
|
|