|
|
|
@ -12,10 +12,202 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import sys
|
|
|
|
|
import re
|
|
|
|
|
from graphviz import GraphPreviewGenerator
|
|
|
|
|
import proto.framework_pb2 as framework_pb2
|
|
|
|
|
|
|
|
|
|
_vartype2str_ = [
|
|
|
|
|
"UNK",
|
|
|
|
|
"LoDTensor",
|
|
|
|
|
"SelectedRows",
|
|
|
|
|
"FeedMinibatch",
|
|
|
|
|
"FetchList",
|
|
|
|
|
"StepScopes",
|
|
|
|
|
"LodRankTable",
|
|
|
|
|
"LoDTensorArray",
|
|
|
|
|
"PlaceList",
|
|
|
|
|
]
|
|
|
|
|
_dtype2str_ = [
|
|
|
|
|
"bool",
|
|
|
|
|
"int16",
|
|
|
|
|
"int32",
|
|
|
|
|
"int64",
|
|
|
|
|
"float16",
|
|
|
|
|
"float32",
|
|
|
|
|
"float64",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def repr_data_type(type):
|
|
|
|
|
return _dtype2str_[type]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def repr_tensor(proto):
|
|
|
|
|
return "tensor(type={}, shape={})".format(_dtype2str_[int(proto.data_type)],
|
|
|
|
|
str(proto.dims))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reprtpl = "{ttype} {name} ({reprs})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def repr_lodtensor(proto):
|
|
|
|
|
if not proto.lod_tensor: return
|
|
|
|
|
level = proto.lod_tensor.lod_level
|
|
|
|
|
reprs = repr_tensor(proto.lod_tensor.tensor)
|
|
|
|
|
return reprtpl.format(
|
|
|
|
|
ttype="LoDTensor" if level > 0 else "Tensor",
|
|
|
|
|
name=proto.name,
|
|
|
|
|
reprs="level=%d, %s" % (level, reprs) if level > 0 else reprs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def repr_selected_rows(proto):
|
|
|
|
|
if not proto.selected_rows: return
|
|
|
|
|
return reprtpl.format(
|
|
|
|
|
ttype="SelectedRows",
|
|
|
|
|
name=proto.name,
|
|
|
|
|
reprs=repr_tensor(proto.selected_rows))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def repr_tensor_array(proto):
|
|
|
|
|
if not proto.tensor_array: return
|
|
|
|
|
return reprtpl.format(
|
|
|
|
|
ttype="TensorArray",
|
|
|
|
|
name=proto.name,
|
|
|
|
|
reprs="level=%d, %s" % (proto.tensor_array.lod_level,
|
|
|
|
|
repr_tensor(proto.lod_tensor)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type_handlers = [
|
|
|
|
|
repr_lodtensor,
|
|
|
|
|
repr_selected_rows,
|
|
|
|
|
repr_tensor_array,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def repr_var(vardesc):
|
|
|
|
|
for handler in type_handlers:
|
|
|
|
|
res = handler(vardesc)
|
|
|
|
|
if res:
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pprint_program_codes(program_desc):
|
|
|
|
|
reprs = []
|
|
|
|
|
for block_idx in range(program_desc.num_blocks()):
|
|
|
|
|
block_desc = program_desc.block(block_idx)
|
|
|
|
|
block_repr = pprint_block_codes(block_desc)
|
|
|
|
|
reprs.append(block_repr)
|
|
|
|
|
return '\n'.join(reprs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pprint_block_codes(block_desc, show_backward=False):
|
|
|
|
|
def is_op_backward(op_desc):
|
|
|
|
|
if op_desc.type.endswith('_grad'): return True
|
|
|
|
|
|
|
|
|
|
def is_var_backward(var):
|
|
|
|
|
if "@GRAD" in var.parameter: return True
|
|
|
|
|
for arg in var.arguments:
|
|
|
|
|
if "@GRAD" in arg: return True
|
|
|
|
|
|
|
|
|
|
for var in op_desc.inputs:
|
|
|
|
|
if is_var_backward(var): return True
|
|
|
|
|
for var in op_desc.outputs:
|
|
|
|
|
if is_var_backward(var): return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def is_var_backward(var_desc):
|
|
|
|
|
return "@GRAD" in var_desc.name
|
|
|
|
|
|
|
|
|
|
if type(block_desc) is not framework_pb2.BlockDesc:
|
|
|
|
|
block_desc = framework_pb2.BlockDesc.FromString(
|
|
|
|
|
block_desc.serialize_to_string())
|
|
|
|
|
var_reprs = []
|
|
|
|
|
op_reprs = []
|
|
|
|
|
for var in block_desc.vars:
|
|
|
|
|
if not show_backward and is_var_backward(var):
|
|
|
|
|
continue
|
|
|
|
|
var_reprs.append(repr_var(var))
|
|
|
|
|
|
|
|
|
|
for op in block_desc.ops:
|
|
|
|
|
if not show_backward and is_op_backward(op): continue
|
|
|
|
|
op_reprs.append(repr_op(op))
|
|
|
|
|
|
|
|
|
|
tpl = "// block-{idx} parent-{pidx}\n// variables\n{vars}\n\n// operators\n{ops}\n"
|
|
|
|
|
return tpl.format(
|
|
|
|
|
idx=block_desc.idx,
|
|
|
|
|
pidx=block_desc.parent_idx,
|
|
|
|
|
vars='\n'.join(var_reprs),
|
|
|
|
|
ops='\n'.join(op_reprs), )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def repr_attr(desc):
|
|
|
|
|
tpl = "{key}={value}"
|
|
|
|
|
valgetter = [
|
|
|
|
|
lambda attr: attr.i,
|
|
|
|
|
lambda attr: attr.f,
|
|
|
|
|
lambda attr: attr.s,
|
|
|
|
|
lambda attr: attr.ints,
|
|
|
|
|
lambda attr: attr.floats,
|
|
|
|
|
lambda attr: attr.strings,
|
|
|
|
|
lambda attr: attr.b,
|
|
|
|
|
lambda attr: attr.bools,
|
|
|
|
|
lambda attr: attr.block_idx,
|
|
|
|
|
lambda attr: attr.l,
|
|
|
|
|
]
|
|
|
|
|
key = desc.name
|
|
|
|
|
value = valgetter[desc.type](desc)
|
|
|
|
|
if key == "dtype":
|
|
|
|
|
value = repr_data_type(value)
|
|
|
|
|
return tpl.format(key=key, value=str(value)), (key, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _repr_op_fill_constant(optype, inputs, outputs, attrs):
|
|
|
|
|
if optype == "fill_constant":
|
|
|
|
|
return "{output} = {data} [shape={shape}]".format(
|
|
|
|
|
output=','.join(outputs),
|
|
|
|
|
data=attrs['value'],
|
|
|
|
|
shape=str(attrs['shape']))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op_repr_handlers = [_repr_op_fill_constant, ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def repr_op(opdesc):
|
|
|
|
|
optype = None
|
|
|
|
|
attrs = []
|
|
|
|
|
attr_dict = {}
|
|
|
|
|
is_target = None
|
|
|
|
|
inputs = []
|
|
|
|
|
outputs = []
|
|
|
|
|
|
|
|
|
|
tpl = "{outputs} = {optype}({inputs}{is_target}) [{attrs}]"
|
|
|
|
|
args2value = lambda args: args[0] if len(args) == 1 else str(list(args))
|
|
|
|
|
for var in opdesc.inputs:
|
|
|
|
|
key = var.parameter
|
|
|
|
|
value = args2value(var.arguments)
|
|
|
|
|
inputs.append("%s=%s" % (key, value))
|
|
|
|
|
for var in opdesc.outputs:
|
|
|
|
|
value = args2value(var.arguments)
|
|
|
|
|
outputs.append(value)
|
|
|
|
|
for attr in opdesc.attrs:
|
|
|
|
|
attr_repr, attr_pair = repr_attr(attr)
|
|
|
|
|
attrs.append(attr_repr)
|
|
|
|
|
attr_dict[attr_pair[0]] = attr_pair[1]
|
|
|
|
|
|
|
|
|
|
is_target = opdesc.is_target
|
|
|
|
|
|
|
|
|
|
for handler in op_repr_handlers:
|
|
|
|
|
res = handler(opdesc.type, inputs, outputs, attr_dict)
|
|
|
|
|
if res: return res
|
|
|
|
|
|
|
|
|
|
return tpl.format(
|
|
|
|
|
outputs=', '.join(outputs),
|
|
|
|
|
optype=opdesc.type,
|
|
|
|
|
inputs=', '.join(inputs),
|
|
|
|
|
attrs="{%s}" % ','.join(attrs),
|
|
|
|
|
is_target=", is_target" if is_target else "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_block_graphviz(block, highlights=None, path="./temp.dot"):
|
|
|
|
|
'''
|
|
|
|
|