|
|
|
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from .graphviz import Digraph
|
|
|
|
|
from .graphviz import Graph
|
|
|
|
|
except ImportError:
|
|
|
|
|
logger.info(
|
|
|
|
|
'Cannot import graphviz, which is required for drawing a network. This '
|
|
|
|
@ -112,7 +112,7 @@ def draw_graph(startup_program, main_program, **kwargs):
|
|
|
|
|
filename = kwargs.get("filename")
|
|
|
|
|
if filename == None:
|
|
|
|
|
filename = str(graph_id) + ".gv"
|
|
|
|
|
g = Digraph(
|
|
|
|
|
g = Graph(
|
|
|
|
|
name=str(graph_id),
|
|
|
|
|
filename=filename,
|
|
|
|
|
graph_attr=GRAPH_STYLE,
|
|
|
|
|