/ pytlib / visualization / graph_visualizer.py
graph_visualizer.py
 1  ## modified from https://github.com/szagoruyko/functional-zoo/blob/master/visualize.py
 2  
 3  from graphviz import Digraph
 4  import torch
 5  
 6  def compute_graph(var, params=None, output_file=None,view=False):
 7      """ Produces Graphviz representation of PyTorch autograd graph
 8  
 9      Blue nodes are the Variables that require grad, orange are Tensors
10      saved for backward in torch.autograd.Function
11  
12      Args:
13          var: output Variable
14          params: dict of (name, Variable) to add names to node that
15              require grad (TODO: make optional)
16      """
17      if params is not None:
18          assert isinstance(params.values()[0], Variable)
19          param_map = {id(v): k for k, v in params.items()}
20  
21      node_attr = dict(style='filled',
22                       shape='box',
23                       align='left',
24                       fontsize='12',
25                       ranksep='0.1',
26                       height='0.2')
27      dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"),format='svg')
28      seen = set()
29  
30      def size_to_str(size):
31          return '('+(', ').join(['%d' % v for v in size])+')'
32  
33      def add_nodes(var):
34          if var not in seen:
35              if torch.is_tensor(var):
36                  dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
37              elif hasattr(var, 'variable'):
38                  u = var.variable
39                  if u is not None:
40                      name = param_map[id(u)] if params is not None else ''
41                      node_name = '%s\n %s' % (name, size_to_str(u.size()))
42                      dot.node(str(id(var)), node_name, fillcolor='lightblue')
43              else:
44                  dot.node(str(id(var)), str(type(var).__name__))
45              seen.add(var)
46              if hasattr(var, 'next_functions'):
47                  for u in var.next_functions:
48                      if u[0] is not None:
49                          dot.edge(str(id(u[0])), str(id(var)))
50                          add_nodes(u[0])
51              if hasattr(var, 'saved_tensors'):
52                  for t in var.saved_tensors:
53                      dot.edge(str(id(t)), str(id(var)))
54                      add_nodes(t)
55      add_nodes(var.grad_fn)
56      if output_file:
57          dot.render(output_file,view=view)
58      return dot