|
|
|
@ -1964,6 +1964,28 @@ class IrOpNode(IrNode):
|
|
|
|
|
else:
|
|
|
|
|
desc._set_attr(name, val)
|
|
|
|
|
|
|
|
|
|
def input_arg_names(self):
|
|
|
|
|
"""
|
|
|
|
|
Return input arguments' names of this op node.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list(str): input arguments' names of this op node.
|
|
|
|
|
"""
|
|
|
|
|
assert self.node.op() is not None, \
|
|
|
|
|
"The node operator description cannot be None."
|
|
|
|
|
return self.node.op().input_arg_names()
|
|
|
|
|
|
|
|
|
|
def output_arg_names(self):
|
|
|
|
|
"""
|
|
|
|
|
Return output arguments' names of this op node.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list(str): output arguments' names of this op node.
|
|
|
|
|
"""
|
|
|
|
|
assert self.node.op() is not None, \
|
|
|
|
|
"The node operator description cannot be None."
|
|
|
|
|
return self.node.op().output_arg_names()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def inputs(self):
|
|
|
|
|
"""
|
|
|
|
@ -2054,31 +2076,38 @@ class IrGraph(object):
|
|
|
|
|
"""
|
|
|
|
|
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
|
|
|
|
|
|
|
|
|
|
def var_node(self, name):
|
|
|
|
|
def _find_var_node(self, key):
|
|
|
|
|
"""
|
|
|
|
|
Get a variable node by name from the graph.
|
|
|
|
|
Get a variable node by the `key` from this graph. The key
|
|
|
|
|
can be a node name or a node id.
|
|
|
|
|
|
|
|
|
|
WARNS:
|
|
|
|
|
There are some nodes may have the same name. So, be
|
|
|
|
|
cautious about using this method when you find the
|
|
|
|
|
target var node by its name.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name(str): the name of the variable node.
|
|
|
|
|
key(str|int): The str type denotes that the target variable node's name.
|
|
|
|
|
And the int type denotes that the target variable node's id.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: The If input's type is not str, or this graph
|
|
|
|
|
doesn't have a variable with the giving name.
|
|
|
|
|
ValueError: If this graph doesn't have a variable with the giving name or id.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
IrVarNode: the variable node with the giving name.
|
|
|
|
|
IrVarNode: the variable node with the giving name or id.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(name, six.string_types):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"var require string as parameter, but get %s instead." %
|
|
|
|
|
(type(name)))
|
|
|
|
|
target_var_node = None
|
|
|
|
|
var_nodes = self.all_var_nodes()
|
|
|
|
|
for var_node in var_nodes:
|
|
|
|
|
if var_node.name() == name:
|
|
|
|
|
target_var_node = var_node
|
|
|
|
|
if isinstance(key, six.string_types):
|
|
|
|
|
for var_node in var_nodes:
|
|
|
|
|
if var_node.name() == key:
|
|
|
|
|
target_var_node = var_node
|
|
|
|
|
elif isinstance(key, int):
|
|
|
|
|
for var_node in var_nodes:
|
|
|
|
|
if var_node.id() == key:
|
|
|
|
|
target_var_node = var_node
|
|
|
|
|
if target_var_node is None:
|
|
|
|
|
raise ValueError("var_node %s not in this graph" % name)
|
|
|
|
|
raise ValueError("var_node %s not in this graph" % key)
|
|
|
|
|
return target_var_node
|
|
|
|
|
|
|
|
|
|
def create_persistable_node(self, name, var_type, shape, var_dtype):
|
|
|
|
|