data_drift_tests.py
1 from abc import ABC 2 from typing import ClassVar 3 from typing import Dict 4 from typing import List 5 from typing import Optional 6 from typing import Union 7 8 import numpy as np 9 import pandas as pd 10 11 from evidently.legacy.base_metric import ColumnName 12 from evidently.legacy.calculations.data_drift import ColumnDataDriftMetrics 13 from evidently.legacy.calculations.stattests import PossibleStatTestType 14 from evidently.legacy.core import ColumnType 15 from evidently.legacy.metric_results import HistogramData 16 from evidently.legacy.metrics import ColumnDriftMetric 17 from evidently.legacy.metrics import DataDriftTable 18 from evidently.legacy.metrics import EmbeddingsDriftMetric 19 from evidently.legacy.metrics.data_drift.base import WithDriftOptionsFields 20 from evidently.legacy.metrics.data_drift.data_drift_table import DataDriftTableResults 21 from evidently.legacy.metrics.data_drift.embedding_drift_methods import DriftMethod 22 from evidently.legacy.model.widget import BaseWidgetInfo 23 from evidently.legacy.renderers.base_renderer import DetailsInfo 24 from evidently.legacy.renderers.base_renderer import TestHtmlInfo 25 from evidently.legacy.renderers.base_renderer import TestRenderer 26 from evidently.legacy.renderers.base_renderer import default_renderer 27 from evidently.legacy.renderers.html_widgets import plotly_figure 28 from evidently.legacy.renderers.html_widgets import table_data 29 from evidently.legacy.tests.base_test import BaseCheckValueTest 30 from evidently.legacy.tests.base_test import ConditionTestParameters 31 from evidently.legacy.tests.base_test import ExcludeNoneMixin 32 from evidently.legacy.tests.base_test import GroupData 33 from evidently.legacy.tests.base_test import GroupingTypes 34 from evidently.legacy.tests.base_test import Test 35 from evidently.legacy.tests.base_test import TestParameters 36 from evidently.legacy.tests.base_test import TestResult 37 from evidently.legacy.tests.base_test import TestStatus 38 from evidently.legacy.tests.base_test import TestValueCondition 39 from evidently.legacy.utils.data_drift_utils import resolve_stattest_threshold 40 from evidently.legacy.utils.data_preprocessing import DataDefinition 41 from evidently.legacy.utils.generators import BaseGenerator 42 from evidently.legacy.utils.types import Numeric 43 from evidently.legacy.utils.visualizations import plot_contour_single 44 from evidently.legacy.utils.visualizations import plot_distr_with_cond_perc_button 45 46 DATA_DRIFT_GROUP = GroupData(id="data_drift", title="Data Drift", description="") 47 GroupingTypes.TestGroup.add_value(DATA_DRIFT_GROUP) 48 49 50 class ColumnDriftParameter(ExcludeNoneMixin, TestParameters): # type: ignore[misc] # pydantic Config 51 class Config: 52 type_alias = "evidently:test_parameters:ColumnDriftParameter" 53 54 stattest: str 55 score: float 56 threshold: float 57 detected: bool 58 column_name: Optional[str] = None 59 60 @classmethod 61 def from_metric(cls, data: ColumnDataDriftMetrics, column_name: str = None): 62 return ColumnDriftParameter( 63 stattest=data.stattest_name, 64 score=np.round(data.drift_score, 3), 65 threshold=data.stattest_threshold, 66 detected=data.drift_detected, 67 column_name=column_name, 68 ) 69 70 71 class ColumnsDriftParameters(ConditionTestParameters): 72 # todo: rename to columns? 73 class Config: 74 type_alias = "evidently:test_parameters:ColumnsDriftParameters" 75 76 features: Dict[str, ColumnDriftParameter] 77 78 @classmethod 79 def from_data_drift_table(cls, table: DataDriftTableResults, condition: TestValueCondition): 80 return ColumnsDriftParameters( 81 features={ 82 feature: ColumnDriftParameter.from_metric(data) for feature, data in table.drift_by_columns.items() 83 }, 84 condition=condition, 85 ) 86 87 def to_dataframe(self) -> pd.DataFrame: 88 return pd.DataFrame( 89 [ 90 { 91 "Feature name": feature, 92 "Stattest": data.stattest, 93 "Drift score": data.score, 94 "Threshold": data.threshold, 95 "Data Drift": "Detected" if data.detected else "Not detected", 96 } 97 for feature, data in self.features.items() 98 ], 99 ) 100 101 102 class BaseDataDriftMetricsTest(BaseCheckValueTest, WithDriftOptionsFields, ABC): 103 group: ClassVar = DATA_DRIFT_GROUP.id 104 _metric: DataDriftTable 105 columns: Optional[List[str]] 106 feature_importance: Optional[bool] 107 108 def __init__( 109 self, 110 columns: Optional[List[str]] = None, 111 eq: Optional[Numeric] = None, 112 gt: Optional[Numeric] = None, 113 gte: Optional[Numeric] = None, 114 is_in: Optional[List[Union[Numeric, str, bool]]] = None, 115 lt: Optional[Numeric] = None, 116 lte: Optional[Numeric] = None, 117 not_eq: Optional[Numeric] = None, 118 not_in: Optional[List[Union[Numeric, str, bool]]] = None, 119 stattest: Optional[PossibleStatTestType] = None, 120 cat_stattest: Optional[PossibleStatTestType] = None, 121 num_stattest: Optional[PossibleStatTestType] = None, 122 text_stattest: Optional[PossibleStatTestType] = None, 123 per_column_stattest: Optional[Dict[str, PossibleStatTestType]] = None, 124 stattest_threshold: Optional[float] = None, 125 cat_stattest_threshold: Optional[float] = None, 126 num_stattest_threshold: Optional[float] = None, 127 text_stattest_threshold: Optional[float] = None, 128 per_column_stattest_threshold: Optional[Dict[str, float]] = None, 129 is_critical: bool = True, 130 feature_importance: Optional[bool] = False, 131 ): 132 super().__init__( 133 eq=eq, 134 gt=gt, 135 gte=gte, 136 is_in=is_in, 137 lt=lt, 138 lte=lte, 139 not_eq=not_eq, 140 not_in=not_in, 141 is_critical=is_critical, 142 columns=columns, 143 stattest=stattest, 144 cat_stattest=cat_stattest, 145 num_stattest=num_stattest, 146 text_stattest=text_stattest, 147 per_column_stattest=per_column_stattest, 148 stattest_threshold=stattest_threshold, 149 cat_stattest_threshold=cat_stattest_threshold, 150 num_stattest_threshold=num_stattest_threshold, 151 text_stattest_threshold=text_stattest_threshold, 152 per_column_stattest_threshold=per_column_stattest_threshold, 153 feature_importance=feature_importance, 154 ) 155 self._metric = DataDriftTable( 156 columns=self.columns, 157 stattest=self.stattest, 158 cat_stattest=self.cat_stattest, 159 num_stattest=self.num_stattest, 160 text_stattest=self.text_stattest, 161 per_column_stattest=self.per_column_stattest, 162 stattest_threshold=self.stattest_threshold, 163 cat_stattest_threshold=self.cat_stattest_threshold, 164 num_stattest_threshold=self.num_stattest_threshold, 165 text_stattest_threshold=self.text_stattest_threshold, 166 per_column_stattest_threshold=self.per_column_stattest_threshold, 167 feature_importance=self.feature_importance, 168 ) 169 170 @property 171 def metric(self): 172 return self._metric 173 174 def check(self): 175 result = super().check() 176 metrics = self.metric.get_result() 177 178 return TestResult( 179 name=result.name, 180 description=result.description, 181 status=TestStatus(result.status), 182 group=self.group, 183 parameters=ColumnsDriftParameters.from_data_drift_table(metrics, self.get_condition()), 184 ) 185 186 187 class TestNumberOfDriftedColumns(BaseDataDriftMetricsTest): 188 class Config: 189 type_alias = "evidently:test:TestNumberOfDriftedColumns" 190 191 name: ClassVar = "Number of Drifted Features" 192 193 def get_condition(self) -> TestValueCondition: 194 if self.condition.has_condition(): 195 return self.condition 196 else: 197 return TestValueCondition(lt=max(0, self.metric.get_result().number_of_columns // 3)) 198 199 def calculate_value_for_test(self) -> Numeric: 200 return self.metric.get_result().number_of_drifted_columns 201 202 def get_description(self, value: Numeric) -> str: 203 n_features = self.metric.get_result().number_of_columns 204 return ( 205 f"The drift is detected for {value} out of {n_features} features. " 206 f"The test threshold is {self.get_condition()}." 207 ) 208 209 210 class TestShareOfDriftedColumns(BaseDataDriftMetricsTest): 211 class Config: 212 type_alias = "evidently:test:TestShareOfDriftedColumns" 213 214 name: ClassVar = "Share of Drifted Columns" 215 216 def get_condition(self) -> TestValueCondition: 217 if self.condition.has_condition(): 218 return self.condition 219 else: 220 return TestValueCondition(lt=0.3) 221 222 def calculate_value_for_test(self) -> Numeric: 223 return self.metric.get_result().share_of_drifted_columns 224 225 def get_description(self, value: Numeric) -> str: 226 n_drifted_features = self.metric.get_result().number_of_drifted_columns 227 n_features = self.metric.get_result().number_of_columns 228 return ( 229 f"The drift is detected for {value * 100:.3g}% features " 230 f"({n_drifted_features} out of {n_features}). The test threshold is {self.get_condition()}" 231 ) 232 233 234 class TestColumnDrift(Test): 235 class Config: 236 type_alias = "evidently:test:TestColumnDrift" 237 238 name: ClassVar = "Drift per Column" 239 group: ClassVar = DATA_DRIFT_GROUP.id 240 _metric: ColumnDriftMetric 241 column_name: ColumnName 242 stattest: Optional[PossibleStatTestType] = None 243 stattest_threshold: Optional[float] = None 244 245 def __init__( 246 self, 247 column_name: Union[str, ColumnName], 248 stattest: Optional[PossibleStatTestType] = None, 249 stattest_threshold: Optional[float] = None, 250 is_critical: bool = True, 251 ): 252 self.column_name = ColumnName.from_any(column_name) 253 self.stattest = stattest 254 self.stattest_threshold = stattest_threshold 255 256 super().__init__(is_critical=is_critical) 257 self._metric = ColumnDriftMetric( 258 column_name=self.column_name, 259 stattest=self.stattest, 260 stattest_threshold=self.stattest_threshold, 261 ) 262 263 @property 264 def metric(self): 265 return self._metric 266 267 def check(self): 268 drift_info = self.metric.get_result() 269 270 p_value = np.round(drift_info.drift_score, 3) 271 stattest_name = drift_info.stattest_name 272 threshold = drift_info.stattest_threshold 273 description = ( 274 f"The drift score for the feature **{self.column_name.display_name}** is {p_value:.3g}. " 275 f"The drift detection method is {stattest_name}. " 276 f"The drift detection threshold is {threshold}." 277 ) 278 279 if not drift_info.drift_detected: 280 result_status = TestStatus.SUCCESS 281 282 else: 283 result_status = TestStatus.FAIL 284 285 return TestResult( 286 name=self.name, 287 description=description, 288 status=result_status, 289 group=self.group, 290 parameters=ColumnDriftParameter.from_metric(drift_info, column_name=self.column_name.display_name), 291 ) 292 293 def groups(self) -> Dict[str, str]: 294 return { 295 GroupingTypes.ByFeature.id: self.column_name.display_name, 296 } 297 298 299 class TestAllFeaturesValueDrift(BaseGenerator): 300 """Create value drift tests for numeric and category features""" 301 302 columns: Optional[List[str]] 303 stattest: Optional[PossibleStatTestType] 304 cat_stattest: Optional[PossibleStatTestType] 305 num_stattest: Optional[PossibleStatTestType] 306 text_stattest: Optional[PossibleStatTestType] 307 per_column_stattest: Optional[Dict[str, PossibleStatTestType]] 308 stattest_threshold: Optional[float] 309 cat_stattest_threshold: Optional[float] 310 num_stattest_threshold: Optional[float] 311 text_stattest_threshold: Optional[float] 312 per_column_stattest_threshold: Optional[Dict[str, float]] 313 314 def __init__( 315 self, 316 columns: Optional[List[str]] = None, 317 stattest: Optional[PossibleStatTestType] = None, 318 cat_stattest: Optional[PossibleStatTestType] = None, 319 num_stattest: Optional[PossibleStatTestType] = None, 320 text_stattest: Optional[PossibleStatTestType] = None, 321 per_column_stattest: Optional[Dict[str, PossibleStatTestType]] = None, 322 stattest_threshold: Optional[float] = None, 323 cat_stattest_threshold: Optional[float] = None, 324 num_stattest_threshold: Optional[float] = None, 325 text_stattest_threshold: Optional[float] = None, 326 per_column_stattest_threshold: Optional[Dict[str, float]] = None, 327 is_critical: bool = True, 328 ): 329 self.is_critical = is_critical 330 self.columns = columns 331 self.stattest = stattest 332 self.cat_stattest = cat_stattest 333 self.num_stattest = num_stattest 334 self.text_stattest = text_stattest 335 self.per_column_stattest = per_column_stattest 336 self.stattest_threshold = stattest_threshold 337 self.cat_stattest_threshold = cat_stattest_threshold 338 self.num_stattest_threshold = num_stattest_threshold 339 self.text_stattest_threshold = text_stattest_threshold 340 self.per_column_stattest_threshold = per_column_stattest_threshold 341 342 def generate(self, data_definition: DataDefinition) -> List[TestColumnDrift]: 343 results = [] 344 for column in data_definition.get_columns(ColumnType.Categorical, features_only=True): 345 if self.columns and column.column_name not in self.columns: 346 continue 347 stattest, threshold = resolve_stattest_threshold( 348 column.column_name, 349 "cat", 350 self.stattest, 351 self.cat_stattest, 352 self.num_stattest, 353 self.text_stattest, 354 self.per_column_stattest, 355 self.stattest_threshold, 356 self.cat_stattest_threshold, 357 self.num_stattest_threshold, 358 self.text_stattest_threshold, 359 self.per_column_stattest_threshold, 360 ) 361 results.append( 362 TestColumnDrift( 363 column_name=column.column_name, 364 stattest=stattest, 365 stattest_threshold=threshold, 366 is_critical=self.is_critical, 367 ) 368 ) 369 for column in data_definition.get_columns(ColumnType.Numerical, features_only=True): 370 if self.columns and column.column_name not in self.columns: 371 continue 372 stattest, threshold = resolve_stattest_threshold( 373 column.column_name, 374 "num", 375 self.stattest, 376 self.cat_stattest, 377 self.num_stattest, 378 self.text_stattest, 379 self.per_column_stattest, 380 self.stattest_threshold, 381 self.cat_stattest_threshold, 382 self.num_stattest_threshold, 383 self.text_stattest_threshold, 384 self.per_column_stattest_threshold, 385 ) 386 results.append( 387 TestColumnDrift( 388 column_name=column.column_name, 389 stattest=stattest, 390 stattest_threshold=threshold, 391 is_critical=self.is_critical, 392 ) 393 ) 394 for column in data_definition.get_columns(ColumnType.Text, features_only=True): 395 if self.columns and column.column_name not in self.columns: 396 continue 397 stattest, threshold = resolve_stattest_threshold( 398 column.column_name, 399 "text", 400 self.stattest, 401 self.cat_stattest, 402 self.num_stattest, 403 self.text_stattest, 404 self.per_column_stattest, 405 self.stattest_threshold, 406 self.cat_stattest_threshold, 407 self.num_stattest_threshold, 408 self.text_stattest_threshold, 409 self.per_column_stattest_threshold, 410 ) 411 results.append( 412 TestColumnDrift( 413 column_name=column.column_name, 414 stattest=stattest, 415 stattest_threshold=threshold, 416 is_critical=self.is_critical, 417 ) 418 ) 419 return results 420 421 422 class TestCustomFeaturesValueDrift(BaseGenerator): 423 """Create value drift tests for specified features""" 424 425 features: List[str] 426 stattest: Optional[PossibleStatTestType] = None 427 cat_stattest: Optional[PossibleStatTestType] = None 428 num_stattest: Optional[PossibleStatTestType] = None 429 text_stattest: Optional[PossibleStatTestType] = None 430 per_column_stattest: Optional[Dict[str, PossibleStatTestType]] = None 431 stattest_threshold: Optional[float] = None 432 cat_stattest_threshold: Optional[float] = None 433 num_stattest_threshold: Optional[float] = None 434 text_stattest_threshold: Optional[float] = None 435 per_column_stattest_threshold: Optional[Dict[str, float]] = None 436 437 def __init__( 438 self, 439 features: List[str], 440 stattest: Optional[PossibleStatTestType] = None, 441 cat_stattest: Optional[PossibleStatTestType] = None, 442 num_stattest: Optional[PossibleStatTestType] = None, 443 text_stattest: Optional[PossibleStatTestType] = None, 444 per_column_stattest: Optional[Dict[str, PossibleStatTestType]] = None, 445 stattest_threshold: Optional[float] = None, 446 cat_stattest_threshold: Optional[float] = None, 447 num_stattest_threshold: Optional[float] = None, 448 text_stattest_threshold: Optional[float] = None, 449 per_column_stattest_threshold: Optional[Dict[str, float]] = None, 450 is_critical: bool = True, 451 ): 452 self.is_critical = is_critical 453 self.features = features 454 self.stattest = stattest 455 self.cat_stattest = cat_stattest 456 self.num_stattest = num_stattest 457 self.text_stattest = text_stattest 458 self.per_column_stattest = per_column_stattest 459 self.stattest_threshold = stattest_threshold 460 self.cat_stattest_threshold = cat_stattest_threshold 461 self.num_stattest_threshold = num_stattest_threshold 462 self.text_stattest_threshold = text_stattest_threshold 463 self.per_feature_threshold = per_column_stattest_threshold 464 465 def generate(self, data_definition: DataDefinition) -> List[TestColumnDrift]: 466 result = [] 467 for name in self.features: 468 column = data_definition.get_column(name) 469 stattest, threshold = resolve_stattest_threshold( 470 name, 471 "cat" 472 if column.column_type == ColumnType.Categorical 473 else "num" 474 if column.column_type == ColumnType.Numerical 475 else "text" 476 if column.column_type == ColumnType.Text 477 else "datetime", 478 self.stattest, 479 self.cat_stattest, 480 self.num_stattest, 481 self.text_stattest, 482 self.per_column_stattest, 483 self.stattest_threshold, 484 self.cat_stattest_threshold, 485 self.num_stattest_threshold, 486 self.text_stattest_threshold, 487 self.per_column_stattest_threshold, 488 ) 489 result.append( 490 TestColumnDrift( 491 column_name=name, 492 stattest=stattest, 493 stattest_threshold=threshold, 494 is_critical=self.is_critical, 495 ) 496 ) 497 return result 498 499 500 @default_renderer(wrap_type=TestNumberOfDriftedColumns) 501 class TestNumberOfDriftedColumnsRenderer(TestRenderer): 502 def render_html(self, obj: TestNumberOfDriftedColumns) -> TestHtmlInfo: 503 info = super().render_html(obj) 504 result = obj.get_result() 505 if result.status == TestStatus.ERROR: 506 return info 507 parameters = result.parameters 508 assert isinstance(parameters, ColumnsDriftParameters) 509 df = parameters.to_dataframe() 510 df = df.sort_values("Data Drift") 511 columns = ["Feature name"] 512 current_fi = obj.metric.get_result().current_fi 513 reference_fi = obj.metric.get_result().reference_fi 514 if current_fi is not None: 515 df["current_feature_importance"] = df["Feature name"].apply(lambda x: current_fi.get(x, "")) 516 columns.append("current_feature_importance") 517 if reference_fi is not None: 518 df["reference_feature_importance"] = df["Feature name"].apply(lambda x: reference_fi.get(x, "")) 519 columns.append("reference_feature_importance") 520 columns += ["Stattest", "Drift score", "Threshold", "Data Drift"] 521 df = df[columns] 522 info.with_details( 523 title="Drift Table", 524 info=table_data(column_names=df.columns.to_list(), data=df.values), 525 ) 526 return info 527 528 529 @default_renderer(wrap_type=TestShareOfDriftedColumns) 530 class TestShareOfDriftedColumnsRenderer(TestRenderer): 531 def render_html(self, obj: TestShareOfDriftedColumns) -> TestHtmlInfo: 532 info = super().render_html(obj) 533 result = obj.get_result() 534 if result.status == TestStatus.ERROR: 535 return info 536 parameters = result.parameters 537 current_fi = obj.metric.get_result().current_fi 538 reference_fi = obj.metric.get_result().reference_fi 539 assert isinstance(parameters, ColumnsDriftParameters) 540 df = parameters.to_dataframe() 541 df = df.sort_values("Data Drift") 542 columns = ["Feature name"] 543 if current_fi is not None: 544 df["current_feature_importance"] = df["Feature name"].apply(lambda x: current_fi.get(x, "")) 545 columns.append("current_feature_importance") 546 if reference_fi is not None: 547 df["reference_feature_importance"] = df["Feature name"].apply(lambda x: reference_fi.get(x, "")) 548 columns.append("reference_feature_importance") 549 columns += ["Stattest", "Drift score", "Threshold", "Data Drift"] 550 df = df[columns] 551 info.details = [ 552 DetailsInfo( 553 id="drift_table", 554 title="", 555 info=BaseWidgetInfo( 556 title="", 557 type="table", 558 params={"header": df.columns.to_list(), "data": df.values}, 559 size=2, 560 ), 561 ), 562 ] 563 return info 564 565 566 @default_renderer(wrap_type=TestColumnDrift) 567 class TestColumnDriftRenderer(TestRenderer): 568 def render_html(self, obj: TestColumnDrift) -> TestHtmlInfo: 569 info = super().render_html(obj) 570 result = obj.metric.get_result() 571 column_name = obj.column_name 572 if result.column_type == "text": 573 if result.current.characteristic_words is not None and result.reference.characteristic_words is not None: 574 info.details = [ 575 DetailsInfo( 576 id=f"{column_name} dritf curr", 577 title="current: characteristic words", 578 info=BaseWidgetInfo( 579 title="", 580 type="table", 581 params={ 582 "header": ["", ""], 583 "data": [[el, ""] for el in result.current.characteristic_words], 584 }, 585 size=2, 586 ), 587 ), 588 DetailsInfo( 589 id=f"{column_name} dritf ref", 590 title="reference: characteristic words", 591 info=BaseWidgetInfo( 592 title="", 593 type="table", 594 params={ 595 "header": ["", ""], 596 "data": [[el, ""] for el in result.reference.characteristic_words], 597 }, 598 size=2, 599 ), 600 ), 601 ] 602 else: 603 return info 604 else: 605 if result.current.distribution is None: 606 raise ValueError("Expected data is missing") 607 fig = plot_distr_with_cond_perc_button( 608 hist_curr=HistogramData.from_distribution(result.current.distribution), 609 hist_ref=HistogramData.from_distribution(result.reference.distribution), 610 xaxis_name="", 611 yaxis_name="count", 612 yaxis_name_perc="percent", 613 color_options=self.color_options, 614 to_json=False, 615 condition=None, 616 ) 617 info.with_details(f"{column_name}", plotly_figure(title="", figure=fig)) 618 return info 619 620 621 class TestEmbeddingsDrift(Test): 622 class Config: 623 type_alias = "evidently:test:TestEmbeddingsDrift" 624 625 name: ClassVar = "Drift for embeddings" 626 group: ClassVar = DATA_DRIFT_GROUP.id 627 embeddings_name: str 628 drift_method: Optional[DriftMethod] 629 _metric: EmbeddingsDriftMetric 630 631 def __init__(self, embeddings_name: str, drift_method: Optional[DriftMethod] = None, is_critical: bool = True): 632 self.embeddings_name = embeddings_name 633 self.drift_method = drift_method 634 super().__init__(is_critical=is_critical) 635 self._metric = EmbeddingsDriftMetric(embeddings_name=self.embeddings_name, drift_method=self.drift_method) 636 637 @property 638 def metric(self): 639 return self._metric 640 641 def check(self): 642 drift_info = self.metric.get_result() 643 drift_score = drift_info.drift_score 644 if drift_info.drift_detected: 645 drift = "detected" 646 647 else: 648 drift = "not detected" 649 650 description = ( 651 f"Data drift {drift}. " 652 f"The drift score for the embedding set **{drift_info.embeddings_name}** is {drift_score:.3g}. " 653 f"The drift detection method is **{drift_info.method_name}**. " 654 ) 655 if not drift_info.drift_detected: 656 result_status = TestStatus.SUCCESS 657 658 else: 659 result_status = TestStatus.FAIL 660 return TestResult( 661 name=self.name, 662 description=description, 663 status=result_status, 664 group=self.group, 665 ) 666 667 def groups(self) -> Dict[str, str]: 668 return {} 669 670 671 @default_renderer(wrap_type=TestEmbeddingsDrift) 672 class TestEmbeddingsDriftRenderer(TestRenderer): 673 def render_html(self, obj: TestEmbeddingsDrift) -> TestHtmlInfo: 674 info = super().render_html(obj) 675 result = obj.metric.get_result() 676 fig = plot_contour_single(result.current, result.reference, "component 1", "component 2") 677 info.with_details(f"Drift in embeddings '{result.embeddings_name}'", plotly_figure(title="", figure=fig)) 678 return info