[dygraph qat] Use layer to calculate output scale (#31861)

* Use layer to calculate output scale
* add backward for moving_average_abs_max_scale and save output scales to op's attr
develop
cc 4 years ago committed by GitHub
parent c3974d0e2a
commit b47478efc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -649,13 +649,18 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
"MovingAverageAbsMaxScale");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"MovingAverageAbsMaxScale");
if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1});
}
if (ctx->HasOutput("OutAccum")) {
ctx->SetOutputDim("OutAccum", {1});
}
ctx->SetOutputDim("OutScale", {1});
if (ctx->HasOutput("Out")) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected:
@ -673,6 +678,9 @@ class MovingAverageAbsMaxScaleOpMaker
AddInput("X", "(Tensor) Input is float data type.");
AddInput("InAccum", "Last accum.").AsDispensable();
AddInput("InState", "Last state.").AsDispensable();
AddOutput("Out",
"(Tensor) Output tensor is just equivalent to the input tensor.")
.AsDispensable();
AddOutput("OutScale", " Current scale");
AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
@ -693,7 +701,7 @@ $$Out = X$$
}
};
class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
@ -701,9 +709,9 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
auto out_grad_name = framework::GradVarName("Out");
auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
"FakeQuantDequantGradOp");
"StrightThroughEstimatorGradOp");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
"FakeQuantDequantGradOp");
"StrightThroughEstimatorGradOp");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
}
@ -717,13 +725,13 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
};
template <typename T>
class FakeQuantDequantGradMaker : public framework::SingleGradOpMaker<T> {
class StrightThroughEstimatorMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("fake_quantize_dequantize_grad");
grad_op->SetType("stright_throuth_estimator_grad");
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
@ -744,11 +752,11 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize_dequantize_abs_max,
ops::FakeQuantOrWithDequantAbsMaxOp,
ops::FakeQuantOrWithDequantAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
fake_quantize_dequantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp,
ops::FakeQuantOrWithDequantAbsMaxOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>);
@ -769,11 +777,12 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);
@ -789,20 +798,22 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
REGISTER_OPERATOR(
moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
ops::MovingAverageAbsMaxScaleOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad,
ops::FakeQuantDequantGradKernel<CPU, float>);
REGISTER_OPERATOR(stright_throuth_estimator_grad,
ops::StrightThroughEstimatorGradOp);
REGISTER_OP_CPU_KERNEL(stright_throuth_estimator_grad,
ops::StrightThroughEstimatorGradKernel<CPU, float>);
REGISTER_OPERATOR(fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CPU, float>);
@ -820,4 +831,8 @@ REGISTER_OP_VERSION(moving_average_abs_max_scale)
"Out",
"Delete output in order to make the inference model not "
"save moving_average_abs_max_scale operator. This will "
"make the quantitative model be correctly applied in inference."));
"make the quantitative model be correctly applied in inference."))
.AddCheckpoint(
R"ROC(Incompatible upgrade of output [Out])ROC",
paddle::framework::compatible::OpVersionDesc().NewOutput(
"Out", "In order to support dygraph qat, add output again."));

