/ src / evidently / legacy / tests / data_drift_tests.py
data_drift_tests.py
  1  from abc import ABC
  2  from typing import ClassVar
  3  from typing import Dict
  4  from typing import List
  5  from typing import Optional
  6  from typing import Union
  7  
  8  import numpy as np
  9  import pandas as pd
 10  
 11  from evidently.legacy.base_metric import ColumnName
 12  from evidently.legacy.calculations.data_drift import ColumnDataDriftMetrics
 13  from evidently.legacy.calculations.stattests import PossibleStatTestType
 14  from evidently.legacy.core import ColumnType
 15  from evidently.legacy.metric_results import HistogramData
 16  from evidently.legacy.metrics import ColumnDriftMetric
 17  from evidently.legacy.metrics import DataDriftTable
 18  from evidently.legacy.metrics import EmbeddingsDriftMetric
 19  from evidently.legacy.metrics.data_drift.base import WithDriftOptionsFields
 20  from evidently.legacy.metrics.data_drift.data_drift_table import DataDriftTableResults
 21  from evidently.legacy.metrics.data_drift.embedding_drift_methods import DriftMethod
 22  from evidently.legacy.model.widget import BaseWidgetInfo
 23  from evidently.legacy.renderers.base_renderer import DetailsInfo
 24  from evidently.legacy.renderers.base_renderer import TestHtmlInfo
 25  from evidently.legacy.renderers.base_renderer import TestRenderer
 26  from evidently.legacy.renderers.base_renderer import default_renderer
 27  from evidently.legacy.renderers.html_widgets import plotly_figure
 28  from evidently.legacy.renderers.html_widgets import table_data
 29  from evidently.legacy.tests.base_test import BaseCheckValueTest
 30  from evidently.legacy.tests.base_test import ConditionTestParameters
 31  from evidently.legacy.tests.base_test import ExcludeNoneMixin
 32  from evidently.legacy.tests.base_test import GroupData
 33  from evidently.legacy.tests.base_test import GroupingTypes
 34  from evidently.legacy.tests.base_test import Test
 35  from evidently.legacy.tests.base_test import TestParameters
 36  from evidently.legacy.tests.base_test import TestResult
 37  from evidently.legacy.tests.base_test import TestStatus
 38  from evidently.legacy.tests.base_test import TestValueCondition
 39  from evidently.legacy.utils.data_drift_utils import resolve_stattest_threshold
 40  from evidently.legacy.utils.data_preprocessing import DataDefinition
 41  from evidently.legacy.utils.generators import BaseGenerator
 42  from evidently.legacy.utils.types import Numeric
 43  from evidently.legacy.utils.visualizations import plot_contour_single
 44  from evidently.legacy.utils.visualizations import plot_distr_with_cond_perc_button
 45  
 46  DATA_DRIFT_GROUP = GroupData(id="data_drift", title="Data Drift", description="")
 47  GroupingTypes.TestGroup.add_value(DATA_DRIFT_GROUP)
 48  
 49  
 50  class ColumnDriftParameter(ExcludeNoneMixin, TestParameters):  # type: ignore[misc] # pydantic Config
 51      class Config:
 52          type_alias = "evidently:test_parameters:ColumnDriftParameter"
 53  
 54      stattest: str
 55      score: float
 56      threshold: float
 57      detected: bool
 58      column_name: Optional[str] = None
 59  
 60      @classmethod
 61      def from_metric(cls, data: ColumnDataDriftMetrics, column_name: str = None):
 62          return ColumnDriftParameter(
 63              stattest=data.stattest_name,
 64              score=np.round(data.drift_score, 3),
 65              threshold=data.stattest_threshold,
 66              detected=data.drift_detected,
 67              column_name=column_name,
 68          )
 69  
 70  
 71  class ColumnsDriftParameters(ConditionTestParameters):
 72      # todo: rename to columns?
 73      class Config:
 74          type_alias = "evidently:test_parameters:ColumnsDriftParameters"
 75  
 76      features: Dict[str, ColumnDriftParameter]
 77  
 78      @classmethod
 79      def from_data_drift_table(cls, table: DataDriftTableResults, condition: TestValueCondition):
 80          return ColumnsDriftParameters(
 81              features={
 82                  feature: ColumnDriftParameter.from_metric(data) for feature, data in table.drift_by_columns.items()
 83              },
 84              condition=condition,
 85          )
 86  
 87      def to_dataframe(self) -> pd.DataFrame:
 88          return pd.DataFrame(
 89              [
 90                  {
 91                      "Feature name": feature,
 92                      "Stattest": data.stattest,
 93                      "Drift score": data.score,
 94                      "Threshold": data.threshold,
 95                      "Data Drift": "Detected" if data.detected else "Not detected",
 96                  }
 97                  for feature, data in self.features.items()
 98              ],
 99          )
