/ benchmark / benchmarks_visualizer.py
benchmarks_visualizer.py
  1  import json
  2  import os
  3  from argparse import ArgumentParser
  4  from dataclasses import dataclass
  5  
  6  import matplotlib.pyplot as plt
  7  import pandas as pd
  8  import seaborn as sns
  9  
 10  DATA_PATH = "data/all_benchmark_data.csv"
 11  VISUALIZATIONS_PATH = "visualizations/"
 12  
 13  
 14  @dataclass
 15  class VisualizationsConfig:
 16      """
 17      Configuration for the visualizations script.
 18  
 19      Args:
 20          kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`)
 21          metric_name (str): Metric name to visualize (speed/memory)
 22          kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full"
 23          display (bool): Display the visualization. Defaults to False
 24          overwrite (bool): Overwrite existing visualization, if none exist this flag has no effect as ones are always created and saved. Defaults to False
 25  
 26      """
 27  
 28      kernel_name: str
 29      metric_name: str
 30      kernel_operation_mode: str = "full"
 31      display: bool = False
 32      overwrite: bool = False
 33  
 34  
 35  def parse_args() -> VisualizationsConfig:
 36      """Parse command line arguments into a configuration object.
 37  
 38      Returns:
 39          VisualizationsConfig: Configuration object for the visualizations script.
 40      """
 41      parser = ArgumentParser()
 42      parser.add_argument(
 43          "--kernel-name", type=str, required=True, help="Kernel name to benchmark"
 44      )
 45      parser.add_argument(
 46          "--metric-name",
 47          type=str,
 48          required=True,
 49          help="Metric name to visualize (speed/memory)",
 50      )
 51      parser.add_argument(
 52          "--kernel-operation-mode",
 53          type=str,
 54          required=True,
 55          help="Kernel operation mode to visualize (forward/backward/full)",
 56      )
 57      parser.add_argument(
 58          "--display", action="store_true", help="Display the visualization"
 59      )
 60      parser.add_argument(
 61          "--overwrite",
 62          action="store_true",
 63          help="Overwrite existing visualization, if none exist this flag has no effect as one are always created",
 64      )
 65  
 66      args = parser.parse_args()
 67  
 68      return VisualizationsConfig(**dict(args._get_kwargs()))
 69  
 70  
 71  def load_data(config: VisualizationsConfig) -> pd.DataFrame:
 72      """Loads the benchmark data from the CSV file and filters it based on the configuration.
 73  
 74      Args:
 75          config (VisualizationsConfig): Configuration object for the visualizations script.
 76  
 77      Raises:
 78          ValueError: If no data is found for the given filters.
 79  
 80      Returns:
 81          pd.DataFrame: Filtered benchmark dataframe.
 82      """
 83      df = pd.read_csv(DATA_PATH)
 84      df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads)
 85  
 86      filtered_df = df[
 87          (df["kernel_name"] == config.kernel_name)
 88          & (df["metric_name"] == config.metric_name)
 89          & (df["kernel_operation_mode"] == config.kernel_operation_mode)
 90          # Use this to filter by extra benchmark configuration property
 91          # & (data['extra_benchmark_config'].apply(lambda x: x.get('H') == 4096))
 92          # FIXME: maybe add a way to filter using some configuration, except of hardcoding it
 93      ]
 94  
 95      if filtered_df.empty:
 96          raise ValueError("No data found for the given filters")
 97  
 98      return filtered_df
 99  
100  
101  def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
102      """Plots the benchmark data, saving the result if needed.
103  
104      Args:
105          df (pd.DataFrame): Filtered benchmark dataframe.
106          config (VisualizationsConfig): Configuration object for the visualizations script.
107      """
108      xlabel = df["x_label"].iloc[0]
109      ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})"
110      # Sort by "kernel_provider" to ensure consistent color assignment
111      df = df.sort_values(by="kernel_provider")
112  
113      plt.figure(figsize=(10, 6))
114      sns.set(style="whitegrid")
115      ax = sns.lineplot(
116          data=df,
117          x="x_value",
118          y="y_value_50",
119          hue="kernel_provider",
120          marker="o",
121          palette="tab10",
122          errorbar=("ci", None),
123      )
124  
125      # Seaborn can't plot pre-computed error bars, so we need to do it manually
126      lines = ax.get_lines()
127      colors = [line.get_color() for line in lines]
128  
129      for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
130          # for i, row in group_data.iterrows():
131          y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
132          y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
133          y_error = [y_error_lower, y_error_upper]
134  
135          plt.errorbar(
136              group_data["x_value"],
137              group_data["y_value_50"],
138              yerr=y_error,
139              fmt="o",
140              color=color,
141              capsize=5,
142          )
143      plt.legend(title="Kernel Provider")
144      plt.xlabel(xlabel)
145      plt.ylabel(ylabel)
146      plt.tight_layout()
147  
148      out_path = os.path.join(
149          VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png"
150      )
151  
152      if config.display:
153          plt.show()
154      if config.overwrite or not os.path.exists(
155          out_path
156      ):  # Save the plot if it doesn't exist or if we want to overwrite it
157          os.makedirs(VISUALIZATIONS_PATH, exist_ok=True)
158          plt.savefig(out_path)
159      plt.close()
160  
161  
162  def main():
163      config = parse_args()
164      df = load_data(config)
165      plot_data(df, config)
166  
167  
168  if __name__ == "__main__":
169      main()