benchmarks_visualizer.py
1 import json 2 import os 3 from argparse import ArgumentParser 4 from dataclasses import dataclass 5 6 import matplotlib.pyplot as plt 7 import pandas as pd 8 import seaborn as sns 9 10 DATA_PATH = "data/all_benchmark_data.csv" 11 VISUALIZATIONS_PATH = "visualizations/" 12 13 14 @dataclass 15 class VisualizationsConfig: 16 """ 17 Configuration for the visualizations script. 18 19 Args: 20 kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`) 21 metric_name (str): Metric name to visualize (speed/memory) 22 kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full" 23 display (bool): Display the visualization. Defaults to False 24 overwrite (bool): Overwrite existing visualization, if none exist this flag has no effect as ones are always created and saved. Defaults to False 25 26 """ 27 28 kernel_name: str 29 metric_name: str 30 kernel_operation_mode: str = "full" 31 display: bool = False 32 overwrite: bool = False 33 34 35 def parse_args() -> VisualizationsConfig: 36 """Parse command line arguments into a configuration object. 37 38 Returns: 39 VisualizationsConfig: Configuration object for the visualizations script. 40 """ 41 parser = ArgumentParser() 42 parser.add_argument( 43 "--kernel-name", type=str, required=True, help="Kernel name to benchmark" 44 ) 45 parser.add_argument( 46 "--metric-name", 47 type=str, 48 required=True, 49 help="Metric name to visualize (speed/memory)", 50 ) 51 parser.add_argument( 52 "--kernel-operation-mode", 53 type=str, 54 required=True, 55 help="Kernel operation mode to visualize (forward/backward/full)", 56 ) 57 parser.add_argument( 58 "--display", action="store_true", help="Display the visualization" 59 ) 60 parser.add_argument( 61 "--overwrite", 62 action="store_true", 63 help="Overwrite existing visualization, if none exist this flag has no effect as one are always created", 64 ) 65 66 args = parser.parse_args() 67 68 return VisualizationsConfig(**dict(args._get_kwargs())) 69 70 71 def load_data(config: VisualizationsConfig) -> pd.DataFrame: 72 """Loads the benchmark data from the CSV file and filters it based on the configuration. 73 74 Args: 75 config (VisualizationsConfig): Configuration object for the visualizations script. 76 77 Raises: 78 ValueError: If no data is found for the given filters. 79 80 Returns: 81 pd.DataFrame: Filtered benchmark dataframe. 82 """ 83 df = pd.read_csv(DATA_PATH) 84 df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads) 85 86 filtered_df = df[ 87 (df["kernel_name"] == config.kernel_name) 88 & (df["metric_name"] == config.metric_name) 89 & (df["kernel_operation_mode"] == config.kernel_operation_mode) 90 # Use this to filter by extra benchmark configuration property 91 # & (data['extra_benchmark_config'].apply(lambda x: x.get('H') == 4096)) 92 # FIXME: maybe add a way to filter using some configuration, except of hardcoding it 93 ] 94 95 if filtered_df.empty: 96 raise ValueError("No data found for the given filters") 97 98 return filtered_df 99 100 101 def plot_data(df: pd.DataFrame, config: VisualizationsConfig): 102 """Plots the benchmark data, saving the result if needed. 103 104 Args: 105 df (pd.DataFrame): Filtered benchmark dataframe. 106 config (VisualizationsConfig): Configuration object for the visualizations script. 107 """ 108 xlabel = df["x_label"].iloc[0] 109 ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})" 110 # Sort by "kernel_provider" to ensure consistent color assignment 111 df = df.sort_values(by="kernel_provider") 112 113 plt.figure(figsize=(10, 6)) 114 sns.set(style="whitegrid") 115 ax = sns.lineplot( 116 data=df, 117 x="x_value", 118 y="y_value_50", 119 hue="kernel_provider", 120 marker="o", 121 palette="tab10", 122 errorbar=("ci", None), 123 ) 124 125 # Seaborn can't plot pre-computed error bars, so we need to do it manually 126 lines = ax.get_lines() 127 colors = [line.get_color() for line in lines] 128 129 for (_, group_data), color in zip(df.groupby("kernel_provider"), colors): 130 # for i, row in group_data.iterrows(): 131 y_error_lower = group_data["y_value_50"] - group_data["y_value_20"] 132 y_error_upper = group_data["y_value_80"] - group_data["y_value_50"] 133 y_error = [y_error_lower, y_error_upper] 134 135 plt.errorbar( 136 group_data["x_value"], 137 group_data["y_value_50"], 138 yerr=y_error, 139 fmt="o", 140 color=color, 141 capsize=5, 142 ) 143 plt.legend(title="Kernel Provider") 144 plt.xlabel(xlabel) 145 plt.ylabel(ylabel) 146 plt.tight_layout() 147 148 out_path = os.path.join( 149 VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png" 150 ) 151 152 if config.display: 153 plt.show() 154 if config.overwrite or not os.path.exists( 155 out_path 156 ): # Save the plot if it doesn't exist or if we want to overwrite it 157 os.makedirs(VISUALIZATIONS_PATH, exist_ok=True) 158 plt.savefig(out_path) 159 plt.close() 160 161 162 def main(): 163 config = parse_args() 164 df = load_data(config) 165 plot_data(df, config) 166 167 168 if __name__ == "__main__": 169 main()