Use the resolve hazard method.

move-code
Zhen Wang 6 years ago
parent 2ccbfd5e10
commit 1c11f817e9

@ -26,35 +26,6 @@ __all__ = [
]
def _resolve_hazard(graph):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n.node
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = graph.topology_sort()
var_nodes = dict()
for node in ordered_nodes:
if node.is_op() and node.op() is not None:
for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name)
]
for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name)
]
else:
var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name))
graph.graph.resolve_hazard(var_nodes)
class QuantizationTransformPass(object):
def __init__(self,
scope=None,
@ -150,8 +121,8 @@ class QuantizationTransformPass(object):
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
sequential_execution = core.get_pass('sequential_execution_pass')
sequential_execution.apply(graph.graph)
#sequential_execution = core.get_pass('sequential_execution_pass')
#sequential_execution.apply(graph.graph)
self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
@ -216,7 +187,7 @@ class QuantizationTransformPass(object):
for op in ops:
if op.name() in self._quantizable_grad_ops:
_transform_backward(graph, op)
_resolve_hazard(graph)
graph.resolve_hazard()
return graph
def _create_global_step(self, graph):
@ -652,6 +623,7 @@ class QuantizationFreezePass(object):
# remove the unused var node in the graph
self._remove_unused_var_nodes(graph)
graph.resolve_hazard()
return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
@ -895,6 +867,7 @@ class ConvertToInt8Pass(object):
# remove the unused var node in the graph
self._remove_unused_var_nodes(graph)
graph.resolve_hazard()
return graph
def _convert_to_int8(self, graph, var_node):
@ -977,5 +950,5 @@ class TransformForMobilePass(object):
for output_node in op_node.outputs:
graph.link_to(dequant_node, output_node)
graph.safe_remove_nodes(op_node)
graph.resolve_hazard()
return graph

@ -2253,6 +2253,34 @@ class IrGraph(object):
original_nodes = {n.node for n in remove_nodes}
core.graph_safe_remove_nodes(self.graph, original_nodes)
def resolve_hazard(self):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = core.topology_sort(self.graph)
var_nodes = dict()
for node in ordered_nodes:
if node.is_op() and node.op() is not None:
for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name)
]
for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name)
]
else:
var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name))
self.graph.resolve_hazard(var_nodes)
def has_circle(self):
"""
Check if the graph has a circle.

Loading…
Cancel
Save