|
|
|
@ -13,8 +13,81 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
from ....framework import Program
|
|
|
|
|
from ....framework import Block
|
|
|
|
|
from .... import core
|
|
|
|
|
|
|
|
|
|
__all__ = ['Graph', 'ImitationGraph', 'IRGraph']
|
|
|
|
|
__all__ = ['Graph', 'ImitationGraph', 'PyGraph']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PyGraph(object):
|
|
|
|
|
"""
|
|
|
|
|
PyGraph uses core.Graph as the delegation to accomplish the manipulation.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, graph):
|
|
|
|
|
assert isinstance(
|
|
|
|
|
graph, core.Graph), 'graph must be the instance of core.Graph.'
|
|
|
|
|
self.graph = graph
|
|
|
|
|
|
|
|
|
|
def all_parameters(self):
|
|
|
|
|
params = []
|
|
|
|
|
for node in self.graph.nodes():
|
|
|
|
|
if node.is_var() and node.var().persistable():
|
|
|
|
|
params.append(node)
|
|
|
|
|
return params
|
|
|
|
|
|
|
|
|
|
def all_vars(self):
|
|
|
|
|
return [node for node in self.graph.nodes() if node.is_var()]
|
|
|
|
|
|
|
|
|
|
def all_ops(self):
|
|
|
|
|
return [node for node in self.graph.nodes() if node.is_op()]
|
|
|
|
|
|
|
|
|
|
def create_param_node(self, name, var_type, shape, var_dtype):
|
|
|
|
|
var_desc = core.VarDesc(name)
|
|
|
|
|
var_desc.set_type(var_type)
|
|
|
|
|
var_desc.set_shape(shape)
|
|
|
|
|
var_desc.set_dtype(var_dtype)
|
|
|
|
|
var_desc.set_persistable(True)
|
|
|
|
|
return self.graph.create_var_node(var_desc)
|
|
|
|
|
|
|
|
|
|
def create_var_node(self, name, var_type, shape, var_dtype):
|
|
|
|
|
var_desc = core.VarDesc(name)
|
|
|
|
|
var_desc.set_type(var_type)
|
|
|
|
|
var_desc.set_shape(shape)
|
|
|
|
|
var_desc.set_dtype(var_dtype)
|
|
|
|
|
return self.graph.create_var_node(var_desc)
|
|
|
|
|
|
|
|
|
|
def create_var_node_from_desc(self, var_desc):
|
|
|
|
|
return self.graph.create_var_node(var_desc)
|
|
|
|
|
|
|
|
|
|
def create_op_node(self, op_type, attrs, inputs, outputs):
|
|
|
|
|
op_desc = core.OpDesc()
|
|
|
|
|
op_desc.set_type(op_type)
|
|
|
|
|
for attr, value in attrs.iteritems():
|
|
|
|
|
self._update_desc_attr(op_desc, attr, value)
|
|
|
|
|
for input_name, var_node in inputs.iteritems():
|
|
|
|
|
op_desc.set_input(input_name, [var_node.name()])
|
|
|
|
|
for output_name, var_node in outputs.iteritems():
|
|
|
|
|
op_desc.set_output(output_name, [var_node.name()])
|
|
|
|
|
return self.graph.create_op_node(op_desc)
|
|
|
|
|
|
|
|
|
|
def create_op_node_from_desc(self, op_desc):
|
|
|
|
|
return self.graph.create_op_node(op_desc)
|
|
|
|
|
|
|
|
|
|
def _update_desc_attr(self, desc, name, val):
|
|
|
|
|
"""
|
|
|
|
|
Update the value of desc's attribute by attribute's name.
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(val, Block):
|
|
|
|
|
desc.set_block_attr(name, val.desc)
|
|
|
|
|
elif isinstance(val, list) and val and all(
|
|
|
|
|
isinstance(v, Block) for v in val):
|
|
|
|
|
desc.set_blocks_attr(name, [v.desc for v in val])
|
|
|
|
|
elif isinstance(val, core.BlockDesc) or \
|
|
|
|
|
isinstance(val, core.ProgramDesc):
|
|
|
|
|
desc.set_serialized_attr(name, val.serialize_to_string())
|
|
|
|
|
else:
|
|
|
|
|
desc._set_attr(name, val)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Graph(object):
|
|
|
|
@ -39,7 +112,3 @@ class ImitationGraph(Graph):
|
|
|
|
|
|
|
|
|
|
def all_parameters(self):
|
|
|
|
|
return self.program.global_block().all_parameters()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IRGraph(Graph):
|
|
|
|
|
pass
|
|
|
|
|