Support loading parameters from checkpoint to save quantized model (#31419)

* Support loading parameters from checkpoint to save quantized model

* Fix the unittest test_moving_average_abs_max_scale_op

* Add unittest of save_quantized_model from checkpoint

* Add comments to explain the function
pull/1/head
guofei 4 years ago committed by GitHub
parent da9dda5c9b
commit ef0dd3efed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -503,7 +503,7 @@ class QuantizedNoweightLayer(layers.Layer):
class MovingAverageAbsMaxScale(layers.Layer):
def __init__(self, name=None, moving_rate=0.9, dtype='float32'):
def __init__(self, layer=None, name=None, moving_rate=0.9, dtype='float32'):
r"""
MovingAverageMaxScale layer is used to calculating the output quantization scale of Layer.
Its computational formula is described as below:
@ -514,33 +514,48 @@ class MovingAverageAbsMaxScale(layers.Layer):
super(MovingAverageAbsMaxScale, self).__init__()
self._moving_rate = moving_rate
self._dtype = dtype
self._layer = layer
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
name = unique_name.generate(scale_prefix)
scale_attr = ParamAttr(
name=name, initializer=Constant(1), trainable=False)
self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=self._dtype)
self._scale.stop_gradient = True
if self._layer is None or not hasattr(self._layer, "_quant_out_scale"):
scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
scale_name = unique_name.generate(scale_prefix)
scale_attr = ParamAttr(
name=scale_name, initializer=Constant(1), trainable=False)
self._scale = self.create_parameter(
shape=[1], attr=scale_attr, dtype=self._dtype)
self._scale.stop_gradient = True
if self._layer is not None:
setattr(self._layer, "_quant_out_scale", self._scale)
else:
self._scale = self._layer._quant_out_scale
state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
self._state = self.create_parameter(
shape=[1], attr=state_attr, dtype=self._dtype)
self._state.stop_gradient = True
if self._layer is None or not hasattr(self._layer, "_quant_out_state"):
state_prefix = "{}.state".format(name) if name else 'outscale.state'
state_attr = ParamAttr(
name=unique_name.generate(state_prefix),
initializer=Constant(1),
trainable=False)
self._state = self.create_parameter(
shape=[1], attr=state_attr, dtype=self._dtype)
self._state.stop_gradient = True
if self._layer is not None:
setattr(self._layer, "_quant_out_state", self._state)
else:
self._state = self._layer._quant_out_state
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
self._accum = self.create_parameter(
shape=[1], attr=accum_attr, dtype=self._dtype)
self._accum.stop_gradient = True
MovingAverageAbsMaxScale._has_create = True
if self._layer is None or not hasattr(self._layer, "_quant_out_accum"):
accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
accum_attr = ParamAttr(
name=unique_name.generate(accum_prefix),
initializer=Constant(1),
trainable=False)
self._accum = self.create_parameter(
shape=[1], attr=accum_attr, dtype=self._dtype)
self._accum.stop_gradient = True
if self._layer is not None:
setattr(self._layer, "_quant_out_accum", self._accum)
else:
self._accum = self._layer._quant_out_accum
def forward(self, input):
if in_dygraph_mode():
@ -549,18 +564,17 @@ class MovingAverageAbsMaxScale(layers.Layer):
state = self._state if self.training else None
accum = self._accum if self.training else None
out_scale, _, _ = core.ops.moving_average_abs_max_scale(
self._scale, _, _ = core.ops.moving_average_abs_max_scale(
input, accum, state, self._scale, state, accum, *attrs)
return out_scale
return self._scale
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'MovingAverageAbsMaxScale')
scale_out = self._scale
attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training}
inputs = {"X": [input]}
outputs = {"OutScale": [scale_out]}
outputs = {"OutScale": [self._scale]}
if self.training:
inputs['InState'] = [self._state]
@ -574,4 +588,4 @@ class MovingAverageAbsMaxScale(layers.Layer):
outputs=outputs,
attrs=attrs)
return scale_out
return self._scale

@ -19,6 +19,8 @@ import numpy as np
import random
import unittest
import logging
import warnings
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
@ -29,7 +31,7 @@ from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass, OutScaleForInferencePass, QuantizationTransformPass
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, ReLU6
from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, PReLU
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D
from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.log_helper import get_logger
@ -45,6 +47,14 @@ _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def get_vaild_warning_num(warning, w):
num = 0
for i in range(len(w)):
if warning in str(w[i].message):
num += 1
return num
def StaticLenet(data, num_classes=10, classifier_activation='softmax'):
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
@ -76,9 +86,9 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'):
param_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr)
batch_norm2 = layers.batch_norm(conv2)
relu6_1 = layers.relu6(batch_norm2)
prelu1 = layers.prelu(batch_norm2, mode='all')
pool2 = fluid.layers.pool2d(
relu6_1, pool_size=2, pool_type='max', pool_stride=2)
prelu1, pool_size=2, pool_type='max', pool_stride=2)
fc1 = fluid.layers.fc(input=pool2,
size=120,
@ -132,7 +142,7 @@ class ImperativeLenet(fluid.dygraph.Layer):
weight_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr),
BatchNorm2D(16),
ReLU6(),
PReLU(),
MaxPool2D(
kernel_size=2, stride=2))
@ -246,6 +256,10 @@ class TestImperativeOutSclae(unittest.TestCase):
lenet.eval()
param_save_path = "test_save_quantized_model/lenet.pdparams"
save_dict = lenet.state_dict()
paddle.save(save_dict, param_save_path)
path = "./dynamic_outscale_infer_model/lenet"
dynamic_save_dir = "./dynamic_outscale_infer_model"
@ -285,6 +299,8 @@ class TestImperativeOutSclae(unittest.TestCase):
for param in main.all_parameters():
if "batch_norm" in param.name:
param_name = param.name.replace("norm", "norm2d")
elif 'prelu' in param.name:
param_name = param.name.replace("prelu", 'p_re_lu')
else:
param_name = param.name
param_tensor = scope.var(param.name).get_tensor()
@ -384,5 +400,94 @@ class TestImperativeOutSclae(unittest.TestCase):
static_ops[i].attr("out_threshold"))
class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
def test_save_quantized_model(self):
weight_quantize_type = 'abs_max'
activation_quantize_type = 'moving_average_abs_max'
load_param_path = "test_save_quantized_model/lenet.pdparams"
path = "./dynamic_outscale_infer_model_from_checkpoint/lenet"
dynamic_model_save_dir = "./dynamic_outscale_infer_model_from_checkpoint"
static_model_save_dir = "./static_outscale_infer_model"
imperative_out_scale = ImperativeQuantAware(
weight_quantize_type=weight_quantize_type,
activation_quantize_type=activation_quantize_type)
with fluid.dygraph.guard():
lenet = ImperativeLenet()
load_dict = paddle.load(load_param_path)
imperative_out_scale.quantize(lenet)
lenet.set_dict(load_dict)
imperative_out_scale.save_quantized_model(
layer=lenet,
path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
# load dynamic model
[dynamic_inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=dynamic_model_save_dir,
executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX))
# load static model
[static_inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=static_model_save_dir,
executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX))
dynamic_ops = dynamic_inference_program.global_block().ops
static_ops = static_inference_program.global_block().ops
for op in dynamic_ops[:]:
if op.type == "flatten2" or 'fake' in op.type:
dynamic_ops.remove(op)
for op in static_ops[:]:
if 'fake' in op.type:
static_ops.remove(op)
for i in range(len(dynamic_ops)):
if dynamic_ops[i].has_attr("out_threshold"):
self.assertTrue(dynamic_ops[i].type == static_ops[i].type)
self.assertTrue(dynamic_ops[i].attr("out_threshold") ==
static_ops[i].attr("out_threshold"))
class TestSaveQuantizedModel_Warning(unittest.TestCase):
def test_warning(self):
path = "./dynamic_outscale_infer_model_with_warnings/lenet"
imperative_out_scale = ImperativeQuantAware()
with fluid.dygraph.guard():
lenet = ImperativeLenet()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
imperative_out_scale.save_quantized_model(
layer=lenet,
path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
warning_message = "Warning: No Layer of the model while to be saved contains the out_threshold attribute, " \
"so the generated inference model would not contain the out_threshold."
num = get_vaild_warning_num(warning_message, w)
assert num == 1
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save