Fix some bugs for quantization passes.

move-code
Zhen Wang 6 years ago
parent ec11135d54
commit 2ccbfd5e10

@ -1964,6 +1964,28 @@ class IrOpNode(IrNode):
else:
desc._set_attr(name, val)
def input_arg_names(self):
"""
Return input arguments' names of this op node.
Returns:
list(str): input arguments' names of this op node.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().input_arg_names()
def output_arg_names(self):
"""
Return output arguments' names of this op node.
Returns:
list(str): output arguments' names of this op node.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().output_arg_names()
@property
def inputs(self):
"""
@ -2054,31 +2076,38 @@ class IrGraph(object):
"""
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
def var_node(self, name):
def _find_var_node(self, key):
"""
Get a variable node by name from the graph.
Get a variable node by the `key` from this graph. The key
can be a node name or a node id.
WARNS:
There are some nodes may have the same name. So, be
cautious about using this method when you find the
target var node by its name.
Args:
name(str): the name of the variable node.
key(str|int): The str type denotes that the target variable node's name.
And the int type denotes that the target variable node's id.
Raises:
ValueError: The If input's type is not str, or this graph
doesn't have a variable with the giving name.
ValueError: If this graph doesn't have a variable with the giving name or id.
Returns:
IrVarNode: the variable node with the giving name.
IrVarNode: the variable node with the giving name or id.
"""
if not isinstance(name, six.string_types):
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
target_var_node = None
var_nodes = self.all_var_nodes()
for var_node in var_nodes:
if var_node.name() == name:
target_var_node = var_node
if isinstance(key, six.string_types):
for var_node in var_nodes:
if var_node.name() == key:
target_var_node = var_node
elif isinstance(key, int):
for var_node in var_nodes:
if var_node.id() == key:
target_var_node = var_node
if target_var_node is None:
raise ValueError("var_node %s not in this graph" % name)
raise ValueError("var_node %s not in this graph" % key)
return target_var_node
def create_persistable_node(self, name, var_type, shape, var_dtype):

Loading…
Cancel
Save