/ src / evidently / legacy / metrics / data_drift / column_interaction_plot.py
column_interaction_plot.py
  1  import json
  2  from typing import Any
  3  from typing import Dict
  4  from typing import List
  5  from typing import Optional
  6  from typing import Union
  7  
  8  import numpy as np
  9  import pandas as pd
 10  
 11  from evidently.legacy.base_metric import InputData
 12  from evidently.legacy.base_metric import Metric
 13  from evidently.legacy.base_metric import MetricResult
 14  from evidently.legacy.base_metric import UsesRawDataMixin
 15  from evidently.legacy.calculations.utils import get_data_for_cat_cat_plot
 16  from evidently.legacy.calculations.utils import get_data_for_num_num_plot
 17  from evidently.legacy.calculations.utils import prepare_box_data
 18  from evidently.legacy.calculations.utils import prepare_data_for_date_cat
 19  from evidently.legacy.calculations.utils import prepare_data_for_date_num
 20  from evidently.legacy.calculations.utils import relabel_data
 21  from evidently.legacy.core import ColumnType
 22  from evidently.legacy.core import IncludeTags
 23  from evidently.legacy.metric_results import ColumnScatter
 24  from evidently.legacy.metric_results import ContourData
 25  from evidently.legacy.model.widget import BaseWidgetInfo
 26  from evidently.legacy.options.base import AnyOptions
 27  from evidently.legacy.renderers.base_renderer import MetricRenderer
 28  from evidently.legacy.renderers.base_renderer import default_renderer
 29  from evidently.legacy.renderers.html_widgets import header_text
 30  from evidently.legacy.utils.visualizations import plot_boxes
 31  from evidently.legacy.utils.visualizations import plot_cat_cat_rel
 32  from evidently.legacy.utils.visualizations import plot_cat_feature_in_time
 33  from evidently.legacy.utils.visualizations import plot_contour
 34  from evidently.legacy.utils.visualizations import plot_num_feature_in_time
 35  from evidently.legacy.utils.visualizations import plot_num_num_rel
 36  
 37  
 38  class ColumnInteractionPlotResults(MetricResult):
 39      class Config:
 40          type_alias = "evidently:metric_result:ColumnInteractionPlotResults"
 41          dict_include = False
 42          pd_include = False
 43          tags = {IncludeTags.Render}
 44          field_tags = {
 45              "current": {IncludeTags.Current},
 46              "reference": {IncludeTags.Reference},
 47              "current_scatter": {IncludeTags.Current},
 48              "current_contour": {IncludeTags.Current},
 49              "current_boxes": {IncludeTags.Current},
 50              "reference_scatter": {IncludeTags.Reference},
 51              "reference_contour": {IncludeTags.Reference},
 52              "reference_boxes": {IncludeTags.Reference},
 53          }
 54  
 55      y_type: ColumnType
 56      x_type: ColumnType
 57      current_scatter: Optional[ColumnScatter]
 58      current_contour: Optional[ContourData]
 59      current_boxes: Optional[Dict[str, Union[list, np.ndarray]]]
 60      current: Optional[pd.DataFrame]
 61      reference_scatter: Optional[ColumnScatter]
 62      reference_contour: Optional[ContourData]
 63      reference_boxes: Optional[Dict[str, Union[list, np.ndarray]]]
 64      reference: Optional[pd.DataFrame]
 65      prefix: Optional[str] = None
 66  
 67  
 68  class ColumnInteractionPlot(UsesRawDataMixin, Metric[ColumnInteractionPlotResults]):
 69      class Config:
 70          type_alias = "evidently:metric:ColumnInteractionPlot"
 71  
 72      x_column: str
 73      y_column: str
 74  
 75      def __init__(self, x_column: str, y_column: str, options: AnyOptions = None):
 76          self.x_column = x_column
 77          self.y_column = y_column
 78          super().__init__(options=options)
 79  
 80      def calculate(self, data: InputData) -> ColumnInteractionPlotResults:
 81          for col in [self.x_column, self.y_column]:
 82              if not data.has_column(col):
 83                  raise ValueError(f"Column '{col}' not found in dataset.")
 84  
 85          x_type, x_curr, x_ref = data.get_data(self.x_column)
 86          y_type, y_curr, y_ref = data.get_data(self.y_column)
 87          for column in [x_curr, x_ref, y_curr, y_ref]:
 88              if column is not None:
 89                  column.replace(to_replace=[np.inf, -np.inf], value=np.nan, inplace=True)
 90          if x_type == ColumnType.Categorical:
 91              x_curr, x_ref = relabel_data(x_curr, x_ref)
 92          if y_type == ColumnType.Categorical:
 93              y_curr, y_ref = relabel_data(y_curr, y_ref)
 94  
 95          agg_data = True
 96          if self.get_options().render_options.raw_data:
 97              agg_data = False
 98          if x_type == ColumnType.Numerical and y_type == ColumnType.Numerical:
 99              raw_plot, agg_plot = get_data_for_num_num_plot(
100                  agg_data,
101                  self.x_column,
102                  self.y_column,
103                  x_curr,
104                  y_curr,
105                  x_ref if x_ref is not None else None,
106                  y_ref if y_ref is not None else None,
107              )
108              if raw_plot is not None:
109                  return ColumnInteractionPlotResults(
110                      x_type=x_type,
111                      y_type=y_type,
112                      current_scatter=raw_plot["current"],
113                      reference_scatter=raw_plot.get("reference"),
114                  )
115              return ColumnInteractionPlotResults(
116                  x_type=x_type,
117                  y_type=y_type,
118                  current_contour=agg_plot["current"],
119                  reference_contour=agg_plot.get("reference"),
120              )
121          if x_type == ColumnType.Categorical and y_type == ColumnType.Categorical:
122              result = get_data_for_cat_cat_plot(
123                  self.x_column,
124                  self.y_column,
125                  x_curr,
126                  y_curr,
127                  x_ref if x_ref is not None else None,
128                  y_ref if y_ref is not None else None,
129              )
130              return ColumnInteractionPlotResults(
131                  x_type=x_type,
132                  y_type=y_type,
133                  current=result["current"],
134                  reference=result.get("reference"),
135              )
136          if (x_type == ColumnType.Categorical and y_type == ColumnType.Numerical) or (
137              x_type == ColumnType.Numerical and y_type == ColumnType.Categorical
138          ):
139              curr_df = pd.DataFrame({self.x_column: x_curr, self.y_column: y_curr})
140              ref_df = None
141              if x_ref is not None and y_ref is not None:
142                  ref_df = pd.DataFrame({self.x_column: x_ref, self.y_column: y_ref})
143              if x_type == ColumnType.Categorical:
144                  cat_name, num_name = self.x_column, self.y_column
145              else:
146                  cat_name, num_name = self.y_column, self.x_column
147              result = prepare_box_data(curr_df, ref_df, cat_name, num_name)
148              return ColumnInteractionPlotResults(
149                  x_type=x_type,
150                  y_type=y_type,
151                  current_boxes=result["current"],
152                  reference_boxes=result.get("reference"),
153              )
154          if (x_type == ColumnType.Numerical and y_type == ColumnType.Datetime) or (
155              x_type == ColumnType.Datetime and y_type == ColumnType.Numerical
156          ):
157              if x_type == ColumnType.Numerical:
158                  date_name, date_curr, date_ref = self.y_column, y_curr, y_ref
159                  num_name, num_curr, num_ref = self.x_column, x_curr, x_ref
160              else:
161                  date_name, date_curr, date_ref = self.x_column, x_curr, x_ref
162                  num_name, num_curr, num_ref = self.y_column, y_curr, y_ref
163              curr_res, ref_res, prefix = prepare_data_for_date_num(
164                  date_curr, date_ref, date_name, num_name, num_curr, num_ref
165              )
166              return ColumnInteractionPlotResults(
167                  x_type=x_type,
168                  y_type=y_type,
169                  current=curr_res,
170                  reference=ref_res,
171                  prefix=prefix,
172              )
173          if (x_type == ColumnType.Categorical and y_type == ColumnType.Datetime) or (
174              x_type == ColumnType.Datetime and y_type == ColumnType.Categorical
175          ):
176              if x_type == ColumnType.Categorical:
177                  date_name, date_curr, date_ref = self.y_column, y_curr, y_ref
178                  cat_name, cat_curr, cat_ref = self.x_column, x_curr, x_ref
179              else:
180                  date_name, date_curr, date_ref = self.x_column, x_curr, x_ref
181                  cat_name, cat_curr, cat_ref = self.y_column, y_curr, y_ref
182              curr_res, ref_res, prefix = prepare_data_for_date_cat(
183                  date_curr, date_ref, date_name, cat_name, cat_curr, cat_ref
184              )
185              return ColumnInteractionPlotResults(
186                  x_type=x_type,
187                  y_type=y_type,
188                  current=curr_res,
189                  reference=ref_res,
190                  prefix=prefix,
191              )
192          raise ValueError(f"Combination of types {x_type} and {y_type} is not supported.")
193  
194  
195  @default_renderer(wrap_type=ColumnInteractionPlot)
196  class ColumnInteractionPlotRenderer(MetricRenderer):
197      def render_html(self, obj: ColumnInteractionPlot) -> List[BaseWidgetInfo]:
198          metric_result = obj.get_result()
199          agg_data = not obj.get_options().render_options.raw_data
200          if (
201              metric_result.x_type == ColumnType.Numerical
202              and metric_result.y_type == ColumnType.Numerical
203              and (metric_result.current_scatter is not None or metric_result.current_contour is not None)
204          ):
205              if (
206                  isinstance(metric_result.current_scatter, Dict[str, List[Any]])
207                  and isinstance(metric_result.reference_scatter, Dict[str, List[Any]])
208                  and (not agg_data or metric_result.current_scatter is not None)
209              ):
210                  fig = plot_num_num_rel(
211                      metric_result.current_scatter,
212                      metric_result.reference_scatter,
213                      obj.y_column,
214                      obj.x_column,
215                      self.color_options,
216                  )
217              elif metric_result.current_contour is not None:
218                  fig = plot_contour(
219                      metric_result.current_contour,
220                      metric_result.reference_contour,
221                      obj.x_column,
222                      obj.y_column,
223                  )
224                  fig = json.loads(fig.to_json())
225          elif (
226              metric_result.x_type == ColumnType.Categorical
227              and metric_result.y_type == ColumnType.Categorical
228              and metric_result.current is not None
229          ):
230              fig = plot_cat_cat_rel(
231                  metric_result.current,
232                  metric_result.reference,
233                  obj.y_column,
234                  obj.x_column,
235                  self.color_options,
236              )
237          elif (
238              metric_result.x_type == ColumnType.Categorical
239              and metric_result.y_type == ColumnType.Numerical
240              and metric_result.current_boxes is not None
241          ):
242              fig = plot_boxes(
243                  metric_result.current_boxes,
244                  metric_result.reference_boxes,
245                  obj.y_column,
246                  obj.x_column,
247                  self.color_options,
248              )
249          elif (
250              metric_result.x_type == ColumnType.Numerical
251              and metric_result.y_type == ColumnType.Categorical
252              and metric_result.current_boxes is not None
253          ):
254              fig = plot_boxes(
255                  metric_result.current_boxes,
256                  metric_result.reference_boxes,
257                  obj.x_column,
258                  obj.y_column,
259                  self.color_options,
260                  True,
261              )
262          elif (
263              metric_result.x_type == ColumnType.Datetime
264              and metric_result.y_type == ColumnType.Numerical
265              and metric_result.current is not None
266              and metric_result.prefix is not None
267          ):
268              fig = plot_num_feature_in_time(
269                  metric_result.current,
270                  metric_result.reference,
271                  obj.y_column,
272                  obj.x_column,
273                  metric_result.prefix,
274                  self.color_options,
275              )
276          elif (
277              metric_result.y_type == ColumnType.Datetime
278              and metric_result.x_type == ColumnType.Numerical
279              and metric_result.current is not None
280              and metric_result.prefix is not None
281          ):
282              fig = plot_num_feature_in_time(
283                  metric_result.current,
284                  metric_result.reference,
285                  obj.x_column,
286                  obj.y_column,
287                  metric_result.prefix,
288                  self.color_options,
289                  True,
290              )
291          elif (
292              metric_result.x_type == ColumnType.Datetime
293              and metric_result.y_type == ColumnType.Categorical
294              and metric_result.current is not None
295              and metric_result.prefix is not None
296          ):
297              fig = plot_cat_feature_in_time(
298                  metric_result.current,
299                  metric_result.reference,
300                  obj.y_column,
301                  obj.x_column,
302                  metric_result.prefix,
303                  self.color_options,
304              )
305          elif (
306              metric_result.y_type == ColumnType.Datetime
307              and metric_result.x_type == ColumnType.Categorical
308              and metric_result.current is not None
309              and metric_result.prefix is not None
310          ):
311              fig = plot_cat_feature_in_time(
312                  metric_result.current,
313                  metric_result.reference,
314                  obj.x_column,
315                  obj.y_column,
316                  metric_result.prefix,
317                  self.color_options,
318                  True,
319              )
320          return [
321              header_text(label=f"Interactions between '{obj.x_column}' and '{obj.y_column}'"),
322              BaseWidgetInfo(
323                  title="",
324                  size=2,
325                  type="big_graph",
326                  params={"data": fig["data"], "layout": fig["layout"]},
327              ),
328          ]