/ src / sirocco / vizgraph.py
vizgraph.py
  1  from __future__ import annotations
  2  
  3  from colorsys import hsv_to_rgb
  4  from itertools import chain
  5  from pathlib import Path
  6  from typing import TYPE_CHECKING, Any, ClassVar
  7  
  8  from lxml import etree
  9  from pygraphviz import AGraph
 10  
 11  if TYPE_CHECKING:
 12      from sirocco.core.graph_items import Store
 13  from sirocco import core
 14  
 15  
 16  def hsv_to_hex(h: float, s: float, v: float) -> str:
 17      r, g, b = hsv_to_rgb(h, s, v)
 18      return "#{:02x}{:02x}{:02x}".format(*map(round, (255 * r, 255 * g, 255 * b)))
 19  
 20  
 21  def node_colors(h: float) -> dict[str, str]:
 22      fill = hsv_to_hex(h / 365, 0.15, 1)
 23      border = hsv_to_hex(h / 365, 1, 0.20)
 24      font = hsv_to_hex(h / 365, 1, 0.15)
 25      return {"fillcolor": fill, "color": border, "fontcolor": font}
 26  
 27  
 28  class VizGraph:
 29      """Class for visualizing a Sirocco workflow"""
 30  
 31      node_base_kw: ClassVar[dict[str, Any]] = {"style": "filled", "fontname": "Fira Sans", "fontsize": 14, "penwidth": 2}
 32      edge_base_kw: ClassVar[dict[str, Any]] = {"color": "#77767B", "penwidth": 1.5}
 33      data_node_base_kw: ClassVar[dict[str, Any]] = node_base_kw | {"shape": "ellipse"}
 34  
 35      data_av_node_kw: ClassVar[dict[str, Any]] = data_node_base_kw | node_colors(116)
 36      data_gen_node_kw: ClassVar[dict[str, Any]] = data_node_base_kw | node_colors(214)
 37      task_node_kw: ClassVar[dict[str, Any]] = node_base_kw | {"shape": "box"} | node_colors(354)
 38      io_edge_kw: ClassVar[dict[str, Any]] = edge_base_kw
 39      wait_on_edge_kw: ClassVar[dict[str, Any]] = edge_base_kw | {"style": "dashed"}
 40      cluster_kw: ClassVar[dict[str, Any]] = {"bgcolor": "#F6F5F4", "color": None, "fontsize": 16}
 41  
 42      def __init__(self, name: str, cycles: Store, data: Store) -> None:
 43          self.name = name
 44          self.agraph = AGraph(name=name, fontname="Fira Sans", newrank=True)
 45          for data_node in data:
 46              gv_kw = self.data_av_node_kw if isinstance(data_node, core.AvailableData) else self.data_gen_node_kw
 47              self.agraph.add_node(data_node, tooltip=self.tooltip(data_node), label=data_node.name, **gv_kw)
 48  
 49          k = 1
 50          for cycle in cycles:
 51              # NOTE: For some reason, clusters need to have a unique name that starts with 'cluster'
 52              #       otherwise they are not taken into account. Hence the k index.
 53              cluster_nodes = []
 54              for task_node in cycle.tasks:
 55                  cluster_nodes.append(task_node)
 56                  self.agraph.add_node(
 57                      task_node, label=task_node.name, tooltip=self.tooltip(task_node), **self.task_node_kw
 58                  )
 59                  for data_node in task_node.input_data_nodes():
 60                      self.agraph.add_edge(data_node, task_node, **self.io_edge_kw)
 61                  for data_node in task_node.output_data_nodes():
 62                      self.agraph.add_edge(task_node, data_node, **self.io_edge_kw)
 63                      cluster_nodes.append(data_node)
 64                  for wait_task_node in task_node.wait_on:
 65                      self.agraph.add_edge(wait_task_node, task_node, **self.wait_on_edge_kw)
 66              self.agraph.add_subgraph(
 67                  cluster_nodes,
 68                  name=f"cluster_{cycle.name}_{k}",
 69                  clusterrank="global",
 70                  label=self.tooltip(cycle),
 71                  tooltip=self.tooltip(cycle),
 72                  **self.cluster_kw,
 73              )
 74              k += 1
 75  
 76      @staticmethod
 77      def tooltip(node) -> str:
 78          return "\n".join(chain([node.name], (f"  {k}: {v}" for k, v in node.coordinates.items())))
 79  
 80      def draw(self, file_path: Path | None = None, **kwargs):
 81          # draw graphviz dot graph to svg file
 82          self.agraph.layout(prog="dot")
 83          if file_path is None:
 84              file_path = Path(f"./{self.name}.svg")
 85  
 86          self.agraph.draw(path=file_path, format="svg", **kwargs)
 87  
 88          # Add interactive capabilities to the svg graph thanks to
 89          # https://github.com/BartBrood/dynamic-SVG-from-Graphviz
 90  
 91          # Parse svg
 92          svg = etree.parse(file_path)  # noqa: S320 this svg is safe as generated internaly
 93          svg_root = svg.getroot()
 94          # Add 'onload' tag
 95          svg_root.set("onload", "addInteractivity(evt)")
 96          # Add css style for interactivity
 97          this_dir = Path(__file__).parent
 98          style_file_path = this_dir / "svg-interactive-style.css"
 99          node = etree.Element("style")
100          node.text = style_file_path.read_text()
101          svg_root.append(node)
102          # Add scripts
103          js_file_path = this_dir / "svg-interactive-script.js"
104          node = etree.Element("script")
105          node.text = etree.CDATA(js_file_path.read_text())
106          svg_root.append(node)
107  
108          # write svg again
109          svg.write(file_path)
110  
111      @classmethod
112      def from_core_workflow(cls, workflow: core.Workflow):
113          return cls(workflow.name, workflow.cycles, workflow.data)
114  
115      @classmethod
116      def from_config_file(cls, config_path: str):
117          return cls.from_core_workflow(core.Workflow.from_config_file(config_path))