/ src / analysis / utils / plot_utils.py
plot_utils.py
 1  # Python Imports
 2  from matplotlib import patheffects as path_effects
 3  from matplotlib import pyplot as plt
 4  from result import Err, Ok, Result
 5  
 6  # Project Imports
 7  
 8  
 9  def add_boxplot_stat_labels(
10      ax: plt.Axes, fmt: str = ".3f", value_type: str = "median", scale_by: float = 1.0
11  ) -> Result[None, str]:
12      # Refactor from https://stackoverflow.com/a/63295846
13      """
14      Add text labels to the median, minimum, or maximum lines of a seaborn boxplot.
15  
16      Args:
17          ax: plt.Axes, e.g., the return value of sns.boxplot()
18          fmt: Format string for the value (e.g., min/max/median).
19          value_type: The type of value to label. Can be 'median', 'min', or 'max'.
20          scale_by: Scales the written value of the value type by 1 / this factor.
21      """
22      lines = ax.get_lines()
23      boxes = [c for c in ax.get_children() if "Patch" in str(c)]  # Get box patches
24      start = 4
25      if not boxes:  # seaborn v0.13 or above (no patches => need to shift index)
26          boxes = [c for c in ax.get_lines() if len(c.get_xdata()) == 5]
27          start += 1
28      lines_per_box = len(lines) // len(boxes)
29  
30      if value_type == "median":
31          line_idx = start
32      elif value_type == "min":
33          line_idx = start - 2  # min line comes 2 positions before the median
34      elif value_type == "max":
35          line_idx = start - 1  # max line comes 1 position before the median
36      else:
37          return Err("Invalid value_type. Must be 'min', 'max', or 'median'.")
38  
39      for value_line in lines[line_idx::lines_per_box]:
40          x, y = (data.mean() for data in value_line.get_data())
41          # choose value depending on horizontal or vertical plot orientation
42          value = x if len(set(value_line.get_xdata())) == 1 else y
43          text = ax.text(
44              x,
45              y,
46              f"{value/scale_by:{fmt}}",
47              ha="center",
48              va="center",
49              fontweight="bold",
50              color="white",
51              size=10,
52          )
53          # create colored border around white text for contrast
54          text.set_path_effects(
55              [
56                  path_effects.Stroke(linewidth=3, foreground=value_line.get_color()),
57                  path_effects.Normal(),
58              ]
59          )
60  
61      return Ok(None)