You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/graphviz.py

273 lines
7.6 KiB

# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import random
import six
import functools
import subprocess
import logging
def crepr(v):
if isinstance(v, six.string_types):
return '"%s"' % v
return str(v)
class Rank(object):
def __init__(self, kind, name, priority):
'''
kind: str
name: str
priority: int
'''
self.kind = kind
self.name = name
self.priority = priority
self.nodes = []
def __str__(self):
if not self.nodes:
return ''
return '{' + 'rank={};'.format(self.kind) + \
','.join([node.name for node in self.nodes]) + '}'
class Graph(object):
rank_counter = 0
def __init__(self, title, **attrs):
self.title = title
self.attrs = attrs
self.nodes = []
self.edges = []
self.rank_groups = {}
def code(self):
return self.__str__()
def rank_group(self, kind, priority):
name = "rankgroup-%d" % Graph.rank_counter
Graph.rank_counter += 1
rank = Rank(kind, name, priority)
self.rank_groups[name] = rank
return name
def node(self, label, prefix, description="", **attrs):
node = Node(label, prefix, description, **attrs)
if 'rank' in attrs:
rank = self.rank_groups[attrs['rank']]
del attrs['rank']
rank.nodes.append(node)
self.nodes.append(node)
return node
def edge(self, source, target, **attrs):
edge = Edge(source, target, **attrs)
self.edges.append(edge)
return edge
def compile(self, dot_path):
file = open(dot_path, 'w')
file.write(self.__str__())
image_path = os.path.join(
os.path.dirname(dot_path), dot_path[:-3] + "pdf")
cmd = ["dot", "-Tpdf", dot_path, "-o", image_path]
subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
logging.warning("write block debug graph to {}".format(image_path))
return image_path
def show(self, dot_path):
image = self.compile(dot_path)
cmd = ["open", image]
subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
def _rank_repr(self):
ranks = sorted(
six.iteritems(self.rank_groups),
key=functools.cmp_to_key(
lambda a, b: a[1].priority > b[1].priority))
repr = []
for x in ranks:
repr.append(str(x[1]))
return '\n'.join(repr) + '\n'
def __str__(self):
reprs = [
'digraph G {',
'title = {}'.format(crepr(self.title)),
]
for attr in self.attrs:
reprs.append("{key}={value};".format(
key=attr, value=crepr(self.attrs[attr])))
reprs.append(self._rank_repr())
random.shuffle(self.nodes)
reprs += [str(node) for node in self.nodes]
for x in self.edges:
reprs.append(str(x))
reprs.append('}')
return '\n'.join(reprs)
class Node(object):
counter = 1
def __init__(self, label, prefix, description="", **attrs):
self.label = label
self.name = "%s_%d" % (prefix, Node.counter)
self.description = description
self.attrs = attrs
Node.counter += 1
def __str__(self):
reprs = '{name} [label={label} {extra} ];'.format(
name=self.name,
label=self.label,
extra=',' + ','.join("%s=%s" % (key, crepr(value))
for key, value in six.iteritems(self.attrs))
if self.attrs else "")
return reprs
class Edge(object):
def __init__(self, source, target, **attrs):
'''
Link source to target.
:param source: Node
:param target: Node
:param graph: Graph
:param attrs: dic
'''
self.source = source
self.target = target
self.attrs = attrs
def __str__(self):
repr = "{source} -> {target} {extra}".format(
source=self.source.name,
target=self.target.name,
extra="" if not self.attrs else
"[" + ','.join("{}={}".format(attr[0], crepr(attr[1]))
for attr in six.iteritems(self.attrs)) + "]")
return repr
class GraphPreviewGenerator(object):
'''
Generate a graph image for ONNX proto.
'''
def __init__(self, title):
# init graphviz graph
self.graph = Graph(
title,
layout="dot",
concentrate="true",
rankdir="TB", )
self.op_rank = self.graph.rank_group('same', 2)
self.param_rank = self.graph.rank_group('same', 1)
self.arg_rank = self.graph.rank_group('same', 0)
def __call__(self, path='temp.dot', show=False):
if not show:
self.graph.compile(path)
else:
self.graph.show(path)
def add_param(self, name, data_type, highlight=False):
label = '\n'.join([
'<<table cellpadding="5">',
' <tr>',
' <td bgcolor="#2b787e">',
' <b>',
name,
' </b>',
' </td>',
' </tr>',
' <tr>',
' <td>',
str(data_type),
' </td>'
' </tr>',
'</table>>',
])
return self.graph.node(
label,
prefix="param",
description=name,
shape="none",
style="rounded,filled,bold",
width="1.3",
color="#148b97" if not highlight else "orange",
fontcolor="#ffffff",
fontname="Arial")
def add_op(self, opType, **kwargs):
highlight = False
if 'highlight' in kwargs:
highlight = kwargs['highlight']
del kwargs['highlight']
return self.graph.node(
"<<B>%s</B>>" % opType,
prefix="op",
description=opType,
shape="box",
style="rounded, filled, bold",
color="#303A3A" if not highlight else "orange",
fontname="Arial",
fontcolor="#ffffff",
width="1.3",
height="0.84", )
def add_arg(self, name, highlight=False):
return self.graph.node(
crepr(name),
prefix="arg",
description=name,
shape="box",
style="rounded,filled,bold",
fontname="Arial",
fontcolor="#999999",
color="#dddddd" if not highlight else "orange")
def add_edge(self, source, target, **kwargs):
highlight = False
if 'highlight' in kwargs:
highlight = kwargs['highlight']
del kwargs['highlight']
return self.graph.edge(
source,
target,
color="#00000" if not highlight else "orange",
**kwargs)