100  
101  
102  class BaseDataDriftMetricsTest(BaseCheckValueTest, WithDriftOptionsFields, ABC):
103      group: ClassVar = DATA_DRIFT_GROUP.id
104      _metric: DataDriftTable
105      columns: Optional[List[str]]
106      feature_importance: Optional[bool]
107  
108      def __init__(
109          self,
110          columns: Optional[List[str]] = None,
111          eq: Optional[Numeric] = None,
112          gt: Optional[Numeric] = None,
113          gte: Optional[Numeric] = None,
114          is_in: Optional[List[Union[Numeric, str, bool]]] = None,
115          lt: Optional[Numeric] = None,
116          lte: Optional[Numeric] = None,
117          not_eq: Optional[Numeric] = None,
118          not_in: Optional[List[Union[Numeric, str, bool]]] = None,
119          stattest: Optional[PossibleStatTestType] = None,
120          cat_stattest: Optional[PossibleStatTestType] = None,
121          num_stattest: Optional[PossibleStatTestType] = None,
122          text_stattest: Optional[PossibleStatTestType] = None,
123          per_column_stattest: Optional[Dict[str, PossibleStatTestType]] = None,
124          stattest_threshold: Optional[float] = None,
125          cat_stattest_threshold: Optional[float] = None,
126          num_stattest_threshold: Optional[float] = None,
127          text_stattest_threshold: Optional[float] = None,
128          per_column_stattest_threshold: Optional[Dict[str, float]] = None,
129          is_critical: bool = True,
130          feature_importance: Optional[bool] = False,
131      ):
132          super().__init__(
133              eq=eq,
134              gt=gt,
135              gte=gte,
136              is_in=is_in,
137              lt=lt,
138              lte=lte,
139              not_eq=not_eq,
140              not_in=not_in,
141              is_critical=is_critical,
142              columns=columns,
143              stattest=stattest,
144              cat_stattest=cat_stattest,
145              num_stattest=num_stattest,
146              text_stattest=text_stattest,
147              per_column_stattest=per_column_stattest,
148              stattest_threshold=stattest_threshold,
149              cat_stattest_threshold=cat_stattest_threshold,
150              num_stattest_threshold=num_stattest_threshold,
151              text_stattest_threshold=text_stattest_threshold,
152              per_column_stattest_threshold=per_column_stattest_threshold,
153              feature_importance=feature_importance,
154          )
155          self._metric = DataDriftTable(
156              columns=self.columns,
157              stattest=self.stattest,
158              cat_stattest=self.cat_stattest,
159              num_stattest=self.num_stattest,
160              text_stattest=self.text_stattest,
161              per_column_stattest=self.per_column_stattest,
162              stattest_threshold=self.stattest_threshold,
163              cat_stattest_threshold=self.cat_stattest_threshold,
164              num_stattest_threshold=self.num_stattest_threshold,
165              text_stattest_threshold=self.text_stattest_threshold,
166              per_column_stattest_threshold=self.per_column_stattest_threshold,
167              feature_importance=self.feature_importance,
168          )
169  
170      @property
171      def metric(self):
172          return self._metric
173  
174      def check(self):
175          result = super().check()
176          metrics = self.metric.get_result()
177  
178          return TestResult(
179              name=result.name,
180              description=result.description,
181              status=TestStatus(result.status),
182              group=self.group,
183              parameters=ColumnsDriftParameters.from_data_drift_table(metrics, self.get_condition()),
184          )
185  
186  
187  class TestNumberOfDriftedColumns(BaseDataDriftMetricsTest):
188      class Config:
189          type_alias = "evidently:test:TestNumberOfDriftedColumns"
190  
191      name: ClassVar = "Number of Drifted Features"
192  
193      def get_condition(self) -> TestValueCondition:
194          if self.condition.has_condition():
195              return self.condition
196          else:
197              return TestValueCondition(lt=max(0, self.metric.get_result().number_of_columns // 3))
198  
199      def calculate_value_for_test(self) -> Numeric:
200          return self.metric.get_result().number_of_drifted_columns
201  
202      def get_description(self, value: Numeric) -> str:
203          n_features = self.metric.get_result().number_of_columns
204          return (
205              f"The drift is detected for {value} out of {n_features} features. "
206              f"The test threshold is {self.get_condition()}."
207          )
208  
209  
210  class TestShareOfDriftedColumns(BaseDataDriftMetricsTest):
211      class Config:
212          type_alias = "evidently:test:TestShareOfDriftedColumns"
213  
214      name: ClassVar = "Share of Drifted Columns"
215  
216      def get_condition(self) -> TestValueCondition:
217          if self.condition.has_condition():
218              return self.condition
219          else:
220              return TestValueCondition(lt=0.3)
221  
222      def calculate_value_for_test(self) -> Numeric:
223          return self.metric.get_result().share_of_drifted_columns
224  
225      def get_description(self, value: Numeric) -> str:
226          n_drifted_features = self.metric.get_result().number_of_drifted_columns
227          n_features = self.metric.get_result().number_of_columns
228          return (
229              f"The drift is detected for {value * 100:.3g}% features "
230              f"({n_drifted_features} out of {n_features}). The test threshold is {self.get_condition()}"
231          )
232  
233  
234  class TestColumnDrift(Test):
235      class Config:
236          type_alias = "evidently:test:TestColumnDrift"
237  
238      name: ClassVar = "Drift per Column"
239      group: ClassVar = DATA_DRIFT_GROUP.id
240      _metric: ColumnDriftMetric
241      column_name: ColumnName
242      stattest: Optional[PossibleStatTestType] = None
243      stattest_threshold: Optional[float] = None
244  
245      def __init__(
246          self,
247          column_name: Union[str, ColumnName],
248          stattest: Optional[PossibleStatTestType] = None,
249          stattest_threshold: Optional[float] = None,
250          is_critical: bool = True,
251      ):
252          self.column_name = ColumnName.from_any(column_name)
253          self.stattest = stattest
254          self.stattest_threshold = stattest_threshold
255  
256          super().__init__(is_critical=is_critical)
257          self._metric = ColumnDriftMetric(
258              column_name=self.column_name,
259              stattest=self.stattest,
260              stattest_threshold=self.stattest_threshold,
261          )
262  
263      @property
264      def metric(self):
265          return self._metric
266  
267      def check(self):
268          drift_info = self.metric.get_result()
269  
270          p_value = np.round(drift_info.drift_score, 3)
271          stattest_name = drift_info.stattest_name
272          threshold = drift_info.stattest_threshold
273          description = (
274              f"The drift score for the feature **{self.column_name.display_name}** is {p_value:.3g}. "
275              f"The drift detection method is {stattest_name}. "
276              f"The drift detection threshold is {threshold}."
277          )
278  
279          if not drift_info.drift_detected:
280              result_status = TestStatus.SUCCESS
281  
282          else:
283              result_status = TestStatus.FAIL
284  
285          return TestResult(
286              name=self.name,
287              description=description,
288              status=result_status,
289              group=self.group,
290              parameters=ColumnDriftParameter.from_metric(drift_info, column_name=self.column_name.display_name),
291          )
292  
293      def groups(self) -> Dict[str, str]:
294          return {
295              GroupingTypes.ByFeature.id: self.column_name.display_name,
296          }
297  
298  
299  class TestAllFeaturesValueDrift(BaseGenerator):
300      """Create value drift tests for numeric and category features"""
301  
302      columns: Optional[List[str]]
303      stattest: Optional[PossibleStatTestType]
304      cat_stattest: Optional[PossibleStatTestType]
305      num_stattest: Optional[PossibleStatTestType]
306      text_stattest: Optional[PossibleStatTestType]
307      per_column_stattest: Optional[Dict[str, PossibleStatTestType]]
308      stattest_threshold: Optional[float]
309      cat_stattest_threshold: Optional[float]
310      num_stattest_threshold: Optional[float]
311      text_stattest_threshold: Optional[float]
312      per_column_stattest_threshold: Optional[Dict[str, float]]
313  
314      def __init__(
315          self,
316          columns: Optional[List[str]] = None,
317          stattest: Optional[PossibleStatTestType] = None,
318          cat_stattest: Optional[PossibleStatTestType] = None,
319          num_stattest: Optional[PossibleStatTestType] = None,
320          text_stattest: Optional[PossibleStatTestType] = None,
321          per_column_stattest: Optional[Dict[str, PossibleStatTestType]] = None,
322          stattest_threshold: Optional[float] = None,
323          cat_stattest_threshold: Optional[float] = None,
324          num_stattest_threshold: Optional[float] = None,
325          text_stattest_threshold: Optional[float] = None,
326          per_column_stattest_threshold: Optional[Dict[str, float]] = None,
327          is_critical: bool = True,
328      ):
329          self.is_critical = is_critical
330          self.columns = columns
331          self.stattest = stattest
332          self.cat_stattest = cat_stattest
333          self.num_stattest = num_stattest
334          self.text_stattest = text_stattest
335          self.per_column_stattest = per_column_stattest
336          self.stattest_threshold = stattest_threshold
337          self.cat_stattest_threshold = cat_stattest_threshold
338          self.num_stattest_threshold = num_stattest_threshold
339          self.text_stattest_threshold = text_stattest_threshold
340          self.per_column_stattest_threshold = per_column_stattest_threshold
341  
342      def generate(self, data_definition: DataDefinition) -> List[TestColumnDrift]:
343          results = []
344          for column in data_definition.get_columns(ColumnType.Categorical, features_only=True):
345              if self.columns and column.column_name not in self.columns:
346                  continue
347              stattest, threshold = resolve_stattest_threshold(
348                  column.column_name,
349                  "cat",
350                  self.stattest,
351                  self.cat_stattest,
352                  self.num_stattest,
353                  self.text_stattest,
354                  self.per_column_stattest,
355                  self.stattest_threshold,
356                  self.cat_stattest_threshold,
357                  self.num_stattest_threshold,
358                  self.text_stattest_threshold,
359                  self.per_column_stattest_threshold,
360              )
361              results.append(
362                  TestColumnDrift(
363                      column_name=column.column_name,
364                      stattest=stattest,
365                      stattest_threshold=threshold,
366                      is_critical=self.is_critical,
367                  )
368              )
369          for column in data_definition.get_columns(ColumnType.Numerical, features_only=True):
370              if self.columns and column.column_name not in self.columns:
371                  continue
372              stattest, threshold = resolve_stattest_threshold(
373                  column.column_name,
374                  "num",
375                  self.stattest,
376                  self.cat_stattest,
377                  self.num_stattest,
378                  self.text_stattest,
379                  self.per_column_stattest,
380                  self.stattest_threshold,
381                  self.cat_stattest_threshold,
382                  self.num_stattest_threshold,
383                  self.text_stattest_threshold,
384                  self.per_column_stattest_threshold,
385              )
386              results.append(
387                  TestColumnDrift(
388                      column_name=column.column_name,
389                      stattest=stattest,
390                      stattest_threshold=threshold,
391                      is_critical=self.is_critical,
392                  )
393              )
394          for column in data_definition.get_columns(ColumnType.Text, features_only=True):
395              if self.columns and column.column_name not in self.columns:
396                  continue
397              stattest, threshold = resolve_stattest_threshold(
398                  column.column_name,
399                  "text",
400                  self.stattest,
401                  self.cat_stattest,
402                  self.num_stattest,
403                  self.text_stattest,
404                  self.per_column_stattest,
405                  self.stattest_threshold,
406                  self.cat_stattest_threshold,
407                  self.num_stattest_threshold,
408                  self.text_stattest_threshold,
409                  self.per_column_stattest_threshold,
410              )
411              results.append(
412                  TestColumnDrift(
413                      column_name=column.column_name,
414                      stattest=stattest,
415                      stattest_threshold=threshold,
416                      is_critical=self.is_critical,
417                  )
418              )
419          return results
420  
421  
422  class TestCustomFeaturesValueDrift(BaseGenerator):
423      """Create value drift tests for specified features"""
424  
425      features: List[str]
426      stattest: Optional[PossibleStatTestType] = None
427      cat_stattest: Optional[PossibleStatTestType] = None
428      num_stattest: Optional[PossibleStatTestType] = None
429      text_stattest: Optional[PossibleStatTestType] = None
430      per_column_stattest: Optional[Dict[str, PossibleStatTestType]] = None
431      stattest_threshold: Optional[float] = None
432      cat_stattest_threshold: Optional[float] = None
433      num_stattest_threshold: Optional[float] = None
434      text_stattest_threshold: Optional[float] = None
435      per_column_stattest_threshold: Optional[Dict[str, float]] = None
436  
437      def __init__(
438          self,
439          features: List[str],
440          stattest: Optional[PossibleStatTestType] = None,
441          cat_stattest: Optional[PossibleStatTestType] = None,
442          num_stattest: Optional[PossibleStatTestType] = None,
443          text_stattest: Optional[PossibleStatTestType] = None,
444          per_column_stattest: Optional[Dict[str, PossibleStatTestType]] = None,
445          stattest_threshold: Optional[float] = None,
446          cat_stattest_threshold: Optional[float] = None,
447          num_stattest_threshold: Optional[float] = None,
448          text_stattest_threshold: Optional[float] = None,
449          per_column_stattest_threshold: Optional[Dict[str, float]] = None,
450          is_critical: bool = True,
451      ):
452          self.is_critical = is_critical
453          self.features = features
454          self.stattest = stattest
455          self.cat_stattest = cat_stattest
456          self.num_stattest = num_stattest
457          self.text_stattest = text_stattest
458          self.per_column_stattest = per_column_stattest
459          self.stattest_threshold = stattest_threshold
460          self.cat_stattest_threshold = cat_stattest_threshold
461          self.num_stattest_threshold = num_stattest_threshold
462          self.text_stattest_threshold = text_stattest_threshold
463          self.per_feature_threshold = per_column_stattest_threshold
464  
465      def generate(self, data_definition: DataDefinition) -> List[TestColumnDrift]:
466          result = []
467          for name in self.features:
468              column = data_definition.get_column(name)
469              stattest, threshold = resolve_stattest_threshold(
470                  name,
471                  "cat"
472                  if column.column_type == ColumnType.Categorical
473                  else "num"
474                  if column.column_type == ColumnType.Numerical
475                  else "text"
476                  if column.column_type == ColumnType.Text
477                  else "datetime",
478                  self.stattest,
479                  self.cat_stattest,
480                  self.num_stattest,
481                  self.text_stattest,
482                  self.per_column_stattest,
483                  self.stattest_threshold,
484                  self.cat_stattest_threshold,
485                  self.num_stattest_threshold,
486                  self.text_stattest_threshold,
487                  self.per_column_stattest_threshold,
488              )
489              result.append(
490                  TestColumnDrift(
491                      column_name=name,
492                      stattest=stattest,
493                      stattest_threshold=threshold,
494                      is_critical=self.is_critical,
495                  )
496              )
497          return result
498  
499  
500  @default_renderer(wrap_type=TestNumberOfDriftedColumns)
501  class TestNumberOfDriftedColumnsRenderer(TestRenderer):
502      def render_html(self, obj: TestNumberOfDriftedColumns) -> TestHtmlInfo:
503          info = super().render_html(obj)
504          result = obj.get_result()
505          if result.status == TestStatus.ERROR:
506              return info
507          parameters = result.parameters
508          assert isinstance(parameters, ColumnsDriftParameters)
509          df = parameters.to_dataframe()
510          df = df.sort_values("Data Drift")
511          columns = ["Feature name"]
512          current_fi = obj.metric.get_result().current_fi
513          reference_fi = obj.metric.get_result().reference_fi
514          if current_fi is not None:
515              df["current_feature_importance"] = df["Feature name"].apply(lambda x: current_fi.get(x, ""))
516              columns.append("current_feature_importance")
517          if reference_fi is not None:
518              df["reference_feature_importance"] = df["Feature name"].apply(lambda x: reference_fi.get(x, ""))
519              columns.append("reference_feature_importance")
520          columns += ["Stattest", "Drift score", "Threshold", "Data Drift"]
521          df = df[columns]
522          info.with_details(
523              title="Drift Table",
524              info=table_data(column_names=df.columns.to_list(), data=df.values),
525          )
526          return info
527  
528  
529  @default_renderer(wrap_type=TestShareOfDriftedColumns)
530  class TestShareOfDriftedColumnsRenderer(TestRenderer):
531      def render_html(self, obj: TestShareOfDriftedColumns) -> TestHtmlInfo:
532          info = super().render_html(obj)
533          result = obj.get_result()
534          if result.status == TestStatus.ERROR:
535              return info
536          parameters = result.parameters
537          current_fi = obj.metric.get_result().current_fi
538          reference_fi = obj.metric.get_result().reference_fi
539          assert isinstance(parameters, ColumnsDriftParameters)
540          df = parameters.to_dataframe()
541          df = df.sort_values("Data Drift")
542          columns = ["Feature name"]
543          if current_fi is not None:
544              df["current_feature_importance"] = df["Feature name"].apply(lambda x: current_fi.get(x, ""))
545              columns.append("current_feature_importance")
546          if reference_fi is not None:
547              df["reference_feature_importance"] = df["Feature name"].apply(lambda x: reference_fi.get(x, ""))
548              columns.append("reference_feature_importance")
549          columns += ["Stattest", "Drift score", "Threshold", "Data Drift"]
550          df = df[columns]
551          info.details = [
552              DetailsInfo(
553                  id="drift_table",
554                  title="",
555                  info=BaseWidgetInfo(
556                      title="",
557                      type="table",
558                      params={"header": df.columns.to_list(), "data": df.values},
559                      size=2,
560                  ),
561              ),
562          ]
563          return info
564  
565  
566  @default_renderer(wrap_type=TestColumnDrift)
567  class TestColumnDriftRenderer(TestRenderer):
568      def render_html(self, obj: TestColumnDrift) -> TestHtmlInfo:
569          info = super().render_html(obj)
570          result = obj.metric.get_result()
571          column_name = obj.column_name
572          if result.column_type == "text":
573              if result.current.characteristic_words is not None and result.reference.characteristic_words is not None:
574                  info.details = [
575                      DetailsInfo(
576                          id=f"{column_name} dritf curr",
577                          title="current: characteristic words",
578                          info=BaseWidgetInfo(
579                              title="",
580                              type="table",
581                              params={
582                                  "header": ["", ""],
583                                  "data": [[el, ""] for el in result.current.characteristic_words],
584                              },
585                              size=2,
586                          ),
587                      ),
588                      DetailsInfo(
589                          id=f"{column_name} dritf ref",
590                          title="reference: characteristic words",
591                          info=BaseWidgetInfo(
592                              title="",
593                              type="table",
594                              params={
595                                  "header": ["", ""],
596                                  "data": [[el, ""] for el in result.reference.characteristic_words],
597                              },
598                              size=2,
599                          ),
600                      ),
601                  ]
602              else:
603                  return info
604          else:
605              if result.current.distribution is None:
606                  raise ValueError("Expected data is missing")
607              fig = plot_distr_with_cond_perc_button(
608                  hist_curr=HistogramData.from_distribution(result.current.distribution),
609                  hist_ref=HistogramData.from_distribution(result.reference.distribution),
610                  xaxis_name="",
611                  yaxis_name="count",
612                  yaxis_name_perc="percent",
613                  color_options=self.color_options,
614                  to_json=False,
615                  condition=None,
616              )
617              info.with_details(f"{column_name}", plotly_figure(title="", figure=fig))
618          return info
619  
620  
621  class TestEmbeddingsDrift(Test):
622      class Config:
623          type_alias = "evidently:test:TestEmbeddingsDrift"
624  
625      name: ClassVar = "Drift for embeddings"
626      group: ClassVar = DATA_DRIFT_GROUP.id
627      embeddings_name: str
628      drift_method: Optional[DriftMethod]
629      _metric: EmbeddingsDriftMetric
630  
631      def __init__(self, embeddings_name: str, drift_method: Optional[DriftMethod] = None, is_critical: bool = True):
632          self.embeddings_name = embeddings_name
633          self.drift_method = drift_method
634          super().__init__(is_critical=is_critical)
635          self._metric = EmbeddingsDriftMetric(embeddings_name=self.embeddings_name, drift_method=self.drift_method)
636  
637      @property
638      def metric(self):
639          return self._metric
640  
641      def check(self):
642          drift_info = self.metric.get_result()
643          drift_score = drift_info.drift_score
644          if drift_info.drift_detected:
645              drift = "detected"
646  
647          else:
648              drift = "not detected"
649  
650          description = (
651              f"Data drift {drift}. "
652              f"The drift score for the embedding set **{drift_info.embeddings_name}** is {drift_score:.3g}. "
653              f"The drift detection method is **{drift_info.method_name}**. "
654          )
655          if not drift_info.drift_detected:
656              result_status = TestStatus.SUCCESS
657  
658          else:
659              result_status = TestStatus.FAIL
660          return TestResult(
661              name=self.name,
662              description=description,
663              status=result_status,
664              group=self.group,
665          )
666  
667      def groups(self) -> Dict[str, str]:
668          return {}
669  
670  
671  @default_renderer(wrap_type=TestEmbeddingsDrift)
672  class TestEmbeddingsDriftRenderer(TestRenderer):
673      def render_html(self, obj: TestEmbeddingsDrift) -> TestHtmlInfo:
674          info = super().render_html(obj)
675          result = obj.metric.get_result()
676          fig = plot_contour_single(result.current, result.reference, "component 1", "component 2")
677          info.with_details(f"Drift in embeddings '{result.embeddings_name}'", plotly_figure(title="", figure=fig))
678          return info