test_base_metric.py
1 from typing import ClassVar 2 from typing import Dict 3 from typing import Optional 4 5 import pandas as pd 6 import pytest 7 8 from evidently.legacy.base_metric import ColumnName 9 from evidently.legacy.base_metric import DatasetType 10 from evidently.legacy.base_metric import InputData 11 from evidently.legacy.base_metric import Metric 12 from evidently.legacy.base_metric import MetricResult 13 from evidently.legacy.core import ColumnType 14 from evidently.legacy.features.generated_features import GeneratedFeature 15 from evidently.legacy.metrics import ColumnValueRangeMetric 16 from evidently.legacy.metrics.base_metric import generate_column_metrics 17 from evidently.legacy.options.base import Options 18 from evidently.legacy.options.option import Option 19 from evidently.legacy.pipeline.column_mapping import ColumnMapping 20 from evidently.legacy.report import Report 21 from evidently.legacy.utils.data_preprocessing import DataDefinition 22 from evidently.pydantic_utils import FingerprintPart 23 from evidently.pydantic_utils import get_value_fingerprint 24 25 26 def test_metric_generator(): 27 test_data = pd.DataFrame({"col1": [3, 2, 3], "col2": [4, 5, 6], "col3": [4, 5, 6]}) 28 report = Report(metrics=[generate_column_metrics(ColumnValueRangeMetric, parameters={"left": 0, "right": 10})]) 29 report.run( 30 current_data=test_data, 31 reference_data=None, 32 column_mapping=ColumnMapping(numerical_features=["col1", "col2", "col3"]), 33 ) 34 assert report.show() 35 36 report = Report( 37 metrics=[ 38 generate_column_metrics( 39 metric_class=ColumnValueRangeMetric, columns=["col2", "col3"], parameters={"left": 0, "right": 10} 40 ) 41 ] 42 ) 43 report.run( 44 current_data=test_data, 45 reference_data=None, 46 column_mapping=ColumnMapping(numerical_features=["col1", "col2", "col3"]), 47 ) 48 assert report.show() 49 50 51 class SimpleMetric(Metric[int]): 52 class Config: 53 alias_required = False 54 55 column_name: ColumnName 56 57 def __init__(self, column_name: ColumnName): 58 self.column_name = column_name 59 super().__init__() 60 61 def calculate(self, data: InputData) -> int: 62 return data.get_current_column(self.column_name).sum() 63 64 65 class SimpleMetric2(Metric[int]): 66 class Config: 67 alias_required = False 68 69 column_name: ColumnName 70 71 def __init__(self, column_name: ColumnName): 72 self.column_name = column_name 73 super().__init__() 74 75 def calculate(self, data: InputData) -> int: 76 return data.get_current_column(self.column_name).sum() + 1 77 78 79 class SimpleMetricWithFeatures(Metric[int]): 80 class Config: 81 alias_required = False 82 83 column_name: str 84 _feature: Optional[GeneratedFeature] 85 86 def __init__(self, column_name: str): 87 self.column_name = column_name 88 self._feature = None 89 super().__init__() 90 91 def calculate(self, data: InputData) -> int: 92 if data.data_definition.get_column(self.column_name).column_type == ColumnType.Categorical: 93 return data.get_current_column(self._feature.as_column()).sum() 94 return data.get_current_column(self.column_name).sum() 95 96 def required_features(self, data_definition: DataDefinition): 97 column_type = data_definition.get_column(self.column_name).column_type 98 self._feature = LengthFeature(self.column_name) 99 if column_type == ColumnType.Categorical: 100 return [self._feature] 101 return [] 102 103 104 class MetricWithAllTextFeatures(Metric[Dict[str, int]]): 105 class Config: 106 alias_required = False 107 108 _features: Dict[str, "LengthFeature"] 109 110 def calculate(self, data: InputData): 111 return {k: data.get_current_column(v.as_column()).sum() for k, v in self._features.items()} 112 113 def required_features(self, data_definition: DataDefinition): 114 self._features = { 115 column.column_name: LengthFeature(column.column_name) 116 for column in data_definition.get_columns(ColumnType.Text, features_only=True) 117 } 118 return list(self._features.values()) 119 120 121 class SimpleGeneratedFeature(GeneratedFeature): 122 class Config: 123 alias_required = False 124 125 __feature_type__: ClassVar = ColumnType.Numerical 126 column_name: str 127 128 def __init__(self, column_name: str, display_name: str = ""): 129 self.column_name = column_name 130 self.display_name = display_name 131 super().__init__() 132 133 def generate_feature(self, data: pd.DataFrame, data_definition: DataDefinition) -> pd.DataFrame: 134 return pd.DataFrame(dict([(self.column_name, data[self.column_name] * 2)])) 135 136 def _as_column(self) -> ColumnName: 137 return self._create_column(subcolumn=self.column_name, default_display_name="SGF: {self.column_name}") 138 139 140 class LengthFeature(GeneratedFeature): 141 class Config: 142 alias_required = False 143 144 __feature_type__: ClassVar = ColumnType.Numerical 145 column_name: str 146 max_length: Optional[int] = None 147 148 def __init__(self, column_name: str, max_length: Optional[int] = None): 149 self.column_name = column_name 150 self.max_length = max_length 151 super().__init__() 152 153 def generate_feature(self, data: pd.DataFrame, data_definition: DataDefinition) -> pd.DataFrame: 154 return pd.DataFrame(dict([(self.column_name, data[self.column_name].apply(len))])) 155 156 def _as_column(self) -> ColumnName: 157 return self._create_column(self.column_name, default_display_name=f"Length of {self.column_name}") 158 159 160 @pytest.mark.parametrize( 161 "metric,result", 162 [ 163 (SimpleMetric(ColumnName("col1", "col1", DatasetType.MAIN, None)), 6), 164 (SimpleMetric(SimpleGeneratedFeature("col1").as_column()), 12), 165 (SimpleMetricWithFeatures("col1"), 6), 166 (SimpleMetricWithFeatures("col2"), 9), 167 (MetricWithAllTextFeatures(), {"col3": 9, "col4": 12}), 168 ], 169 ) 170 def test_additional_features(metric, result): 171 test_data = pd.DataFrame( 172 dict( 173 col1=[1.0, 2.0, 3.0], 174 col2=["11", "111", "1111"], 175 col3=["11", "111", "1111"], 176 col4=["111", "1111", "11111"], 177 ) 178 ) 179 report = Report(metrics=[metric]) 180 181 report.run( 182 current_data=test_data, 183 reference_data=None, 184 column_mapping=ColumnMapping( 185 numerical_features=["col1"], 186 categorical_features=["col2"], 187 text_features=["col3", "col4"], 188 ), 189 ) 190 report._inner_suite.raise_for_error() 191 assert metric.get_result() == result 192 193 194 @pytest.mark.parametrize( 195 "metrics,result", 196 [ 197 ( 198 [ 199 SimpleMetric(SimpleGeneratedFeature("col1", "d1").as_column()), 200 SimpleMetric2(SimpleGeneratedFeature("col1", "d2").as_column()), 201 ], 202 (12, 13), 203 ), 204 ], 205 ) 206 def test_additional_features_multi_metrics(metrics, result): 207 test_data = pd.DataFrame( 208 dict( 209 col1=[1.0, 2.0, 3.0], 210 col2=["11", "111", "1111"], 211 col3=["11", "111", "1111"], 212 col4=["111", "1111", "11111"], 213 ) 214 ) 215 report = Report(metrics=metrics) 216 217 report.run( 218 current_data=test_data, 219 reference_data=None, 220 column_mapping=ColumnMapping( 221 numerical_features=["col1"], 222 categorical_features=["col2"], 223 text_features=["col3", "col4"], 224 ), 225 ) 226 report._inner_suite.raise_for_error() 227 assert metrics[0].get_result() == result[0] 228 assert metrics[1].get_result() == result[1] 229 230 231 def test_options_fingerprint_not_specified(): 232 class MyOption(Option): 233 field: str 234 235 class MockMetric(Metric[MetricResult]): 236 class Config: 237 alias_required = False 238 239 def calculate(self, data: InputData): 240 return MetricResult() 241 242 m1 = MockMetric(options=[MyOption(field="a")]) 243 m2 = MockMetric(options=[MyOption(field="b")]) 244 245 assert m1.get_fingerprint() == m2.get_fingerprint() 246 247 248 def test_options_fingerprint_specified_type(): 249 class MyOption(Option): 250 field: str 251 252 class UsesMyOptionMixin: 253 options: Options 254 255 def get_options_fingerprint(self) -> FingerprintPart: 256 return get_value_fingerprint(self.options.get(MyOption).field) 257 258 class MockMetricWithOption(UsesMyOptionMixin, Metric[MetricResult]): 259 class Config: 260 alias_required = False 261 262 def calculate(self, data: InputData): 263 return MetricResult() 264 265 m3 = MockMetricWithOption(options=[MyOption(field="a")]) 266 m4 = MockMetricWithOption(options=[MyOption(field="b")]) 267 268 assert m3.get_fingerprint() != m4.get_fingerprint()