parent
2bd92754e8
commit
8894c67d71
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
from graphviz import GraphPreviewGenerator
|
||||||
|
import proto.framework_pb2 as framework_pb2
|
||||||
|
|
||||||
|
|
||||||
|
def draw_block_graphviz(block, highlights=None, path="./temp.dot"):
|
||||||
|
'''
|
||||||
|
Generate a debug graph for block.
|
||||||
|
Args:
|
||||||
|
block(Block): a block.
|
||||||
|
'''
|
||||||
|
graph = GraphPreviewGenerator("some graph")
|
||||||
|
# collect parameters and args
|
||||||
|
protostr = block.desc.serialize_to_string()
|
||||||
|
desc = framework_pb2.BlockDesc.FromString(str(protostr))
|
||||||
|
|
||||||
|
def need_highlight(name):
|
||||||
|
if highlights is None: return False
|
||||||
|
for pattern in highlights:
|
||||||
|
assert type(pattern) is str
|
||||||
|
if re.match(pattern, name):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# draw parameters and args
|
||||||
|
vars = {}
|
||||||
|
for var in desc.vars:
|
||||||
|
shape = [str(i) for i in var.lod_tensor.tensor.dims]
|
||||||
|
if not shape:
|
||||||
|
shape = ['null']
|
||||||
|
# create var
|
||||||
|
if var.persistable:
|
||||||
|
varn = graph.add_param(
|
||||||
|
var.name, var.type, shape, highlight=need_highlight(var.name))
|
||||||
|
else:
|
||||||
|
varn = graph.add_arg(var.name, highlight=need_highlight(var.name))
|
||||||
|
vars[var.name] = varn
|
||||||
|
|
||||||
|
def add_op_link_var(op, var, op2var=False):
|
||||||
|
for arg in var.arguments:
|
||||||
|
if arg not in vars:
|
||||||
|
# add missing variables as argument
|
||||||
|
vars[arg] = graph.add_arg(arg, highlight=need_highlight(arg))
|
||||||
|
varn = vars[arg]
|
||||||
|
highlight = need_highlight(op.description) or need_highlight(
|
||||||
|
varn.description)
|
||||||
|
if op2var:
|
||||||
|
graph.add_edge(op, varn, highlight=highlight)
|
||||||
|
else:
|
||||||
|
graph.add_edge(varn, op, highlight=highlight)
|
||||||
|
|
||||||
|
for op in desc.ops:
|
||||||
|
opn = graph.add_op(op.type, highlight=need_highlight(op.type))
|
||||||
|
for var in op.inputs:
|
||||||
|
add_op_link_var(opn, var, False)
|
||||||
|
for var in op.outputs:
|
||||||
|
add_op_link_var(opn, var, True)
|
||||||
|
|
||||||
|
graph(path, show=True)
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue