|
|
|
@ -65,6 +65,56 @@ def _all_persistable_var_names(program):
|
|
|
|
|
return persistable_var_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _remove_unused_var_nodes(graph):
|
|
|
|
|
all_used_vars = set()
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
for input_node in op_node.inputs:
|
|
|
|
|
all_used_vars.add(input_node)
|
|
|
|
|
for output_node in op_node.outputs:
|
|
|
|
|
all_used_vars.add(output_node)
|
|
|
|
|
|
|
|
|
|
all_used_vars = {n.node for n in all_used_vars}
|
|
|
|
|
all_unused_vars = {
|
|
|
|
|
n
|
|
|
|
|
for n in filter(lambda node: node.node not in all_used_vars,
|
|
|
|
|
graph.all_var_nodes())
|
|
|
|
|
}
|
|
|
|
|
graph.safe_remove_nodes(all_unused_vars)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _remove_ctrl_vars(graph):
|
|
|
|
|
remove_ctr_vars = set()
|
|
|
|
|
for node in graph.all_var_nodes():
|
|
|
|
|
if node.is_ctrl_var():
|
|
|
|
|
remove_ctr_vars.add(node)
|
|
|
|
|
graph.safe_remove_nodes(remove_ctr_vars)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _apply_pass(scope,
|
|
|
|
|
graph,
|
|
|
|
|
pass_name,
|
|
|
|
|
attrs=None,
|
|
|
|
|
attr_values=None,
|
|
|
|
|
debug=False):
|
|
|
|
|
ir_pass = core.get_pass(pass_name)
|
|
|
|
|
cpp_graph = graph.graph
|
|
|
|
|
if not cpp_graph.has('__param_scope__'):
|
|
|
|
|
cpp_graph.set_not_owned('__param_scope__', scope)
|
|
|
|
|
if attrs:
|
|
|
|
|
assert attr_values and len(attrs) == len(
|
|
|
|
|
attr_values), "Different number of pass attributes and their values."
|
|
|
|
|
for attr, value in zip(attrs, attr_values):
|
|
|
|
|
ir_pass.set(attr, value)
|
|
|
|
|
ir_pass.apply(cpp_graph)
|
|
|
|
|
if debug:
|
|
|
|
|
graph.draw('.', 'qat_fp32_{}'.format(pass_name), graph.all_op_nodes())
|
|
|
|
|
_remove_unused_var_nodes(graph)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PostTrainingQuantization(object):
|
|
|
|
|
"""
|
|
|
|
|
Utilizing post training quantization methon to quantize the FP32 model,
|
|
|
|
@ -89,6 +139,7 @@ class PostTrainingQuantization(object):
|
|
|
|
|
weight_bits=8,
|
|
|
|
|
activation_quantize_type='range_abs_max',
|
|
|
|
|
weight_quantize_type='channel_wise_abs_max',
|
|
|
|
|
optimize_model=False,
|
|
|
|
|
is_use_cache_file=False,
|
|
|
|
|
cache_dir="./temp_post_training"):
|
|
|
|
|
'''
|
|
|
|
@ -145,6 +196,14 @@ class PostTrainingQuantization(object):
|
|
|
|
|
the fake ops in saving quantized model, and we save the scale obtained
|
|
|
|
|
by post training quantization in fake ops. Compared to 'abs_max',
|
|
|
|
|
the model accuracy is usually higher when it is 'channel_wise_abs_max'.
|
|
|
|
|
optimize_model(bool, optional): If set optimize_model as True, it applies
|
|
|
|
|
some passes to the model before quantization, and it supports
|
|
|
|
|
`conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
|
|
|
|
|
weights are quantized by tensor-wise method, which means the weights
|
|
|
|
|
scale for all channel are the same. However, if fuse
|
|
|
|
|
`conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
|
|
|
|
|
be different. In address this problem, fuse the pattern before
|
|
|
|
|
quantization. Default False.
|
|
|
|
|
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
|
|
|
|
|
all temp data will be saved in memory. If set is_use_cache_file as True,
|
|
|
|
|
it will save temp data to disk. When the fp32 model is complex or
|
|
|
|
@ -240,6 +299,7 @@ class PostTrainingQuantization(object):
|
|
|
|
|
for op_type in self._quantizable_op_type:
|
|
|
|
|
assert op_type in self._support_quantize_op_type, \
|
|
|
|
|
op_type + " is not supported for quantization."
|
|
|
|
|
self._optimize_model = optimize_model
|
|
|
|
|
self._is_use_cache_file = is_use_cache_file
|
|
|
|
|
self._cache_dir = cache_dir
|
|
|
|
|
if self._is_use_cache_file and not os.path.exists(self._cache_dir):
|
|
|
|
@ -344,6 +404,10 @@ class PostTrainingQuantization(object):
|
|
|
|
|
executor=self._executor,
|
|
|
|
|
model_filename=self._model_filename,
|
|
|
|
|
params_filename=self._params_filename)
|
|
|
|
|
|
|
|
|
|
if self._optimize_model:
|
|
|
|
|
self._optimize_fp32_model()
|
|
|
|
|
|
|
|
|
|
feed_vars = [framework._get_var(str(var_name), self._program) \
|
|
|
|
|
for var_name in self._feed_list]
|
|
|
|
|
self._data_loader = io.DataLoader.from_generator(
|
|
|
|
@ -358,6 +422,16 @@ class PostTrainingQuantization(object):
|
|
|
|
|
self._data_loader.set_batch_generator(
|
|
|
|
|
self._batch_generator, places=self._place)
|
|
|
|
|
|
|
|
|
|
def _optimize_fp32_model(self):
|
|
|
|
|
'''
|
|
|
|
|
Fuse the `conv2d/depthwise_conv2d + bn` in FP32 model.
|
|
|
|
|
'''
|
|
|
|
|
_logger.info("Optimize FP32 model ...")
|
|
|
|
|
graph = IrGraph(core.Graph(self._program.desc), for_test=True)
|
|
|
|
|
graph = _remove_ctrl_vars(graph)
|
|
|
|
|
graph = _apply_pass(self._scope, graph, 'conv_bn_fuse_pass')
|
|
|
|
|
self._program = graph.to_program()
|
|
|
|
|
|
|
|
|
|
def _collect_target_varnames(self):
|
|
|
|
|
'''
|
|
|
|
|
Collect the variable names for sampling, and set activation
|
|
|
|
|