Fix some bugs for quantization passes.

move-code
Zhen Wang 6 years ago
parent ec11135d54
commit 2ccbfd5e10

@ -1964,6 +1964,28 @@ class IrOpNode(IrNode):
else: else:
desc._set_attr(name, val) desc._set_attr(name, val)
def input_arg_names(self):
"""
Return input arguments' names of this op node.
Returns:
list(str): input arguments' names of this op node.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().input_arg_names()
def output_arg_names(self):
"""
Return output arguments' names of this op node.
Returns:
list(str): output arguments' names of this op node.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().output_arg_names()
@property @property
def inputs(self): def inputs(self):
""" """
@ -2054,31 +2076,38 @@ class IrGraph(object):
""" """
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()} return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
def var_node(self, name): def _find_var_node(self, key):
""" """
Get a variable node by name from the graph. Get a variable node by the `key` from this graph. The key
can be a node name or a node id.
WARNS:
There are some nodes may have the same name. So, be
cautious about using this method when you find the
target var node by its name.
Args: Args:
name(str): the name of the variable node. key(str|int): The str type denotes that the target variable node's name.
And the int type denotes that the target variable node's id.
Raises: Raises:
ValueError: The If input's type is not str, or this graph ValueError: If this graph doesn't have a variable with the giving name or id.
doesn't have a variable with the giving name.
Returns: Returns:
IrVarNode: the variable node with the giving name. IrVarNode: the variable node with the giving name or id.
""" """
if not isinstance(name, six.string_types):
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
target_var_node = None target_var_node = None
var_nodes = self.all_var_nodes() var_nodes = self.all_var_nodes()
for var_node in var_nodes: if isinstance(key, six.string_types):
if var_node.name() == name: for var_node in var_nodes:
target_var_node = var_node if var_node.name() == key:
target_var_node = var_node
elif isinstance(key, int):
for var_node in var_nodes:
if var_node.id() == key:
target_var_node = var_node
if target_var_node is None: if target_var_node is None:
raise ValueError("var_node %s not in this graph" % name) raise ValueError("var_node %s not in this graph" % key)
return target_var_node return target_var_node
def create_persistable_node(self, name, var_type, shape, var_dtype): def create_persistable_node(self, name, var_type, shape, var_dtype):

Loading…
Cancel
Save