embeddings_drift.py
1 from typing import List 2 from typing import Optional 3 4 import numpy as np 5 import pandas as pd 6 from sklearn.manifold import TSNE 7 8 from evidently.legacy.base_metric import InputData 9 from evidently.legacy.base_metric import Metric 10 from evidently.legacy.base_metric import MetricResult 11 from evidently.legacy.core import IncludeTags 12 from evidently.legacy.metrics.data_drift.embedding_drift_methods import DriftMethod 13 from evidently.legacy.metrics.data_drift.embedding_drift_methods import model 14 from evidently.legacy.model.widget import BaseWidgetInfo 15 from evidently.legacy.options.base import AnyOptions 16 from evidently.legacy.renderers.base_renderer import MetricRenderer 17 from evidently.legacy.renderers.base_renderer import default_renderer 18 from evidently.legacy.renderers.html_widgets import CounterData 19 from evidently.legacy.renderers.html_widgets import WidgetSize 20 from evidently.legacy.renderers.html_widgets import counter 21 from evidently.legacy.renderers.html_widgets import plotly_figure 22 from evidently.legacy.utils.visualizations import get_gaussian_kde 23 from evidently.legacy.utils.visualizations import plot_contour_single 24 25 SAMPLE_CONSTANT = 2500 26 27 28 class EmbeddingsDriftMetricResults(MetricResult): 29 class Config: 30 type_alias = "evidently:metric_result:EmbeddingsDriftMetricResults" 31 dict_exclude_fields = { 32 "reference", 33 "current", 34 } 35 36 field_tags = { 37 "current": {IncludeTags.Current, IncludeTags.Render}, 38 "reference": {IncludeTags.Reference, IncludeTags.Render}, 39 "embeddings_name": {IncludeTags.Parameter}, 40 "method_name": {IncludeTags.Parameter}, 41 } 42 43 embeddings_name: str 44 drift_score: float 45 drift_detected: bool 46 method_name: str 47 reference: np.ndarray 48 current: np.ndarray 49 50 51 class EmbeddingsDriftMetric(Metric[EmbeddingsDriftMetricResults]): 52 class Config: 53 type_alias = "evidently:metric:EmbeddingsDriftMetric" 54 55 embeddings_name: str 56 drift_method: Optional[DriftMethod] 57 58 def __init__(self, embeddings_name: str, drift_method: Optional[DriftMethod] = None, options: AnyOptions = None): 59 self.embeddings_name = embeddings_name 60 self.drift_method = drift_method 61 super().__init__(options=options) 62 63 def calculate(self, data: InputData) -> EmbeddingsDriftMetricResults: 64 if data.reference_data is None: 65 raise ValueError("Reference dataset should be present") 66 drift_method = self.drift_method or model(bootstrap=data.reference_data.shape[0] < 1000) 67 emb_dict = data.data_definition.embeddings 68 if emb_dict is None: 69 raise ValueError("Embeddings should be defined in column mapping") 70 if self.embeddings_name not in emb_dict.keys(): 71 raise ValueError(f"{self.embeddings_name} not in column_mapping.embeddings") 72 emb_list = emb_dict[self.embeddings_name] 73 drift_score, drift_detected, method_name = drift_method( 74 data.current_data[emb_list], data.reference_data[emb_list] 75 ) 76 # visualisation 77 ref_sample_size = min(SAMPLE_CONSTANT, data.reference_data.shape[0]) 78 curr_sample_size = min(SAMPLE_CONSTANT, data.current_data.shape[0]) 79 ref_sample = data.reference_data[emb_list].sample(ref_sample_size, random_state=24) 80 curr_sample = data.current_data[emb_list].sample(curr_sample_size, random_state=24) 81 data_2d = TSNE(n_components=2).fit_transform(pd.concat([ref_sample, curr_sample])) 82 reference, _, _ = get_gaussian_kde(data_2d[:ref_sample_size, 0], data_2d[:ref_sample_size, 1]) 83 current, _, _ = get_gaussian_kde(data_2d[ref_sample_size:, 0], data_2d[ref_sample_size:, 1]) 84 85 return EmbeddingsDriftMetricResults( 86 embeddings_name=self.embeddings_name, 87 drift_score=drift_score, 88 drift_detected=drift_detected, 89 method_name=method_name, 90 reference=reference, 91 current=current, 92 ) 93 94 95 @default_renderer(wrap_type=EmbeddingsDriftMetric) 96 class EmbeddingsDriftMetricRenderer(MetricRenderer): 97 def render_html(self, obj: EmbeddingsDriftMetric) -> List[BaseWidgetInfo]: 98 result = obj.get_result() 99 if result.drift_detected: 100 drift = "detected" 101 102 else: 103 drift = "not detected" 104 drift_score = round(result.drift_score, 3) 105 fig = plot_contour_single(result.current, result.reference, "component 1", "component 2") 106 return [ 107 counter( 108 counters=[ 109 CounterData( 110 ( 111 f"Data drift {drift}. " 112 f"Drift detection method: {result.method_name}. " 113 f"Drift score: {drift_score}" 114 ), 115 f"Drift in embeddings '{result.embeddings_name}'", 116 ) 117 ], 118 title="", 119 ), 120 plotly_figure(title="", figure=fig, size=WidgetSize.FULL), 121 ]