column_interaction_plot.py
1 import json 2 from typing import Any 3 from typing import Dict 4 from typing import List 5 from typing import Optional 6 from typing import Union 7 8 import numpy as np 9 import pandas as pd 10 11 from evidently.legacy.base_metric import InputData 12 from evidently.legacy.base_metric import Metric 13 from evidently.legacy.base_metric import MetricResult 14 from evidently.legacy.base_metric import UsesRawDataMixin 15 from evidently.legacy.calculations.utils import get_data_for_cat_cat_plot 16 from evidently.legacy.calculations.utils import get_data_for_num_num_plot 17 from evidently.legacy.calculations.utils import prepare_box_data 18 from evidently.legacy.calculations.utils import prepare_data_for_date_cat 19 from evidently.legacy.calculations.utils import prepare_data_for_date_num 20 from evidently.legacy.calculations.utils import relabel_data 21 from evidently.legacy.core import ColumnType 22 from evidently.legacy.core import IncludeTags 23 from evidently.legacy.metric_results import ColumnScatter 24 from evidently.legacy.metric_results import ContourData 25 from evidently.legacy.model.widget import BaseWidgetInfo 26 from evidently.legacy.options.base import AnyOptions 27 from evidently.legacy.renderers.base_renderer import MetricRenderer 28 from evidently.legacy.renderers.base_renderer import default_renderer 29 from evidently.legacy.renderers.html_widgets import header_text 30 from evidently.legacy.utils.visualizations import plot_boxes 31 from evidently.legacy.utils.visualizations import plot_cat_cat_rel 32 from evidently.legacy.utils.visualizations import plot_cat_feature_in_time 33 from evidently.legacy.utils.visualizations import plot_contour 34 from evidently.legacy.utils.visualizations import plot_num_feature_in_time 35 from evidently.legacy.utils.visualizations import plot_num_num_rel 36 37 38 class ColumnInteractionPlotResults(MetricResult): 39 class Config: 40 type_alias = "evidently:metric_result:ColumnInteractionPlotResults" 41 dict_include = False 42 pd_include = False 43 tags = {IncludeTags.Render} 44 field_tags = { 45 "current": {IncludeTags.Current}, 46 "reference": {IncludeTags.Reference}, 47 "current_scatter": {IncludeTags.Current}, 48 "current_contour": {IncludeTags.Current}, 49 "current_boxes": {IncludeTags.Current}, 50 "reference_scatter": {IncludeTags.Reference}, 51 "reference_contour": {IncludeTags.Reference}, 52 "reference_boxes": {IncludeTags.Reference}, 53 } 54 55 y_type: ColumnType 56 x_type: ColumnType 57 current_scatter: Optional[ColumnScatter] 58 current_contour: Optional[ContourData] 59 current_boxes: Optional[Dict[str, Union[list, np.ndarray]]] 60 current: Optional[pd.DataFrame] 61 reference_scatter: Optional[ColumnScatter] 62 reference_contour: Optional[ContourData] 63 reference_boxes: Optional[Dict[str, Union[list, np.ndarray]]] 64 reference: Optional[pd.DataFrame] 65 prefix: Optional[str] = None 66 67 68 class ColumnInteractionPlot(UsesRawDataMixin, Metric[ColumnInteractionPlotResults]): 69 class Config: 70 type_alias = "evidently:metric:ColumnInteractionPlot" 71 72 x_column: str 73 y_column: str 74 75 def __init__(self, x_column: str, y_column: str, options: AnyOptions = None): 76 self.x_column = x_column 77 self.y_column = y_column 78 super().__init__(options=options) 79 80 def calculate(self, data: InputData) -> ColumnInteractionPlotResults: 81 for col in [self.x_column, self.y_column]: 82 if not data.has_column(col): 83 raise ValueError(f"Column '{col}' not found in dataset.") 84 85 x_type, x_curr, x_ref = data.get_data(self.x_column) 86 y_type, y_curr, y_ref = data.get_data(self.y_column) 87 for column in [x_curr, x_ref, y_curr, y_ref]: 88 if column is not None: 89 column.replace(to_replace=[np.inf, -np.inf], value=np.nan, inplace=True) 90 if x_type == ColumnType.Categorical: 91 x_curr, x_ref = relabel_data(x_curr, x_ref) 92 if y_type == ColumnType.Categorical: 93 y_curr, y_ref = relabel_data(y_curr, y_ref) 94 95 agg_data = True 96 if self.get_options().render_options.raw_data: 97 agg_data = False 98 if x_type == ColumnType.Numerical and y_type == ColumnType.Numerical: 99 raw_plot, agg_plot = get_data_for_num_num_plot( 100 agg_data, 101 self.x_column, 102 self.y_column, 103 x_curr, 104 y_curr, 105 x_ref if x_ref is not None else None, 106 y_ref if y_ref is not None else None, 107 ) 108 if raw_plot is not None: 109 return ColumnInteractionPlotResults( 110 x_type=x_type, 111 y_type=y_type, 112 current_scatter=raw_plot["current"], 113 reference_scatter=raw_plot.get("reference"), 114 ) 115 return ColumnInteractionPlotResults( 116 x_type=x_type, 117 y_type=y_type, 118 current_contour=agg_plot["current"], 119 reference_contour=agg_plot.get("reference"), 120 ) 121 if x_type == ColumnType.Categorical and y_type == ColumnType.Categorical: 122 result = get_data_for_cat_cat_plot( 123 self.x_column, 124 self.y_column, 125 x_curr, 126 y_curr, 127 x_ref if x_ref is not None else None, 128 y_ref if y_ref is not None else None, 129 ) 130 return ColumnInteractionPlotResults( 131 x_type=x_type, 132 y_type=y_type, 133 current=result["current"], 134 reference=result.get("reference"), 135 ) 136 if (x_type == ColumnType.Categorical and y_type == ColumnType.Numerical) or ( 137 x_type == ColumnType.Numerical and y_type == ColumnType.Categorical 138 ): 139 curr_df = pd.DataFrame({self.x_column: x_curr, self.y_column: y_curr}) 140 ref_df = None 141 if x_ref is not None and y_ref is not None: 142 ref_df = pd.DataFrame({self.x_column: x_ref, self.y_column: y_ref}) 143 if x_type == ColumnType.Categorical: 144 cat_name, num_name = self.x_column, self.y_column 145 else: 146 cat_name, num_name = self.y_column, self.x_column 147 result = prepare_box_data(curr_df, ref_df, cat_name, num_name) 148 return ColumnInteractionPlotResults( 149 x_type=x_type, 150 y_type=y_type, 151 current_boxes=result["current"], 152 reference_boxes=result.get("reference"), 153 ) 154 if (x_type == ColumnType.Numerical and y_type == ColumnType.Datetime) or ( 155 x_type == ColumnType.Datetime and y_type == ColumnType.Numerical 156 ): 157 if x_type == ColumnType.Numerical: 158 date_name, date_curr, date_ref = self.y_column, y_curr, y_ref 159 num_name, num_curr, num_ref = self.x_column, x_curr, x_ref 160 else: 161 date_name, date_curr, date_ref = self.x_column, x_curr, x_ref 162 num_name, num_curr, num_ref = self.y_column, y_curr, y_ref 163 curr_res, ref_res, prefix = prepare_data_for_date_num( 164 date_curr, date_ref, date_name, num_name, num_curr, num_ref 165 ) 166 return ColumnInteractionPlotResults( 167 x_type=x_type, 168 y_type=y_type, 169 current=curr_res, 170 reference=ref_res, 171 prefix=prefix, 172 ) 173 if (x_type == ColumnType.Categorical and y_type == ColumnType.Datetime) or ( 174 x_type == ColumnType.Datetime and y_type == ColumnType.Categorical 175 ): 176 if x_type == ColumnType.Categorical: 177 date_name, date_curr, date_ref = self.y_column, y_curr, y_ref 178 cat_name, cat_curr, cat_ref = self.x_column, x_curr, x_ref 179 else: 180 date_name, date_curr, date_ref = self.x_column, x_curr, x_ref 181 cat_name, cat_curr, cat_ref = self.y_column, y_curr, y_ref 182 curr_res, ref_res, prefix = prepare_data_for_date_cat( 183 date_curr, date_ref, date_name, cat_name, cat_curr, cat_ref 184 ) 185 return ColumnInteractionPlotResults( 186 x_type=x_type, 187 y_type=y_type, 188 current=curr_res, 189 reference=ref_res, 190 prefix=prefix, 191 ) 192 raise ValueError(f"Combination of types {x_type} and {y_type} is not supported.") 193 194 195 @default_renderer(wrap_type=ColumnInteractionPlot) 196 class ColumnInteractionPlotRenderer(MetricRenderer): 197 def render_html(self, obj: ColumnInteractionPlot) -> List[BaseWidgetInfo]: 198 metric_result = obj.get_result() 199 agg_data = not obj.get_options().render_options.raw_data 200 if ( 201 metric_result.x_type == ColumnType.Numerical 202 and metric_result.y_type == ColumnType.Numerical 203 and (metric_result.current_scatter is not None or metric_result.current_contour is not None) 204 ): 205 if ( 206 isinstance(metric_result.current_scatter, Dict[str, List[Any]]) 207 and isinstance(metric_result.reference_scatter, Dict[str, List[Any]]) 208 and (not agg_data or metric_result.current_scatter is not None) 209 ): 210 fig = plot_num_num_rel( 211 metric_result.current_scatter, 212 metric_result.reference_scatter, 213 obj.y_column, 214 obj.x_column, 215 self.color_options, 216 ) 217 elif metric_result.current_contour is not None: 218 fig = plot_contour( 219 metric_result.current_contour, 220 metric_result.reference_contour, 221 obj.x_column, 222 obj.y_column, 223 ) 224 fig = json.loads(fig.to_json()) 225 elif ( 226 metric_result.x_type == ColumnType.Categorical 227 and metric_result.y_type == ColumnType.Categorical 228 and metric_result.current is not None 229 ): 230 fig = plot_cat_cat_rel( 231 metric_result.current, 232 metric_result.reference, 233 obj.y_column, 234 obj.x_column, 235 self.color_options, 236 ) 237 elif ( 238 metric_result.x_type == ColumnType.Categorical 239 and metric_result.y_type == ColumnType.Numerical 240 and metric_result.current_boxes is not None 241 ): 242 fig = plot_boxes( 243 metric_result.current_boxes, 244 metric_result.reference_boxes, 245 obj.y_column, 246 obj.x_column, 247 self.color_options, 248 ) 249 elif ( 250 metric_result.x_type == ColumnType.Numerical 251 and metric_result.y_type == ColumnType.Categorical 252 and metric_result.current_boxes is not None 253 ): 254 fig = plot_boxes( 255 metric_result.current_boxes, 256 metric_result.reference_boxes, 257 obj.x_column, 258 obj.y_column, 259 self.color_options, 260 True, 261 ) 262 elif ( 263 metric_result.x_type == ColumnType.Datetime 264 and metric_result.y_type == ColumnType.Numerical 265 and metric_result.current is not None 266 and metric_result.prefix is not None 267 ): 268 fig = plot_num_feature_in_time( 269 metric_result.current, 270 metric_result.reference, 271 obj.y_column, 272 obj.x_column, 273 metric_result.prefix, 274 self.color_options, 275 ) 276 elif ( 277 metric_result.y_type == ColumnType.Datetime 278 and metric_result.x_type == ColumnType.Numerical 279 and metric_result.current is not None 280 and metric_result.prefix is not None 281 ): 282 fig = plot_num_feature_in_time( 283 metric_result.current, 284 metric_result.reference, 285 obj.x_column, 286 obj.y_column, 287 metric_result.prefix, 288 self.color_options, 289 True, 290 ) 291 elif ( 292 metric_result.x_type == ColumnType.Datetime 293 and metric_result.y_type == ColumnType.Categorical 294 and metric_result.current is not None 295 and metric_result.prefix is not None 296 ): 297 fig = plot_cat_feature_in_time( 298 metric_result.current, 299 metric_result.reference, 300 obj.y_column, 301 obj.x_column, 302 metric_result.prefix, 303 self.color_options, 304 ) 305 elif ( 306 metric_result.y_type == ColumnType.Datetime 307 and metric_result.x_type == ColumnType.Categorical 308 and metric_result.current is not None 309 and metric_result.prefix is not None 310 ): 311 fig = plot_cat_feature_in_time( 312 metric_result.current, 313 metric_result.reference, 314 obj.x_column, 315 obj.y_column, 316 metric_result.prefix, 317 self.color_options, 318 True, 319 ) 320 return [ 321 header_text(label=f"Interactions between '{obj.x_column}' and '{obj.y_column}'"), 322 BaseWidgetInfo( 323 title="", 324 size=2, 325 type="big_graph", 326 params={"data": fig["data"], "layout": fig["layout"]}, 327 ), 328 ]