vizgraph.py
1 from __future__ import annotations 2 3 from colorsys import hsv_to_rgb 4 from itertools import chain 5 from pathlib import Path 6 from typing import TYPE_CHECKING, Any, ClassVar 7 8 from lxml import etree 9 from pygraphviz import AGraph 10 11 if TYPE_CHECKING: 12 from sirocco.core.graph_items import Store 13 from sirocco import core 14 15 16 def hsv_to_hex(h: float, s: float, v: float) -> str: 17 r, g, b = hsv_to_rgb(h, s, v) 18 return "#{:02x}{:02x}{:02x}".format(*map(round, (255 * r, 255 * g, 255 * b))) 19 20 21 def node_colors(h: float) -> dict[str, str]: 22 fill = hsv_to_hex(h / 365, 0.15, 1) 23 border = hsv_to_hex(h / 365, 1, 0.20) 24 font = hsv_to_hex(h / 365, 1, 0.15) 25 return {"fillcolor": fill, "color": border, "fontcolor": font} 26 27 28 class VizGraph: 29 """Class for visualizing a Sirocco workflow""" 30 31 node_base_kw: ClassVar[dict[str, Any]] = {"style": "filled", "fontname": "Fira Sans", "fontsize": 14, "penwidth": 2} 32 edge_base_kw: ClassVar[dict[str, Any]] = {"color": "#77767B", "penwidth": 1.5} 33 data_node_base_kw: ClassVar[dict[str, Any]] = node_base_kw | {"shape": "ellipse"} 34 35 data_av_node_kw: ClassVar[dict[str, Any]] = data_node_base_kw | node_colors(116) 36 data_gen_node_kw: ClassVar[dict[str, Any]] = data_node_base_kw | node_colors(214) 37 task_node_kw: ClassVar[dict[str, Any]] = node_base_kw | {"shape": "box"} | node_colors(354) 38 io_edge_kw: ClassVar[dict[str, Any]] = edge_base_kw 39 wait_on_edge_kw: ClassVar[dict[str, Any]] = edge_base_kw | {"style": "dashed"} 40 cluster_kw: ClassVar[dict[str, Any]] = {"bgcolor": "#F6F5F4", "color": None, "fontsize": 16} 41 42 def __init__(self, name: str, cycles: Store, data: Store) -> None: 43 self.name = name 44 self.agraph = AGraph(name=name, fontname="Fira Sans", newrank=True) 45 for data_node in data: 46 gv_kw = self.data_av_node_kw if isinstance(data_node, core.AvailableData) else self.data_gen_node_kw 47 self.agraph.add_node(data_node, tooltip=self.tooltip(data_node), label=data_node.name, **gv_kw) 48 49 k = 1 50 for cycle in cycles: 51 # NOTE: For some reason, clusters need to have a unique name that starts with 'cluster' 52 # otherwise they are not taken into account. Hence the k index. 53 cluster_nodes = [] 54 for task_node in cycle.tasks: 55 cluster_nodes.append(task_node) 56 self.agraph.add_node( 57 task_node, label=task_node.name, tooltip=self.tooltip(task_node), **self.task_node_kw 58 ) 59 for data_node in task_node.input_data_nodes(): 60 self.agraph.add_edge(data_node, task_node, **self.io_edge_kw) 61 for data_node in task_node.output_data_nodes(): 62 self.agraph.add_edge(task_node, data_node, **self.io_edge_kw) 63 cluster_nodes.append(data_node) 64 for wait_task_node in task_node.wait_on: 65 self.agraph.add_edge(wait_task_node, task_node, **self.wait_on_edge_kw) 66 self.agraph.add_subgraph( 67 cluster_nodes, 68 name=f"cluster_{cycle.name}_{k}", 69 clusterrank="global", 70 label=self.tooltip(cycle), 71 tooltip=self.tooltip(cycle), 72 **self.cluster_kw, 73 ) 74 k += 1 75 76 @staticmethod 77 def tooltip(node) -> str: 78 return "\n".join(chain([node.name], (f" {k}: {v}" for k, v in node.coordinates.items()))) 79 80 def draw(self, file_path: Path | None = None, **kwargs): 81 # draw graphviz dot graph to svg file 82 self.agraph.layout(prog="dot") 83 if file_path is None: 84 file_path = Path(f"./{self.name}.svg") 85 86 self.agraph.draw(path=file_path, format="svg", **kwargs) 87 88 # Add interactive capabilities to the svg graph thanks to 89 # https://github.com/BartBrood/dynamic-SVG-from-Graphviz 90 91 # Parse svg 92 svg = etree.parse(file_path) # noqa: S320 this svg is safe as generated internaly 93 svg_root = svg.getroot() 94 # Add 'onload' tag 95 svg_root.set("onload", "addInteractivity(evt)") 96 # Add css style for interactivity 97 this_dir = Path(__file__).parent 98 style_file_path = this_dir / "svg-interactive-style.css" 99 node = etree.Element("style") 100 node.text = style_file_path.read_text() 101 svg_root.append(node) 102 # Add scripts 103 js_file_path = this_dir / "svg-interactive-script.js" 104 node = etree.Element("script") 105 node.text = etree.CDATA(js_file_path.read_text()) 106 svg_root.append(node) 107 108 # write svg again 109 svg.write(file_path) 110 111 @classmethod 112 def from_core_workflow(cls, workflow: core.Workflow): 113 return cls(workflow.name, workflow.cycles, workflow.data) 114 115 @classmethod 116 def from_config_file(cls, config_path: str): 117 return cls.from_core_workflow(core.Workflow.from_config_file(config_path))