parent
2bd92754e8
commit
8894c67d71
@ -0,0 +1,73 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from graphviz import GraphPreviewGenerator
|
||||
import proto.framework_pb2 as framework_pb2
|
||||
|
||||
|
||||
def draw_block_graphviz(block, highlights=None, path="./temp.dot"):
|
||||
'''
|
||||
Generate a debug graph for block.
|
||||
Args:
|
||||
block(Block): a block.
|
||||
'''
|
||||
graph = GraphPreviewGenerator("some graph")
|
||||
# collect parameters and args
|
||||
protostr = block.desc.serialize_to_string()
|
||||
desc = framework_pb2.BlockDesc.FromString(str(protostr))
|
||||
|
||||
def need_highlight(name):
|
||||
if highlights is None: return False
|
||||
for pattern in highlights:
|
||||
assert type(pattern) is str
|
||||
if re.match(pattern, name):
|
||||
return True
|
||||
return False
|
||||
|
||||
# draw parameters and args
|
||||
vars = {}
|
||||
for var in desc.vars:
|
||||
shape = [str(i) for i in var.lod_tensor.tensor.dims]
|
||||
if not shape:
|
||||
shape = ['null']
|
||||
# create var
|
||||
if var.persistable:
|
||||
varn = graph.add_param(
|
||||
var.name, var.type, shape, highlight=need_highlight(var.name))
|
||||
else:
|
||||
varn = graph.add_arg(var.name, highlight=need_highlight(var.name))
|
||||
vars[var.name] = varn
|
||||
|
||||
def add_op_link_var(op, var, op2var=False):
|
||||
for arg in var.arguments:
|
||||
if arg not in vars:
|
||||
# add missing variables as argument
|
||||
vars[arg] = graph.add_arg(arg, highlight=need_highlight(arg))
|
||||
varn = vars[arg]
|
||||
highlight = need_highlight(op.description) or need_highlight(
|
||||
varn.description)
|
||||
if op2var:
|
||||
graph.add_edge(op, varn, highlight=highlight)
|
||||
else:
|
||||
graph.add_edge(varn, op, highlight=highlight)
|
||||
|
||||
for op in desc.ops:
|
||||
opn = graph.add_op(op.type, highlight=need_highlight(op.type))
|
||||
for var in op.inputs:
|
||||
add_op_link_var(opn, var, False)
|
||||
for var in op.outputs:
|
||||
add_op_link_var(opn, var, True)
|
||||
|
||||
graph(path, show=True)
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue