|
|
|
@ -23,6 +23,7 @@ import traceback
|
|
|
|
|
import six
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import subprocess
|
|
|
|
|
|
|
|
|
|
from .. import compat as cpt
|
|
|
|
|
from .proto import framework_pb2
|
|
|
|
@ -381,27 +382,6 @@ class Variable(object):
|
|
|
|
|
self._ivar.desc = self.desc
|
|
|
|
|
self._ivar.stop_gradient = stop_gradient
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def construct_from_desc(block, desc):
|
|
|
|
|
"""
|
|
|
|
|
Construct a Variable from variable desc.
|
|
|
|
|
Args:
|
|
|
|
|
desc(core.VarDesc): The variable desc for constructing.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Variable: A variable.
|
|
|
|
|
"""
|
|
|
|
|
v = Variable(
|
|
|
|
|
block=block,
|
|
|
|
|
type=desc.type(),
|
|
|
|
|
name=desc.name(),
|
|
|
|
|
shape=desc.shape(),
|
|
|
|
|
dtype=desc.dtype(),
|
|
|
|
|
lod_level=desc.lod_level(),
|
|
|
|
|
persistable=desc.persistable())
|
|
|
|
|
v.desc = desc
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
def _numpy(self):
|
|
|
|
|
tensor = self._ivar.value().get_tensor()
|
|
|
|
|
return np.array(tensor)
|
|
|
|
@ -1533,6 +1513,154 @@ class Block(object):
|
|
|
|
|
return ret_var
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IrGraph(object):
|
|
|
|
|
"""
|
|
|
|
|
IrGraph uses core.Graph as the delegation to accomplish the manipulation.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, graph, for_test=False):
|
|
|
|
|
"""
|
|
|
|
|
Construct the IrGraph using core.Graph.
|
|
|
|
|
Args:
|
|
|
|
|
graph(core.Graph): C++ Graph.
|
|
|
|
|
for_test(bool): True for the test graph and false for the train graph.
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(
|
|
|
|
|
graph, core.Graph), 'graph must be the instance of core.Graph.'
|
|
|
|
|
self.graph = graph
|
|
|
|
|
self._for_test = for_test
|
|
|
|
|
|
|
|
|
|
def is_test(self):
|
|
|
|
|
return self._for_test
|
|
|
|
|
|
|
|
|
|
def all_parameters(self):
|
|
|
|
|
param_nodes = set()
|
|
|
|
|
for node in self.graph.nodes():
|
|
|
|
|
if node.is_var() and node.var() is not None and node.var(
|
|
|
|
|
).persistable():
|
|
|
|
|
param_nodes.add(node)
|
|
|
|
|
return param_nodes
|
|
|
|
|
|
|
|
|
|
def all_vars(self):
|
|
|
|
|
return {node for node in self.graph.nodes() if node.is_var()}
|
|
|
|
|
|
|
|
|
|
def all_ops(self):
|
|
|
|
|
return {node for node in self.graph.nodes() if node.is_op()}
|
|
|
|
|
|
|
|
|
|
def create_param_node(self, name, var_type, shape, var_dtype):
|
|
|
|
|
var_desc = core.VarDesc(name)
|
|
|
|
|
var_desc.set_type(var_type)
|
|
|
|
|
var_desc.set_shape(shape)
|
|
|
|
|
var_desc.set_dtype(var_dtype)
|
|
|
|
|
var_desc.set_persistable(True)
|
|
|
|
|
return self.graph.create_var_node(var_desc)
|
|
|
|
|
|
|
|
|
|
def create_var_node(self, name, var_type, shape, var_dtype):
|
|
|
|
|
var_desc = core.VarDesc(name)
|
|
|
|
|
var_desc.set_type(var_type)
|
|
|
|
|
var_desc.set_shape(shape)
|
|
|
|
|
var_desc.set_dtype(var_dtype)
|
|
|
|
|
return self.graph.create_var_node(var_desc)
|
|
|
|
|
|
|
|
|
|
def create_var_node_from_desc(self, var_desc):
|
|
|
|
|
return self.graph.create_var_node(var_desc)
|
|
|
|
|
|
|
|
|
|
def create_op_node(self, op_type, attrs, inputs, outputs):
|
|
|
|
|
op_desc = core.OpDesc()
|
|
|
|
|
op_desc.set_type(op_type)
|
|
|
|
|
for attr, value in attrs.iteritems():
|
|
|
|
|
self._update_desc_attr(op_desc, attr, value)
|
|
|
|
|
for input_name, var_nodes in inputs.iteritems():
|
|
|
|
|
if not isinstance(var_nodes, list):
|
|
|
|
|
var_nodes = [var_nodes]
|
|
|
|
|
op_desc.set_input(input_name,
|
|
|
|
|
[var_node.name() for var_node in var_nodes])
|
|
|
|
|
for output_name, var_nodes in outputs.iteritems():
|
|
|
|
|
if not isinstance(var_nodes, list):
|
|
|
|
|
var_nodes = [var_nodes]
|
|
|
|
|
op_desc.set_output(output_name,
|
|
|
|
|
[var_node.name() for var_node in var_nodes])
|
|
|
|
|
return self.graph.create_op_node(op_desc)
|
|
|
|
|
|
|
|
|
|
def create_op_node_from_desc(self, op_desc):
|
|
|
|
|
return self.graph.create_op_node(op_desc)
|
|
|
|
|
|
|
|
|
|
def update_input_link(self, old_input_node, new_input_node, op_node):
|
|
|
|
|
assert old_input_node in self.graph.nodes() and new_input_node in self.graph.nodes() and \
|
|
|
|
|
op_node in self.graph.nodes(), 'Th three arguments must be in the graph nodes.'
|
|
|
|
|
old_input_node.outputs_remove(op_node)
|
|
|
|
|
op_node.inputs_remove(old_input_node)
|
|
|
|
|
new_input_node.outputs_append(op_node)
|
|
|
|
|
op_node.inputs_append(new_input_node)
|
|
|
|
|
op_node.op()._rename_input(old_input_node.name(), new_input_node.name())
|
|
|
|
|
|
|
|
|
|
def link_to(self, node_in, node_out):
|
|
|
|
|
assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \
|
|
|
|
|
'Th two arguments must be in the graph nodes.'
|
|
|
|
|
node_in.outputs_append(node_out)
|
|
|
|
|
node_out.inputs_append(node_in)
|
|
|
|
|
|
|
|
|
|
def safe_remove_nodes(self, remove_nodes):
|
|
|
|
|
if not isinstance(remove_nodes, set):
|
|
|
|
|
remove_nodes = set(remove_nodes)
|
|
|
|
|
core.graph_safe_remove_nodes(self.graph, remove_nodes)
|
|
|
|
|
|
|
|
|
|
def draw(self, save_path, name, marked_nodes=None):
|
|
|
|
|
def _convert_to_pdf(dot_file_path):
|
|
|
|
|
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
|
|
|
|
|
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
|
|
|
|
|
+ ' -o ' + pdf_save_path, shell=True)
|
|
|
|
|
if exited_code != 0:
|
|
|
|
|
print('The dot command is needed for creating pdf files.')
|
|
|
|
|
print('The {} is saved as the dot filetype.'.format(
|
|
|
|
|
dot_file_path))
|
|
|
|
|
|
|
|
|
|
remove_ctr_vars = set()
|
|
|
|
|
ops_num = 0
|
|
|
|
|
for node in self.graph.nodes():
|
|
|
|
|
if node.is_ctrl_var():
|
|
|
|
|
remove_ctr_vars.add(node)
|
|
|
|
|
elif node.is_op():
|
|
|
|
|
ops_num += 1
|
|
|
|
|
print('Total ops num = {}.'.format(ops_num))
|
|
|
|
|
self.safe_remove_nodes(remove_ctr_vars)
|
|
|
|
|
if marked_nodes is not None:
|
|
|
|
|
if not isinstance(marked_nodes, set):
|
|
|
|
|
marked_nodes = set(marked_nodes)
|
|
|
|
|
marked_nodes = marked_nodes - remove_ctr_vars
|
|
|
|
|
if self.graph.has('__graphviz__marked_node__'):
|
|
|
|
|
self.graph.erase('__graphviz__marked_node__')
|
|
|
|
|
self.graph.set('__graphviz__marked_node__', marked_nodes)
|
|
|
|
|
viz_dot_path = os.path.join(save_path, name) + '.dot'
|
|
|
|
|
viz_pass = core.get_pass('graph_viz_pass')
|
|
|
|
|
viz_pass.set('graph_viz_path', viz_dot_path)
|
|
|
|
|
viz_pass.apply(self.graph)
|
|
|
|
|
_convert_to_pdf(viz_dot_path)
|
|
|
|
|
|
|
|
|
|
def to_program(self):
|
|
|
|
|
convert_pass = core.get_pass('graph_to_program_pass')
|
|
|
|
|
convert_pass.set('program', Program().desc)
|
|
|
|
|
convert_pass.apply(self.graph)
|
|
|
|
|
desc = convert_pass.get_program('program')
|
|
|
|
|
program = Program._construct_from_desc(desc)
|
|
|
|
|
return program
|
|
|
|
|
|
|
|
|
|
def _update_desc_attr(self, desc, name, val):
|
|
|
|
|
"""
|
|
|
|
|
Update the value of desc's attribute by attribute's name.
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(val, Block):
|
|
|
|
|
desc.set_block_attr(name, val.desc)
|
|
|
|
|
elif isinstance(val, list) and val and all(
|
|
|
|
|
isinstance(v, Block) for v in val):
|
|
|
|
|
desc.set_blocks_attr(name, [v.desc for v in val])
|
|
|
|
|
elif isinstance(val, core.BlockDesc) or \
|
|
|
|
|
isinstance(val, core.ProgramDesc):
|
|
|
|
|
desc.set_serialized_attr(name, val.serialize_to_string())
|
|
|
|
|
else:
|
|
|
|
|
desc._set_attr(name, val)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Program(object):
|
|
|
|
|
"""
|
|
|
|
|
Python Program. Beneath it is a ProgramDesc, which is used for
|
|
|
|
@ -1958,12 +2086,10 @@ class Program(object):
|
|
|
|
|
return p
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def construct_from_desc(desc):
|
|
|
|
|
def _construct_from_desc(desc):
|
|
|
|
|
"""
|
|
|
|
|
Construct a program from program desc.
|
|
|
|
|
|
|
|
|
|
Notes: All information about parameters will be lost.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
desc(core.ProgramDesc): The program desc for constructing.
|
|
|
|
|
|
|
|
|
|