|
|
|
@ -183,7 +183,7 @@ class QuantizeTranspiler(object):
|
|
|
|
|
block, idx + 1, quant_var, scale_var, quant_bits)
|
|
|
|
|
dequanted_vars[block_id][name] = dequant_var
|
|
|
|
|
# rename the forward op inputs
|
|
|
|
|
op.rename_input(name, dequant_var.name)
|
|
|
|
|
op._rename_input(name, dequant_var.name)
|
|
|
|
|
|
|
|
|
|
def _transpile_backward(block, op):
|
|
|
|
|
block_id = block.idx
|
|
|
|
@ -191,7 +191,7 @@ class QuantizeTranspiler(object):
|
|
|
|
|
for name in op.input_arg_names:
|
|
|
|
|
if name in dequanted_vars[block_id]:
|
|
|
|
|
dequant_var = dequanted_vars[block_id][name]
|
|
|
|
|
op.rename_input(name, dequant_var.name)
|
|
|
|
|
op._rename_input(name, dequant_var.name)
|
|
|
|
|
no_dequanted_input_vars = False
|
|
|
|
|
if no_dequanted_input_vars:
|
|
|
|
|
raise ValueError("There is no dequanted inputs for op %s." %
|
|
|
|
@ -262,7 +262,7 @@ class QuantizeTranspiler(object):
|
|
|
|
|
scale_var = None
|
|
|
|
|
for name in op.input_arg_names:
|
|
|
|
|
if name in op_in_rename_map[block_id]:
|
|
|
|
|
op.rename_input(name, op_in_rename_map[block_id][name])
|
|
|
|
|
op._rename_input(name, op_in_rename_map[block_id][name])
|
|
|
|
|
|
|
|
|
|
scale_v = var_scale_map[block_id][_original_var_name(name)]
|
|
|
|
|
if _original_var_name(name) in persistable_vars:
|
|
|
|
@ -312,7 +312,8 @@ class QuantizeTranspiler(object):
|
|
|
|
|
# input of the followed ops
|
|
|
|
|
for name in op.input_arg_names:
|
|
|
|
|
if name in op_out_rename_map[block_id]:
|
|
|
|
|
op.rename_input(name, op_out_rename_map[block_id][name])
|
|
|
|
|
op._rename_input(name,
|
|
|
|
|
op_out_rename_map[block_id][name])
|
|
|
|
|
|
|
|
|
|
if op_type in self.fake_quant_op_types:
|
|
|
|
|
in_arg_name = op.input('X')[0]
|
|
|
|
@ -378,10 +379,11 @@ class QuantizeTranspiler(object):
|
|
|
|
|
if name not in input_map:
|
|
|
|
|
int8_var = convert_to_int8(var)
|
|
|
|
|
input_map[name] = int8_var.name
|
|
|
|
|
op.rename_input(name, input_map[name])
|
|
|
|
|
op._rename_input(name, input_map[name])
|
|
|
|
|
self._remove_unused_var(program)
|
|
|
|
|
|
|
|
|
|
def _remove_unused_var(self, program):
|
|
|
|
|
all_remove_vars = []
|
|
|
|
|
for block in program.blocks:
|
|
|
|
|
args = []
|
|
|
|
|
for op in block.ops:
|
|
|
|
@ -389,9 +391,16 @@ class QuantizeTranspiler(object):
|
|
|
|
|
args += op.output_arg_names
|
|
|
|
|
args = list(set(args))
|
|
|
|
|
var_names = block.vars.keys()
|
|
|
|
|
sub_block_remove_vars = []
|
|
|
|
|
for var in var_names:
|
|
|
|
|
if var not in args:
|
|
|
|
|
block._remove_var(var)
|
|
|
|
|
sub_block_remove_vars.append(var)
|
|
|
|
|
all_remove_vars.append(sub_block_remove_vars)
|
|
|
|
|
|
|
|
|
|
remove_vars = [list(set(v)) for v in all_remove_vars]
|
|
|
|
|
for i, block in enumerate(program.blocks):
|
|
|
|
|
for v in remove_vars[i]:
|
|
|
|
|
block._remove_var(v)
|
|
|
|
|
|
|
|
|
|
def _insert_quant_abs_max_op(self, block, idx, var, quant_bits):
|
|
|
|
|
"""Insert fake_quantize_abs_max op.
|
|
|
|
|