|
|
|
@ -17,7 +17,9 @@ import numpy as np
|
|
|
|
|
import six
|
|
|
|
|
from ..... import compat as cpt
|
|
|
|
|
from .... import core
|
|
|
|
|
from .... import Executor
|
|
|
|
|
from ....framework import IrGraph
|
|
|
|
|
from ....framework import IrNode
|
|
|
|
|
from ....framework import Program
|
|
|
|
|
from ....initializer import Constant
|
|
|
|
|
from .... import unique_name
|
|
|
|
@ -31,7 +33,7 @@ __all__ = [
|
|
|
|
|
class QuantizationTransformPass(object):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
scope=None,
|
|
|
|
|
program_exe=None,
|
|
|
|
|
place=None,
|
|
|
|
|
weight_bits=8,
|
|
|
|
|
activation_bits=8,
|
|
|
|
|
activation_quantize_type='abs_max',
|
|
|
|
@ -45,7 +47,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
|
|
|
|
|
type, this pass will create some new parameters. The scope is used to
|
|
|
|
|
initialize these new parameters.
|
|
|
|
|
program_exe(fluid.Executor): program_exe is used to initialize new
|
|
|
|
|
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
|
|
|
|
|
parameters described above.
|
|
|
|
|
weight_bits (int): quantization bit number for weights,
|
|
|
|
|
the bias is not quantized.
|
|
|
|
@ -71,13 +73,13 @@ class QuantizationTransformPass(object):
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
|
|
|
|
|
graph = IrGraph(core.Graph(program.desc), for_test=False)
|
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
|
|
|
place = fluid.CPUPlace()
|
|
|
|
|
transform_pass = QuantizationTransformPass(fluid.global_scope(),
|
|
|
|
|
exe)
|
|
|
|
|
place)
|
|
|
|
|
transform_pass.apply(graph)
|
|
|
|
|
"""
|
|
|
|
|
self._scope = scope
|
|
|
|
|
self._program_exe = program_exe
|
|
|
|
|
self._place = place
|
|
|
|
|
self._weight_bits = weight_bits
|
|
|
|
|
self._activation_bits = activation_bits
|
|
|
|
|
|
|
|
|
@ -118,7 +120,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
self._is_test = graph.is_test()
|
|
|
|
|
# marked the variable which has been dequantized.
|
|
|
|
|
dequantized_vars = collections.OrderedDict()
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
|
|
|
|
|
|
|
|
|
|
def _transform_forward(graph, op):
|
|
|
|
|
for var_node in op.inputs:
|
|
|
|
@ -149,7 +151,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
|
|
|
|
|
if not self._is_test:
|
|
|
|
|
self._create_global_step(graph)
|
|
|
|
|
ops = graph.all_ops()
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
# The process of _transform_forward and _transform_backward is needed in two for loops.
|
|
|
|
|
# The loop for transforming the forward graph:
|
|
|
|
|
for op in ops:
|
|
|
|
@ -163,8 +165,8 @@ class QuantizationTransformPass(object):
|
|
|
|
|
if len(self._need_initialized) > 0:
|
|
|
|
|
assert self._scope is not None, \
|
|
|
|
|
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
|
|
|
|
|
assert self._program_exe is not None, \
|
|
|
|
|
'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.'
|
|
|
|
|
assert self._place is not None, \
|
|
|
|
|
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
|
|
|
|
|
init_program = Program()
|
|
|
|
|
for var_desc, initializer in six.iteritems(self._need_initialized):
|
|
|
|
|
var = init_program.global_block().create_var(
|
|
|
|
@ -175,7 +177,8 @@ class QuantizationTransformPass(object):
|
|
|
|
|
lod_level=var_desc.lod_level(),
|
|
|
|
|
persistable=var_desc.persistable())
|
|
|
|
|
initializer(var, init_program.global_block())
|
|
|
|
|
self._program_exe.run(program=init_program, scope=self._scope)
|
|
|
|
|
exe = Executor(self._place)
|
|
|
|
|
exe.run(program=init_program, scope=self._scope)
|
|
|
|
|
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
@ -183,11 +186,11 @@ class QuantizationTransformPass(object):
|
|
|
|
|
if self._weight_quantize_type == 'range_abs_max' or \
|
|
|
|
|
self._activation_quantize_type == 'range_abs_max':
|
|
|
|
|
counter_name = cpt.to_text('@STEP_COUNTER@')
|
|
|
|
|
for node in graph.all_vars():
|
|
|
|
|
for node in graph.all_var_nodes():
|
|
|
|
|
if node.name() == counter_name:
|
|
|
|
|
self._global_step = node
|
|
|
|
|
if self._global_step is None:
|
|
|
|
|
global_step_in = graph.create_param_node(
|
|
|
|
|
global_step_in = graph.create_persistable_node(
|
|
|
|
|
name=counter_name,
|
|
|
|
|
var_type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
shape=[1],
|
|
|
|
@ -262,7 +265,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
shape=var_node.var().shape(),
|
|
|
|
|
var_dtype=var_node.var().dtype())
|
|
|
|
|
|
|
|
|
|
scale_in_node = graph.create_param_node(
|
|
|
|
|
scale_in_node = graph.create_persistable_node(
|
|
|
|
|
name=self._quantized_scale_name(var_node.name()),
|
|
|
|
|
var_type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
shape=[1],
|
|
|
|
@ -275,7 +278,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
|
|
|
|
|
if not self._is_test:
|
|
|
|
|
# The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
|
|
|
|
|
scales_node = graph.create_param_node(
|
|
|
|
|
scales_node = graph.create_persistable_node(
|
|
|
|
|
name=unique_name.generate('scales'),
|
|
|
|
|
var_type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
shape=[self._window_size],
|
|
|
|
@ -400,8 +403,8 @@ class QuantizationFreezePass(object):
|
|
|
|
|
Args:
|
|
|
|
|
graph(IrGraph): the applied graph.
|
|
|
|
|
"""
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
|
|
|
|
|
ops = graph.all_ops()
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
op_name = op_node.name()
|
|
|
|
|
if op_name in self._fake_quant_op_names:
|
|
|
|
@ -425,13 +428,13 @@ class QuantizationFreezePass(object):
|
|
|
|
|
self._weight_bits)
|
|
|
|
|
self._restore_var(input_arg_name, quantized_param_v)
|
|
|
|
|
|
|
|
|
|
ops = graph.all_ops()
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
op_name = op_node.name()
|
|
|
|
|
if op_name in self._fake_dequant_op_names:
|
|
|
|
|
self._remove_fake_quant_and_dequant_op(graph, op_node)
|
|
|
|
|
|
|
|
|
|
ops = graph.all_ops()
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
op_name = op_node.name()
|
|
|
|
|
if op_name in self._quantizable_ops:
|
|
|
|
@ -462,7 +465,7 @@ class QuantizationFreezePass(object):
|
|
|
|
|
def _insert_post_dequant_op(self, graph, op_node):
|
|
|
|
|
max_range = None
|
|
|
|
|
scale_var_node = None
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
|
|
|
|
|
for var_node in op_node.inputs:
|
|
|
|
|
name = var_node.name()
|
|
|
|
|
if name in self._op_input_rename_map:
|
|
|
|
@ -480,7 +483,7 @@ class QuantizationFreezePass(object):
|
|
|
|
|
original_var_name)
|
|
|
|
|
max_range = param_range * act_range / scale_v
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(scale_v, core.Node)
|
|
|
|
|
assert isinstance(scale_v, IrNode)
|
|
|
|
|
scale_var_node = self._var_scale_map[original_var_name]
|
|
|
|
|
|
|
|
|
|
if len(op_node.outputs) != 1:
|
|
|
|
@ -517,14 +520,19 @@ class QuantizationFreezePass(object):
|
|
|
|
|
|
|
|
|
|
def _remove_unused_var_nodes(self, graph):
|
|
|
|
|
all_used_vars = set()
|
|
|
|
|
ops = graph.all_ops()
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
for input_node in op_node.inputs:
|
|
|
|
|
all_used_vars.add(input_node)
|
|
|
|
|
for output_node in op_node.outputs:
|
|
|
|
|
all_used_vars.add(output_node)
|
|
|
|
|
|
|
|
|
|
all_unused_vars = graph.all_vars() - all_used_vars
|
|
|
|
|
all_used_vars = {n.node for n in all_used_vars}
|
|
|
|
|
all_unused_vars = {
|
|
|
|
|
n
|
|
|
|
|
for n in filter(lambda node: node.node not in all_used_vars,
|
|
|
|
|
graph.all_var_nodes())
|
|
|
|
|
}
|
|
|
|
|
graph.safe_remove_nodes(all_unused_vars)
|
|
|
|
|
|
|
|
|
|
def _original_var_name(self, var_name):
|
|
|
|
@ -583,8 +591,8 @@ class ConvertToInt8Pass(object):
|
|
|
|
|
Args:
|
|
|
|
|
graph(IrGraph): the applied graph.
|
|
|
|
|
"""
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
|
|
|
|
|
ops = graph.all_ops()
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
input_map = {}
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
op_name = op_node.name()
|
|
|
|
@ -605,7 +613,7 @@ class ConvertToInt8Pass(object):
|
|
|
|
|
|
|
|
|
|
def _convert_to_int8(self, graph, var_node):
|
|
|
|
|
int8_var_node_name = var_node.name() + ".int8"
|
|
|
|
|
int8_var_node = graph.create_param_node(
|
|
|
|
|
int8_var_node = graph.create_persistable_node(
|
|
|
|
|
name=cpt.to_text(int8_var_node_name),
|
|
|
|
|
var_type=var_node.var().type(),
|
|
|
|
|
shape=var_node.var().shape(),
|
|
|
|
@ -624,14 +632,19 @@ class ConvertToInt8Pass(object):
|
|
|
|
|
|
|
|
|
|
def _remove_unused_var_nodes(self, graph):
|
|
|
|
|
all_used_vars = set()
|
|
|
|
|
ops = graph.all_ops()
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
for input_node in op_node.inputs:
|
|
|
|
|
all_used_vars.add(input_node)
|
|
|
|
|
for output_node in op_node.outputs:
|
|
|
|
|
all_used_vars.add(output_node)
|
|
|
|
|
|
|
|
|
|
all_unused_vars = graph.all_vars() - all_used_vars
|
|
|
|
|
all_used_vars = {n.node for n in all_used_vars}
|
|
|
|
|
all_unused_vars = {
|
|
|
|
|
n
|
|
|
|
|
for n in filter(lambda node: node.node not in all_used_vars,
|
|
|
|
|
graph.all_var_nodes())
|
|
|
|
|
}
|
|
|
|
|
graph.safe_remove_nodes(all_unused_vars)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -655,7 +668,7 @@ class TransformForMobilePass(object):
|
|
|
|
|
Args:
|
|
|
|
|
graph(IrGraph): the graph will be transformed.
|
|
|
|
|
"""
|
|
|
|
|
ops = graph.all_ops()
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
name = op_node.name()
|
|
|
|
|
if name in self._fake_quant_op_names:
|
|
|
|
|