|
|
|
|
@ -15,10 +15,11 @@
|
|
|
|
|
import numpy as np
|
|
|
|
|
from .... import core
|
|
|
|
|
from ....framework import IrGraph
|
|
|
|
|
from ....framework import IrNode
|
|
|
|
|
|
|
|
|
|
__all__ = ['Qat2Int8MkldnnPass']
|
|
|
|
|
|
|
|
|
|
OpRole = core.op_proto_and_checker_maker.OpRole
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Qat2Int8MkldnnPass(object):
|
|
|
|
|
"""
|
|
|
|
|
@ -62,6 +63,7 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
self._pool_ops = ['pool2d']
|
|
|
|
|
self._mul_ops = ['mul']
|
|
|
|
|
self._fc_ops = ['fc']
|
|
|
|
|
self._matmul_ops = ['matmul']
|
|
|
|
|
self._weight_scales = {}
|
|
|
|
|
# Collect the Input and Output sclaes from Fake QAT models
|
|
|
|
|
self._var_quant_scales = {}
|
|
|
|
|
@ -79,9 +81,9 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
graph = self._compute_weight_scales(graph)
|
|
|
|
|
graph = self._update_relu_output_scales(graph)
|
|
|
|
|
graph = self._propagate_scales(graph)
|
|
|
|
|
graph = self._set_dummy_fc_out_scales(graph)
|
|
|
|
|
graph = self._set_dummy_out_scales(graph)
|
|
|
|
|
graph = self._quantize_fp32_graph(graph)
|
|
|
|
|
graph = self._remove_unused_var_nodes(graph)
|
|
|
|
|
graph = self._cleanup(graph)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def apply_fp32(self, graph):
|
|
|
|
|
@ -92,7 +94,7 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
graph = self._remove_fake_ops(graph)
|
|
|
|
|
graph = self._dequantize_weights(graph)
|
|
|
|
|
graph = self._optimize_fp32_graph(graph)
|
|
|
|
|
graph = self._remove_unused_var_nodes(graph)
|
|
|
|
|
graph = self._cleanup(graph)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _convert_scale2tensor(self, scale):
|
|
|
|
|
@ -176,23 +178,29 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _set_dummy_fc_out_scales(self, graph):
|
|
|
|
|
def _set_dummy_out_scales(self, graph):
|
|
|
|
|
'''
|
|
|
|
|
For the output tensors of FC that do not have an assigned scale,
|
|
|
|
|
For the output tensors of fc, conv2d and matmul ops that do not have an assigned scale,
|
|
|
|
|
assign a dummy scale (same scale as input), so that the quantize pass
|
|
|
|
|
won't fail. In the end these scales aren't used, since FCs that
|
|
|
|
|
won't fail. In the end these scales aren't used, since the ops that
|
|
|
|
|
have an unassigend output scale will have a force_fp32_output attr
|
|
|
|
|
set to True.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
def _set_scale(op, op_types, input_names, output_name):
|
|
|
|
|
scales = self._var_quant_scales
|
|
|
|
|
should_set = op.name() in op_types \
|
|
|
|
|
and op.output(output_name)[0] not in scales \
|
|
|
|
|
and all(op.input(input_name)[0] in scales for input_name in input_names)
|
|
|
|
|
if should_set:
|
|
|
|
|
output_var_name = op.output(output_name)[0]
|
|
|
|
|
input_var_name = op.input(input_names[0])[0]
|
|
|
|
|
scales[output_var_name] = scales[input_var_name]
|
|
|
|
|
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name() in self._fc_ops:
|
|
|
|
|
input_name = op.input("Input")[0]
|
|
|
|
|
output_name = op.output("Out")[0]
|
|
|
|
|
if input_name in self._var_quant_scales and \
|
|
|
|
|
output_name not in self._var_quant_scales:
|
|
|
|
|
# use input scale as a "dummy" scale
|
|
|
|
|
self._var_quant_scales[
|
|
|
|
|
output_name] = self._var_quant_scales[input_name]
|
|
|
|
|
_set_scale(op, self._conv_ops, ["Input"], "Output")
|
|
|
|
|
_set_scale(op, self._fc_ops, ["Input"], "Out")
|
|
|
|
|
_set_scale(op, self._matmul_ops, ["X", "Y"], "Out")
|
|
|
|
|
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
@ -358,6 +366,15 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
self._remove_unused_var_nodes(graph)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _cleanup(self, graph):
|
|
|
|
|
# remove dropout ops
|
|
|
|
|
graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
|
|
|
|
|
# make some MKL-DNN ops working inplace
|
|
|
|
|
graph = self._apply_pass(graph, 'mkldnn_inplace_pass')
|
|
|
|
|
graph = self._remove_unused_var_nodes(graph)
|
|
|
|
|
graph = self._set_op_role_forward(graph)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _remove_unused_var_nodes(self, graph):
|
|
|
|
|
all_used_vars = set()
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
@ -376,8 +393,14 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
graph.safe_remove_nodes(all_unused_vars)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _set_op_role_forward(self, graph):
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op in ops:
|
|
|
|
|
op.set_attr("op_role", OpRole.Forward)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _compute_weight_scales(self, graph):
|
|
|
|
|
def _compute_var_scales(ops, out_name, w_name, axis):
|
|
|
|
|
def _compute_var_scales(ops, w_name, axis):
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.op().type() in ops:
|
|
|
|
|
weight_var_name = op.input(w_name)[0]
|
|
|
|
|
@ -394,8 +417,8 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
self._var_quant_scales[weight_var_name] = (use_unsigned_int,
|
|
|
|
|
lod_tensor)
|
|
|
|
|
|
|
|
|
|
_compute_var_scales(self._conv_ops, "Output", "Filter", axis=1)
|
|
|
|
|
_compute_var_scales(self._fc_ops, "Out", "W", axis=0)
|
|
|
|
|
_compute_var_scales(self._conv_ops, "Filter", axis=1)
|
|
|
|
|
_compute_var_scales(self._fc_ops, "W", axis=0)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _find_avg_pooling_ids(self, graph):
|
|
|
|
|
|