|
|
|
@ -46,6 +46,7 @@ _op_real_in_out_name = {
|
|
|
|
|
"conv2d": [["Input", "Filter"], ["Output"]],
|
|
|
|
|
"depthwise_conv2d": [["Input"], ["Output"]],
|
|
|
|
|
"mul": [["X", "Y"], ["Out"]],
|
|
|
|
|
"matmul": [["X", "Y"], ["Out"]],
|
|
|
|
|
"pool2d": [["X"], ["Out"]],
|
|
|
|
|
"elementwise_add": [["X", "Y"], ["Out"]],
|
|
|
|
|
"concat": [["X"], ["Out"]],
|
|
|
|
@ -87,8 +88,25 @@ def _init_var_node(var_node, value, scope, place):
|
|
|
|
|
tensor.set(value, place)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_input_all_not_persistable(graph, op_node):
|
|
|
|
|
'''
|
|
|
|
|
Analyse the real inputs of the op node are all not persistable.
|
|
|
|
|
'''
|
|
|
|
|
is_input_all_not_persistable = True
|
|
|
|
|
op_node_name = op_node.name()
|
|
|
|
|
input_name_list = _op_real_in_out_name[op_node_name][0]
|
|
|
|
|
for input_name in input_name_list:
|
|
|
|
|
for arg_name in op_node.input(input_name):
|
|
|
|
|
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
|
|
|
|
|
is_input_all_not_persistable = (is_input_all_not_persistable and \
|
|
|
|
|
(not in_node.persistable()))
|
|
|
|
|
return is_input_all_not_persistable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuantizationTransformPass(object):
|
|
|
|
|
_supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
|
|
|
|
|
_supported_quantizable_op_type = [
|
|
|
|
|
'conv2d', 'depthwise_conv2d', 'mul', 'matmul'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
scope=None,
|
|
|
|
@ -225,7 +243,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
dequant_var_node = dequantized_vars[var_node.name()]
|
|
|
|
|
else:
|
|
|
|
|
quant_bits = self._weight_bits if var_node.name() in persistable_vars \
|
|
|
|
|
else self._activation_bits
|
|
|
|
|
else self._activation_bits
|
|
|
|
|
quant_type = self._weight_quantize_type if var_node.name() \
|
|
|
|
|
in persistable_vars else self._activation_quantize_type
|
|
|
|
|
if quant_type == 'channel_wise_abs_max':
|
|
|
|
@ -252,17 +270,12 @@ class QuantizationTransformPass(object):
|
|
|
|
|
graph.update_input_link(var_node, dequant_var_node, op)
|
|
|
|
|
|
|
|
|
|
def _transform_backward(graph, op):
|
|
|
|
|
no_dequanted_input_vars = True
|
|
|
|
|
for var_node in op.inputs:
|
|
|
|
|
if var_node.name() not in op.input_arg_names():
|
|
|
|
|
continue
|
|
|
|
|
if var_node.name() in dequantized_vars:
|
|
|
|
|
dequant_var_node = dequantized_vars[var_node.name()]
|
|
|
|
|
graph.update_input_link(var_node, dequant_var_node, op)
|
|
|
|
|
no_dequanted_input_vars = False
|
|
|
|
|
if no_dequanted_input_vars:
|
|
|
|
|
raise ValueError("There is no dequanted inputs for op %s." %
|
|
|
|
|
(op.name()))
|
|
|
|
|
|
|
|
|
|
if not self._is_test:
|
|
|
|
|
self._create_global_step(graph)
|
|
|
|
@ -277,18 +290,11 @@ class QuantizationTransformPass(object):
|
|
|
|
|
# The loop for transforming the forward graph:
|
|
|
|
|
for op in ops:
|
|
|
|
|
if op.name() in self._quantizable_ops:
|
|
|
|
|
skipped = op.op().has_attr("skip_quant") and \
|
|
|
|
|
op.op().attr("skip_quant")
|
|
|
|
|
if skipped:
|
|
|
|
|
continue
|
|
|
|
|
_transform_forward(graph, op)
|
|
|
|
|
if not QuantizationTransformPass._is_skip_quant(graph, op):
|
|
|
|
|
_transform_forward(graph, op)
|
|
|
|
|
# The loop for renaming the inputs of backward op.
|
|
|
|
|
for op in ops:
|
|
|
|
|
if op.name() in self._quantizable_grad_ops:
|
|
|
|
|
skipped = op.op().has_attr("skip_quant") and \
|
|
|
|
|
op.op().attr("skip_quant")
|
|
|
|
|
if skipped:
|
|
|
|
|
continue
|
|
|
|
|
_transform_backward(graph, op)
|
|
|
|
|
graph.resolve_hazard()
|
|
|
|
|
return graph
|
|
|
|
@ -630,6 +636,22 @@ class QuantizationTransformPass(object):
|
|
|
|
|
"""
|
|
|
|
|
return "%s.scale" % (var_name)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _is_skip_quant(graph, op_node):
|
|
|
|
|
"""
|
|
|
|
|
Analyse whether the op node skips quantization.
|
|
|
|
|
"""
|
|
|
|
|
is_skip = False
|
|
|
|
|
if op_node.op().has_attr("skip_quant") and \
|
|
|
|
|
op_node.op().attr("skip_quant"):
|
|
|
|
|
is_skip = True
|
|
|
|
|
# if the inputs of mul and matmul are not all persistable, use
|
|
|
|
|
# AddQuantDequantPass to quantize them.
|
|
|
|
|
if op_node.name() in ["mul", "matmul"] and \
|
|
|
|
|
_is_input_all_not_persistable(graph, op_node):
|
|
|
|
|
is_skip = True
|
|
|
|
|
return is_skip
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuantizationFreezePass(object):
|
|
|
|
|
_supported_quantizable_op_type = \
|
|
|
|
@ -733,14 +755,17 @@ class QuantizationFreezePass(object):
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
op_name = op_node.name()
|
|
|
|
|
if op_name in self._quantizable_ops:
|
|
|
|
|
skipped = op_node.op().has_attr("skip_quant") and \
|
|
|
|
|
op_node.op().attr("skip_quant")
|
|
|
|
|
if skipped:
|
|
|
|
|
continue
|
|
|
|
|
if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
|
|
|
|
|
self._insert_post_channel_dequant_op(graph, op_node)
|
|
|
|
|
else:
|
|
|
|
|
self._insert_post_dequant_op(graph, op_node)
|
|
|
|
|
# only process the node that is quantized by QuantizationTransformPass
|
|
|
|
|
is_op_node_quantized = False
|
|
|
|
|
for var_node in op_node.inputs:
|
|
|
|
|
var_name = var_node.name()
|
|
|
|
|
if var_name.endswith('.dequantized'):
|
|
|
|
|
is_op_node_quantized = True
|
|
|
|
|
if is_op_node_quantized:
|
|
|
|
|
if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
|
|
|
|
|
self._insert_post_channel_dequant_op(graph, op_node)
|
|
|
|
|
else:
|
|
|
|
|
self._insert_post_dequant_op(graph, op_node)
|
|
|
|
|
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
|
|
|
|
@ -829,10 +854,6 @@ class QuantizationFreezePass(object):
|
|
|
|
|
|
|
|
|
|
def _insert_post_dequant_op(self, graph, op_node):
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
|
|
|
|
|
if len(op_node.input_arg_names()) >= 2 and len(persistable_vars) == 0:
|
|
|
|
|
raise ValueError("The op %s has more than one inputs "
|
|
|
|
|
"and all of them are not persistable. "
|
|
|
|
|
"Now, it is not supported!" % (op_node.name()))
|
|
|
|
|
max_range = 1
|
|
|
|
|
param_range = (1 << (self._weight_bits - 1)) - 1
|
|
|
|
|
act_range = (1 << (self._activation_bits - 1)) - 1
|
|
|
|
@ -987,9 +1008,7 @@ class ConvertToInt8Pass(object):
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
op_name = op_node.name()
|
|
|
|
|
if op_name in self._quantizable_ops:
|
|
|
|
|
skipped = op_node.op().has_attr("skip_quant") and \
|
|
|
|
|
op_node.op().attr("skip_quant")
|
|
|
|
|
if skipped:
|
|
|
|
|
if QuantizationTransformPass._is_skip_quant(graph, op_node):
|
|
|
|
|
continue
|
|
|
|
|
for var_node in op_node.inputs:
|
|
|
|
|
name = var_node.name()
|
|
|
|
@ -1240,7 +1259,7 @@ class AddQuantDequantPass(object):
|
|
|
|
|
"equal", "gather", "greater_equal", "greater_than", "less_equal",
|
|
|
|
|
"less_than", "mean", "not_equal", "reshape", "reshape2",
|
|
|
|
|
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
|
|
|
|
|
"squeeze", "elementwise_sub"
|
|
|
|
|
"squeeze", "elementwise_sub", "mul", "matmul"
|
|
|
|
|
]
|
|
|
|
|
_activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]
|
|
|
|
|
|
|
|
|
@ -1317,34 +1336,38 @@ class AddQuantDequantPass(object):
|
|
|
|
|
all_op_nodes = graph.all_op_nodes()
|
|
|
|
|
for op_node in all_op_nodes:
|
|
|
|
|
if op_node.name() in self._quantizable_op_type:
|
|
|
|
|
user_skipped = False
|
|
|
|
|
is_skip = False
|
|
|
|
|
if isinstance(self._skip_pattern, list):
|
|
|
|
|
user_skipped = op_node.op().has_attr("op_namescope") and \
|
|
|
|
|
is_skip = op_node.op().has_attr("op_namescope") and \
|
|
|
|
|
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
|
|
|
|
|
elif isinstance(self._skip_pattern, str):
|
|
|
|
|
user_skipped = op_node.op().has_attr("op_namescope") and \
|
|
|
|
|
is_skip = op_node.op().has_attr("op_namescope") and \
|
|
|
|
|
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
|
|
|
|
|
|
|
|
|
|
if user_skipped:
|
|
|
|
|
continue
|
|
|
|
|
is_op_node_quantized = False
|
|
|
|
|
for var_node in op_node.inputs:
|
|
|
|
|
var_name = var_node.name()
|
|
|
|
|
if var_name.endswith('.dequantized'):
|
|
|
|
|
is_op_node_quantized = True
|
|
|
|
|
|
|
|
|
|
if not self._is_input_all_not_persistable(graph, op_node):
|
|
|
|
|
if is_skip or is_op_node_quantized or \
|
|
|
|
|
(not _is_input_all_not_persistable(graph, op_node)):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
input_name_list = _op_real_in_out_name[op_node.name()][0]
|
|
|
|
|
arg_names = []
|
|
|
|
|
for input_name in input_name_list:
|
|
|
|
|
for arg_name in op_node.input(input_name):
|
|
|
|
|
in_node = graph._find_node_by_name(op_node.inputs,
|
|
|
|
|
arg_name)
|
|
|
|
|
if arg_name in dequantized_vars_map:
|
|
|
|
|
quant_var_node = dequantized_vars_map[arg_name]
|
|
|
|
|
else:
|
|
|
|
|
quant_var_node, _ = \
|
|
|
|
|
self._inser_quant_dequant_moving_average_abs_max_op(
|
|
|
|
|
graph, in_node, self._quant_bits)
|
|
|
|
|
dequantized_vars_map[arg_name] = quant_var_node
|
|
|
|
|
graph.update_input_link(in_node, quant_var_node,
|
|
|
|
|
op_node)
|
|
|
|
|
arg_names.extend(op_node.input(input_name))
|
|
|
|
|
for arg_name in arg_names:
|
|
|
|
|
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
|
|
|
|
|
if arg_name in dequantized_vars_map:
|
|
|
|
|
quant_var_node = dequantized_vars_map[arg_name]
|
|
|
|
|
else:
|
|
|
|
|
quant_var_node, _ = \
|
|
|
|
|
self._inser_quant_dequant_moving_average_abs_max_op(
|
|
|
|
|
graph, in_node, self._quant_bits)
|
|
|
|
|
dequantized_vars_map[arg_name] = quant_var_node
|
|
|
|
|
graph.update_input_link(in_node, quant_var_node, op_node)
|
|
|
|
|
|
|
|
|
|
# Backward stage, update input link
|
|
|
|
|
for op_node in all_op_nodes:
|
|
|
|
@ -1360,21 +1383,6 @@ class AddQuantDequantPass(object):
|
|
|
|
|
graph.resolve_hazard()
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _is_input_all_not_persistable(self, graph, op_node):
|
|
|
|
|
'''
|
|
|
|
|
Analyse the real inputs of the op node are all not persistable.
|
|
|
|
|
'''
|
|
|
|
|
is_input_all_not_persistable = True
|
|
|
|
|
op_node_name = op_node.name()
|
|
|
|
|
|
|
|
|
|
input_name_list = _op_real_in_out_name[op_node_name][0]
|
|
|
|
|
for input_name in input_name_list:
|
|
|
|
|
for arg_name in op_node.input(input_name):
|
|
|
|
|
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
|
|
|
|
|
is_input_all_not_persistable = (is_input_all_not_persistable and \
|
|
|
|
|
(not in_node.persistable()))
|
|
|
|
|
return is_input_all_not_persistable
|
|
|
|
|
|
|
|
|
|
def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
|
|
|
|
|
quant_bits):
|
|
|
|
|
"""Insert fake_quantize_dequantize_moving_average_abs_max op.
|
|
|
|
|