/ src / evidently / legacy / metrics / data_drift / embeddings_drift.py
embeddings_drift.py
  1  from typing import List
  2  from typing import Optional
  3  
  4  import numpy as np
  5  import pandas as pd
  6  from sklearn.manifold import TSNE
  7  
  8  from evidently.legacy.base_metric import InputData
  9  from evidently.legacy.base_metric import Metric
 10  from evidently.legacy.base_metric import MetricResult
 11  from evidently.legacy.core import IncludeTags
 12  from evidently.legacy.metrics.data_drift.embedding_drift_methods import DriftMethod
 13  from evidently.legacy.metrics.data_drift.embedding_drift_methods import model
 14  from evidently.legacy.model.widget import BaseWidgetInfo
 15  from evidently.legacy.options.base import AnyOptions
 16  from evidently.legacy.renderers.base_renderer import MetricRenderer
 17  from evidently.legacy.renderers.base_renderer import default_renderer
 18  from evidently.legacy.renderers.html_widgets import CounterData
 19  from evidently.legacy.renderers.html_widgets import WidgetSize
 20  from evidently.legacy.renderers.html_widgets import counter
 21  from evidently.legacy.renderers.html_widgets import plotly_figure
 22  from evidently.legacy.utils.visualizations import get_gaussian_kde
 23  from evidently.legacy.utils.visualizations import plot_contour_single
 24  
 25  SAMPLE_CONSTANT = 2500
 26  
 27  
 28  class EmbeddingsDriftMetricResults(MetricResult):
 29      class Config:
 30          type_alias = "evidently:metric_result:EmbeddingsDriftMetricResults"
 31          dict_exclude_fields = {
 32              "reference",
 33              "current",
 34          }
 35  
 36          field_tags = {
 37              "current": {IncludeTags.Current, IncludeTags.Render},
 38              "reference": {IncludeTags.Reference, IncludeTags.Render},
 39              "embeddings_name": {IncludeTags.Parameter},
 40              "method_name": {IncludeTags.Parameter},
 41          }
 42  
 43      embeddings_name: str
 44      drift_score: float
 45      drift_detected: bool
 46      method_name: str
 47      reference: np.ndarray
 48      current: np.ndarray
 49  
 50  
 51  class EmbeddingsDriftMetric(Metric[EmbeddingsDriftMetricResults]):
 52      class Config:
 53          type_alias = "evidently:metric:EmbeddingsDriftMetric"
 54  
 55      embeddings_name: str
 56      drift_method: Optional[DriftMethod]
 57  
 58      def __init__(self, embeddings_name: str, drift_method: Optional[DriftMethod] = None, options: AnyOptions = None):
 59          self.embeddings_name = embeddings_name
 60          self.drift_method = drift_method
 61          super().__init__(options=options)
 62  
 63      def calculate(self, data: InputData) -> EmbeddingsDriftMetricResults:
 64          if data.reference_data is None:
 65              raise ValueError("Reference dataset should be present")
 66          drift_method = self.drift_method or model(bootstrap=data.reference_data.shape[0] < 1000)
 67          emb_dict = data.data_definition.embeddings
 68          if emb_dict is None:
 69              raise ValueError("Embeddings should be defined in column mapping")
 70          if self.embeddings_name not in emb_dict.keys():
 71              raise ValueError(f"{self.embeddings_name} not in column_mapping.embeddings")
 72          emb_list = emb_dict[self.embeddings_name]
 73          drift_score, drift_detected, method_name = drift_method(
 74              data.current_data[emb_list], data.reference_data[emb_list]
 75          )
 76          # visualisation
 77          ref_sample_size = min(SAMPLE_CONSTANT, data.reference_data.shape[0])
 78          curr_sample_size = min(SAMPLE_CONSTANT, data.current_data.shape[0])
 79          ref_sample = data.reference_data[emb_list].sample(ref_sample_size, random_state=24)
 80          curr_sample = data.current_data[emb_list].sample(curr_sample_size, random_state=24)
 81          data_2d = TSNE(n_components=2).fit_transform(pd.concat([ref_sample, curr_sample]))
 82          reference, _, _ = get_gaussian_kde(data_2d[:ref_sample_size, 0], data_2d[:ref_sample_size, 1])
 83          current, _, _ = get_gaussian_kde(data_2d[ref_sample_size:, 0], data_2d[ref_sample_size:, 1])
 84  
 85          return EmbeddingsDriftMetricResults(
 86              embeddings_name=self.embeddings_name,
 87              drift_score=drift_score,
 88              drift_detected=drift_detected,
 89              method_name=method_name,
 90              reference=reference,
 91              current=current,
 92          )
 93  
 94  
 95  @default_renderer(wrap_type=EmbeddingsDriftMetric)
 96  class EmbeddingsDriftMetricRenderer(MetricRenderer):
 97      def render_html(self, obj: EmbeddingsDriftMetric) -> List[BaseWidgetInfo]:
 98          result = obj.get_result()
 99          if result.drift_detected:
100              drift = "detected"
101  
102          else:
103              drift = "not detected"
104          drift_score = round(result.drift_score, 3)
105          fig = plot_contour_single(result.current, result.reference, "component 1", "component 2")
106          return [
107              counter(
108                  counters=[
109                      CounterData(
110                          (
111                              f"Data drift {drift}. "
112                              f"Drift detection method: {result.method_name}. "
113                              f"Drift score: {drift_score}"
114                          ),
115                          f"Drift in embeddings '{result.embeddings_name}'",
116                      )
117                  ],
118                  title="",
119              ),
120              plotly_figure(title="", figure=fig, size=WidgetSize.FULL),
121          ]