@ -543,8 +543,8 @@ REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
REGISTER_OP_CUDA_KERNEL(
fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_grad,
ops::FakeQuantDequantGradKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(stright_throuth_estimator_grad,
ops::StrightThroughEstimatorGradKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CUDA, float>);

@ -314,6 +314,12 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
auto* in = context.Input<framework::Tensor>("X");
auto& dev_ctx = context.template device_context<DeviceContext>();
if (context.HasOutput("Out")) {
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
}
bool is_test = context.Attr<bool>("is_test");
// testing
if (is_test) {
@ -344,17 +350,17 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
};
template <typename DeviceContext, typename T>
class FakeQuantDequantGradKernel : public framework::OpKernel<T> {
class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X");
auto* d_x = context.Output<framework::LoDTensor>(x_grad_name);
PADDLE_ENFORCE_NOT_NULL(
d_x, platform::errors::PreconditionNotMet(
"FakeQuantDequantGradOp doesn't have the output named %s.",
x_grad_name));
PADDLE_ENFORCE_NOT_NULL(d_x, platform::errors::PreconditionNotMet(
"StrightThroughEstimatorGradKernel "
"doesn't have the output named %s.",
x_grad_name));
// Initialize dx as same as d_out
d_x->mutable_data<T>(context.GetPlace());

@ -84,7 +84,8 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"matrix_nms", {"Out", "Index", "RoisNum"}},
{"distribute_fpn_proposals",
{"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
{"moving_average_abs_max_scale",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"multiclass_nms3", {"Out", "NmsRoisNum"}},
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"momentum", {"ParamOut", "VelocityOut"}},
@ -137,7 +138,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"check_finite_and_unscale", {"Out", "FoundInfinite"}},
{"update_loss_scaling",
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
{"moving_average_abs_max_scale",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"rnn", {"DropoutState"}},

@ -507,59 +507,42 @@ class QuantizedNoweightLayer(layers.Layer):
class MovingAverageAbsMaxScale(layers.Layer):
def __init__(self, layer=None, name=None, moving_rate=0.9, dtype='float32'):
def __init__(self, name=None, moving_rate=0.9, dtype='float32'):
r"""
MovingAverageMaxScale layer is used to calculating the output quantization scale of Layer.
Its computational formula is described as below:
MovingAverageMaxScale layer is used to calculating the output quantization
scale of Layer. Its computational formula is described as below:
:math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)`
:math:`Out = X`
"""
super(MovingAverageAbsMaxScale, self).__init__()
self._moving_rate = moving_rate
self._dtype = dtype
self._layer = layer
if self._layer is None or not hasattr(self._layer, "_quant_out_scale"):
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
scale_name = unique_name.generate(scale_prefix)
scale_attr = ParamAttr(
name=scale_name, initializer=Constant(1), trainable=False)
self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=self._dtype)
self._scale.stop_gradient = True
if self._layer is not None:
setattr(self._layer, "_quant_out_scale", self._scale)
else:
self._scale = self._layer._quant_out_scale
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
scale_name = unique_name.generate(scale_prefix)
scale_attr = ParamAttr(
name=scale_name, initializer=Constant(1), trainable=False)
self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=dtype)
self._scale.stop_gradient = True
if self._layer is None or not hasattr(self._layer, "_quant_out_state"):
state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
self._state = self.create_parameter(
shape=[1], attr=state_attr, dtype=self._dtype)
self._state.stop_gradient = True
if self._layer is not None:
setattr(self._layer, "_quant_out_state", self._state)
else:
self._state = self._layer._quant_out_state
state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
self._state = self.create_parameter(
shape=[1], attr=state_attr, dtype=dtype)
self._state.stop_gradient = True
if self._layer is None or not hasattr(self._layer, "_quant_out_accum"):
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
self._accum = self.create_parameter(
shape=[1], attr=accum_attr, dtype=self._dtype)
self._accum.stop_gradient = True
if self._layer is not None:
setattr(self._layer, "_quant_out_accum", self._accum)
else:
self._accum = self._layer._quant_out_accum
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
self._accum = self.create_parameter(
shape=[1], attr=accum_attr, dtype=dtype)
self._accum.stop_gradient = True
def forward(self, input):
if in_dygraph_mode():
@ -567,18 +550,30 @@ class MovingAverageAbsMaxScale(layers.Layer):
not self.training)
state = self._state if self.training else None
accum = self._accum if self.training else None
quant_out = _varbase_creator(
type=input.type,
name="{}.tmp".format(input.name),
shape=input.shape,
dtype=input.dtype,
persistable=False)
self._scale, _, _ = core.ops.moving_average_abs_max_scale(
input, accum, state, self._scale, state, accum, *attrs)
return self._scale
out, _, _, _ = core.ops.moving_average_abs_max_scale(
input, accum, state, quant_out, self._scale, state, accum,
*attrs)
return out
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'MovingAverageAbsMaxScale')
attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training}
inputs = {"X": [input]}
outputs = {"OutScale": [self._scale]}
quant_out = self._helper.create_variable(
name="{}.tmp".format(input.name),
dtype=input.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
outputs = {"Out": [quant_out], "OutScale": [self._scale]}
if self.training:
inputs['InState'] = [self._state]
@ -592,4 +587,22 @@ class MovingAverageAbsMaxScale(layers.Layer):
outputs=outputs,
attrs=attrs)
return self._scale
return quant_out
class QuantizedOutputLayer(layers.Layer):
def __init__(self, layer=None, moving_rate=0.9, dtype='float32'):
r"""
Add MovingAverageMaxScale layer to the behind of the input layer.
"""
super(QuantizedOutputLayer, self).__init__()
self._layer = layer
self._moving_average_abs_max_scale = \
MovingAverageAbsMaxScale(layer.full_name(), moving_rate, dtype)
def forward(self, input):
if isinstance(input, list):
assert len(input) == 1, \
"The QuantizedOutputLayer should only have one input."
out = self._layer(input)
return self._moving_average_abs_max_scale(out)

