graph_visualizer.py
1 ## modified from https://github.com/szagoruyko/functional-zoo/blob/master/visualize.py 2 3 from graphviz import Digraph 4 import torch 5 6 def compute_graph(var, params=None, output_file=None,view=False): 7 """ Produces Graphviz representation of PyTorch autograd graph 8 9 Blue nodes are the Variables that require grad, orange are Tensors 10 saved for backward in torch.autograd.Function 11 12 Args: 13 var: output Variable 14 params: dict of (name, Variable) to add names to node that 15 require grad (TODO: make optional) 16 """ 17 if params is not None: 18 assert isinstance(params.values()[0], Variable) 19 param_map = {id(v): k for k, v in params.items()} 20 21 node_attr = dict(style='filled', 22 shape='box', 23 align='left', 24 fontsize='12', 25 ranksep='0.1', 26 height='0.2') 27 dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"),format='svg') 28 seen = set() 29 30 def size_to_str(size): 31 return '('+(', ').join(['%d' % v for v in size])+')' 32 33 def add_nodes(var): 34 if var not in seen: 35 if torch.is_tensor(var): 36 dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange') 37 elif hasattr(var, 'variable'): 38 u = var.variable 39 if u is not None: 40 name = param_map[id(u)] if params is not None else '' 41 node_name = '%s\n %s' % (name, size_to_str(u.size())) 42 dot.node(str(id(var)), node_name, fillcolor='lightblue') 43 else: 44 dot.node(str(id(var)), str(type(var).__name__)) 45 seen.add(var) 46 if hasattr(var, 'next_functions'): 47 for u in var.next_functions: 48 if u[0] is not None: 49 dot.edge(str(id(u[0])), str(id(var))) 50 add_nodes(u[0]) 51 if hasattr(var, 'saved_tensors'): 52 for t in var.saved_tensors: 53 dot.edge(str(id(t)), str(id(var))) 54 add_nodes(t) 55 add_nodes(var.grad_fn) 56 if output_file: 57 dot.render(output_file,view=view) 58 return dot