Merge pull request #16456 from wzzju/fix_quan_hang

Fix quantization hang bugs.
move-code
Zhen Wang 6 years ago committed by GitHub
commit d68a02af17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2052,6 +2052,28 @@ class IrOpNode(IrNode):
else:
desc._set_attr(name, val)
def input_arg_names(self):
"""
Return input arguments' names of this op node.
Returns:
list(str): input arguments' names of this op node.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().input_arg_names()
def output_arg_names(self):
"""
Return output arguments' names of this op node.
Returns:
list(str): output arguments' names of this op node.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().output_arg_names()
@property
def inputs(self):
"""
@ -2142,31 +2164,38 @@ class IrGraph(object):
"""
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
def var_node(self, name):
def _find_var_node(self, key):
"""
Get a variable node by name from the graph.
Get a variable node by the `key` from this graph. The key
can be a node name or a node id.
WARNS:
There are some nodes may have the same name. So, be
cautious about using this method when you find the
target var node by its name.
Args:
name(str): the name of the variable node.
key(str|int): The str type denotes that the target variable node's name.
And the int type denotes that the target variable node's id.
Raises:
ValueError: The If input's type is not str, or this graph
doesn't have a variable with the giving name.
ValueError: If this graph doesn't have a variable with the giving name or id.
Returns:
IrVarNode: the variable node with the giving name.
IrVarNode: the variable node with the giving name or id.
"""
if not isinstance(name, six.string_types):
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
target_var_node = None
var_nodes = self.all_var_nodes()
for var_node in var_nodes:
if var_node.name() == name:
target_var_node = var_node
if isinstance(key, six.string_types):
for var_node in var_nodes:
if var_node.name() == key:
target_var_node = var_node
elif isinstance(key, int):
for var_node in var_nodes:
if var_node.id() == key:
target_var_node = var_node
if target_var_node is None:
raise ValueError("var_node %s not in this graph" % name)
raise ValueError("var_node %s not in this graph" % key)
return target_var_node
def create_persistable_node(self, name, var_type, shape, var_dtype):
@ -2312,6 +2341,34 @@ class IrGraph(object):
original_nodes = {n.node for n in remove_nodes}
core.graph_safe_remove_nodes(self.graph, original_nodes)
def resolve_hazard(self):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = core.topology_sort(self.graph)
var_nodes = dict()
for node in ordered_nodes:
if node.is_op() and node.op() is not None:
for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name)
]
for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name)
]
else:
var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name))
self.graph.resolve_hazard(var_nodes)
def has_circle(self):
"""
Check if the graph has a circle.

Loading…
Cancel
Save