@ -13,22 +13,7 @@
# limitations under the License.
import paddle
op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"softmax": [["X"], ["Out"]],
"relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]],
"prelu": [["X"], ["Out"]],
"tanh": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
}
import numpy as np
quant_input_layers_map = {
'Conv2D': paddle.nn.Conv2D,
@ -85,3 +70,33 @@ weight_op_types = [
"conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose",
"depthwise_conv2d_transpose"
]
def load_variable_data(scope, var_name):
'''
Load variable value from scope
'''
var_node = scope.find_var(var_name)
assert var_node is not None, \
"Can not find " + var_name + " in the scope."
return np.array(var_node.get_tensor())
def find_previous_op(block, var_name):
"""
Find the previous op for the input variable.
"""
for op in block.ops:
if var_name in op.output_arg_names:
return op
def find_next_ops(block, var_name):
"""
Find all followed ops for the input variable.
"""
res_ops = []
for op in block.ops:
if var_name in op.input_arg_names:
res_ops.append(op)
return res_ops

@ -478,30 +478,5 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
self.assertTrue(op_count == 14)
class TestSaveQuantizedModel_Warning(unittest.TestCase):
def test_warning(self):
path = "./dynamic_outscale_infer_model_with_warnings/lenet"
imperative_out_scale = ImperativeQuantAware()
with fluid.dygraph.guard():
lenet = ImperativeLenet()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
imperative_out_scale.save_quantized_model(
layer=lenet,
path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
warning_message = "Warning: No Layer of the model while to be " \
"saved contains the out_threshold attribute, so the " \
"generated inference model would not contain the " \
"out_threshold."
num = get_vaild_warning_num(warning_message, w)
assert num == 1
if __name__ == '__main__':
unittest.main()

@ -166,12 +166,14 @@ class TestMovingAverageAbsMaxScaleOp(OpTest):
accum[0] = 1
state = np.zeros(1).astype("float32")
state[0] = 1
x = np.random.random((8, 16, 7, 7)).astype("float32")
self.inputs = {
'X': np.random.random((8, 16, 7, 7)).astype("float32"),
'X': x,
'InAccum': accum,
'InState': state,
}
out = x
out_accum = np.zeros(1).astype("float32")
out_state = np.zeros(1).astype("float32")
out_scale = np.zeros(1).astype("float32")
@ -180,6 +182,7 @@ class TestMovingAverageAbsMaxScaleOp(OpTest):
out_state[0] = self.attrs['moving_rate'] * state[0] + 1
out_scale = out_accum / out_state
self.outputs = {
'Out': out,
'OutAccum': out_accum,
'OutState': out_state,
'OutScale': out_scale,

Loading…
Cancel
Save