compare.py
1 from collections import defaultdict 2 from typing import Any 3 from typing import Callable 4 from typing import Dict 5 from typing import List 6 from typing import Union 7 8 import pandas as pd 9 10 from evidently.core.report import Snapshot 11 12 CompareIndex = Union[str, List[str], Callable[[Snapshot], Any]] 13 14 15 def _get_index(index: CompareIndex, run: Snapshot, i: int) -> Any: 16 if isinstance(index, list): 17 return index[i] 18 if callable(index): 19 return index(run) 20 if index == "timestamp": 21 return run._timestamp 22 if index.startswith("metadata."): 23 key = index[len("metadata.") :] 24 return run._metadata[key] 25 raise ValueError(f"Invalid index: {index}") 26 27 28 def compare( 29 *runs: Snapshot, index: CompareIndex = "timestamp", all_metrics: bool = False, use_tests: bool = False 30 ) -> pd.DataFrame: 31 """Compare multiple `Report` snapshots side-by-side in a `pandas.DataFrame`. 32 33 If you computed multiple snapshots, you can quickly compare the resulting metrics 34 side-by-side in a dataframe. This is useful for comparing: 35 - Different time periods 36 - Different model/prompt versions 37 - Different datasets 38 39 Args: 40 * `*runs`: One or more `evidently.core.report.Snapshot` objects to compare 41 * `index`: How to index the comparison. Can be: 42 - "timestamp": Use snapshot timestamp (default) 43 - List of strings: Custom index values for each run 44 - Callable: Function that takes a `Snapshot` and returns index value 45 - "metadata.<key>": Use metadata value as index 46 * `all_metrics`: If True, include all metrics from all runs. If False (default), 47 only include metrics present in all runs. 48 * `use_tests`: If True, include test results instead of metric values 49 50 Returns: 51 * `pandas.DataFrame` with metrics as rows and runs as columns, indexed by the specified index 52 53 Example: 54 ```python 55 compare_dataframe = compare(my_eval_1, my_eval_2, my_eval_3) 56 compare_dataframe = compare(run1, run2, index="metadata.model_version") 57 ``` 58 """ 59 if isinstance(index, list) and len(index) != len(runs): 60 raise ValueError("Index and runs must have same length") 61 62 common_metrics = set.intersection(*[set(r._top_level_metrics) for r in runs]) 63 result: Dict[str, Dict[int, Union[float, str]]] = defaultdict(dict) 64 for i, run in enumerate(runs): 65 result["index"][i] = _get_index(index, run, i) 66 for metric_id in run._top_level_metrics: 67 if not all_metrics and metric_id not in common_metrics: 68 continue 69 metric_result = run._metrics[metric_id] 70 if use_tests: 71 for test in metric_result.tests: 72 result[test.name][i] = test.status.value 73 else: 74 for key, value in metric_result.itervalues(): 75 col = f"{metric_result.display_name}.{key}" 76 result[col][i] = value 77 return ( 78 pd.DataFrame({col: [val.get(i, None) for i in range(len(runs))] for col, val in result.items()}) 79 .set_index("index") 80 .T 81 )