"add net drawer for visualizing the graph" (#5292)
* "add net drawer for visualizing the graph" * "fix " * "add dep"fix-typo
parent
0a32e74d13
commit
e0c3a6683c
@ -0,0 +1,109 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
import paddle.v2.framework.core as core
|
||||
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
logger.info(
|
||||
'Cannot import graphviz, which is required for drawing a network. This '
|
||||
'can usually be installed in python with "pip install graphviz". Also, '
|
||||
'pydot requires graphviz to convert dot files to pdf: in ubuntu, this '
|
||||
'can usually be installed with "sudo apt-get install graphviz".')
|
||||
print('net_drawer will not run correctly. Please install the correct '
|
||||
'dependencies.')
|
||||
exit(0)
|
||||
|
||||
OP_STYLE = {
|
||||
'shape': 'oval',
|
||||
'color': '#0F9D58',
|
||||
'style': 'filled',
|
||||
'fontcolor': '#FFFFFF'
|
||||
}
|
||||
|
||||
VAR_STYLE = {}
|
||||
|
||||
GRAPH_STYLE = {"rankdir": "TB", }
|
||||
|
||||
GRAPH_ID = 0
|
||||
|
||||
|
||||
def unique_id():
|
||||
def generator():
|
||||
GRAPH_ID += 1
|
||||
return GRAPH_ID
|
||||
|
||||
return generator
|
||||
|
||||
|
||||
def draw_node(op):
|
||||
node = OP_STYLE
|
||||
node["name"] = op.type
|
||||
node["label"] = op.type
|
||||
return node
|
||||
|
||||
|
||||
def draw_edge(var_parent, op, var, arg):
|
||||
edge = VAR_STYLE
|
||||
edge["label"] = "%s(%s)" % (var.parameter, arg)
|
||||
edge["head_name"] = op.type
|
||||
edge["tail_name"] = var_parent[arg]
|
||||
return edge
|
||||
|
||||
|
||||
def parse_graph(program, graph, var_dict, **kwargs):
|
||||
|
||||
# fill the known variables
|
||||
for block in program.blocks:
|
||||
for var in block.vars:
|
||||
if not var_dict.has_key(var):
|
||||
var_dict[var] = "Feed"
|
||||
|
||||
proto = framework_pb2.ProgramDesc.FromString(
|
||||
program.desc.serialize_to_string())
|
||||
for block in proto.blocks:
|
||||
for op in block.ops:
|
||||
graph.node(**draw_node(op))
|
||||
for o in op.outputs:
|
||||
for arg in o.arguments:
|
||||
var_dict[arg] = op.type
|
||||
for e in op.inputs:
|
||||
for arg in e.arguments:
|
||||
if var_dict.has_key(arg):
|
||||
graph.edge(**draw_edge(var_dict, op, e, arg))
|
||||
|
||||
|
||||
def draw_graph(init_program, program, **kwargs):
|
||||
if kwargs.has_key("graph_attr"):
|
||||
GRAPH_STYLE.update(kwargs[graph_attr])
|
||||
if kwargs.has_key("node_attr"):
|
||||
OP_STYLE.update(kwargs[node_attr])
|
||||
if kwargs.has_key("edge_attr"):
|
||||
VAR_STYLE.update(kwargs[edge_attr])
|
||||
|
||||
graph_id = unique_id()
|
||||
filename = kwargs.get("filename")
|
||||
if filename == None:
|
||||
filename = str(graph_id) + ".gv"
|
||||
g = Digraph(
|
||||
name=str(graph_id),
|
||||
filename=filename,
|
||||
graph_attr=GRAPH_STYLE,
|
||||
node_attr=OP_STYLE,
|
||||
edge_attr=VAR_STYLE,
|
||||
**kwargs)
|
||||
|
||||
var_dict = {}
|
||||
parse_graph(init_program, g, var_dict)
|
||||
parse_graph(program, g, var_dict)
|
||||
|
||||
if filename != None:
|
||||
g.save()
|
||||
return g
|
Loading…
Reference in new issue