|
|
|
@ -26,35 +26,6 @@ __all__ = [
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resolve_hazard(graph):
|
|
|
|
|
def _to_node(nodes, node_name):
|
|
|
|
|
target_node = None
|
|
|
|
|
for n in nodes:
|
|
|
|
|
if n.name() == node_name:
|
|
|
|
|
target_node = n.node
|
|
|
|
|
assert target_node is not None, "Cannot find the target node in the giving set."
|
|
|
|
|
return target_node
|
|
|
|
|
|
|
|
|
|
ordered_nodes = graph.topology_sort()
|
|
|
|
|
var_nodes = dict()
|
|
|
|
|
for node in ordered_nodes:
|
|
|
|
|
if node.is_op() and node.op() is not None:
|
|
|
|
|
for each_var_name in node.op().input_arg_names():
|
|
|
|
|
if each_var_name not in var_nodes:
|
|
|
|
|
var_nodes[each_var_name] = [
|
|
|
|
|
_to_node(node.inputs, each_var_name)
|
|
|
|
|
]
|
|
|
|
|
for each_var_name in node.op().output_arg_names():
|
|
|
|
|
if each_var_name not in var_nodes:
|
|
|
|
|
var_nodes[each_var_name] = [
|
|
|
|
|
_to_node(node.outputs, each_var_name)
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
var_nodes[each_var_name].append(
|
|
|
|
|
_to_node(node.outputs, each_var_name))
|
|
|
|
|
graph.graph.resolve_hazard(var_nodes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuantizationTransformPass(object):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
scope=None,
|
|
|
|
@ -150,8 +121,8 @@ class QuantizationTransformPass(object):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(graph,
|
|
|
|
|
IrGraph), 'graph must be the instance of IrGraph.'
|
|
|
|
|
sequential_execution = core.get_pass('sequential_execution_pass')
|
|
|
|
|
sequential_execution.apply(graph.graph)
|
|
|
|
|
#sequential_execution = core.get_pass('sequential_execution_pass')
|
|
|
|
|
#sequential_execution.apply(graph.graph)
|
|
|
|
|
self._is_test = graph.is_test()
|
|
|
|
|
# marked the variable which has been dequantized.
|
|
|
|
|
dequantized_vars = collections.OrderedDict()
|
|
|
|
@ -216,7 +187,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
for op in ops:
|
|
|
|
|
if op.name() in self._quantizable_grad_ops:
|
|
|
|
|
_transform_backward(graph, op)
|
|
|
|
|
_resolve_hazard(graph)
|
|
|
|
|
graph.resolve_hazard()
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _create_global_step(self, graph):
|
|
|
|
@ -652,6 +623,7 @@ class QuantizationFreezePass(object):
|
|
|
|
|
|
|
|
|
|
# remove the unused var node in the graph
|
|
|
|
|
self._remove_unused_var_nodes(graph)
|
|
|
|
|
graph.resolve_hazard()
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
|
|
|
|
@ -895,6 +867,7 @@ class ConvertToInt8Pass(object):
|
|
|
|
|
|
|
|
|
|
# remove the unused var node in the graph
|
|
|
|
|
self._remove_unused_var_nodes(graph)
|
|
|
|
|
graph.resolve_hazard()
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _convert_to_int8(self, graph, var_node):
|
|
|
|
@ -977,5 +950,5 @@ class TransformForMobilePass(object):
|
|
|
|
|
for output_node in op_node.outputs:
|
|
|
|
|
graph.link_to(dequant_node, output_node)
|
|
|
|
|
graph.safe_remove_nodes(op_node)
|
|
|
|
|
|
|
|
|
|
graph.resolve_hazard()
|
|
|
|
|
return graph
|
|
|
|
|