html_widgets.py
1 import dataclasses 2 from enum import Enum 3 from typing import Dict 4 from typing import Iterable 5 from typing import List 6 from typing import Optional 7 from typing import Tuple 8 from typing import Union 9 10 import numpy as np 11 import pandas as pd 12 from plotly import graph_objs as go 13 from plotly.subplots import make_subplots 14 from uuid6 import uuid7 15 16 from evidently.legacy.metric_results import Distribution 17 from evidently.legacy.metric_results import HistogramData 18 from evidently.legacy.metric_results import Label 19 from evidently.legacy.metric_results import LiftCurve 20 from evidently.legacy.metric_results import PRCurve 21 from evidently.legacy.metric_results import ROCCurve 22 from evidently.legacy.model.widget import BaseWidgetInfo 23 from evidently.legacy.model.widget import PlotlyGraphInfo 24 from evidently.legacy.model.widget import TabInfo 25 from evidently.legacy.model.widget import WidgetType 26 from evidently.legacy.options import ColorOptions 27 28 29 class WidgetSize(int, Enum): 30 SMALL = 0 31 HALF = 1 32 FULL = 2 33 34 35 class GraphData: 36 title: str 37 data: dict 38 layout: dict 39 40 def __init__(self, title: str, data: dict, layout: dict): 41 """ 42 create GraphData object for usage in plotly_graph_tabs or plotly_data. 43 44 Args: 45 title: title of graph 46 data: plotly figure data 47 layout: plotly figure layout 48 """ 49 self.title = title 50 self.data = data 51 self.layout = layout 52 53 @staticmethod 54 def figure(title: str, figure: go.Figure): 55 """ 56 create GraphData from plotly figure itself 57 Args: 58 title: title of graph 59 figure: plotly figure for getting data from 60 """ 61 data = figure.to_plotly_json() 62 return GraphData(title, data["data"], data["layout"]) 63 64 65 def plotly_graph(*, graph_data: GraphData, size: WidgetSize = WidgetSize.FULL) -> BaseWidgetInfo: 66 """ 67 generate plotly plot with given GraphData object. 68 69 Args: 70 graph_data: plot data for widget 71 size: size of widget to render 72 73 Example: 74 >>> figure = go.Figure(go.Bar(name="Bar plot", x=[1, 2, 3, 4], y=[10, 11, 20, 11])) 75 >>> f_dict = figure.to_plotly_json() 76 >>> bar_graph_data = GraphData(title="Some plot title", data=f_dict["data"], layout=f_dict["layout"]) 77 >>> widget_info = plotly_graph(graph_data=bar_graph_data, size=WidgetSize.FULL) 78 """ 79 return BaseWidgetInfo( 80 title=graph_data.title, 81 type=WidgetType.BIG_GRAPH.value, 82 size=size.value, 83 params={"data": graph_data.data, "layout": graph_data.layout}, 84 ) 85 86 87 def plotly_data(*, title: str, data: dict, layout: dict, size: WidgetSize = WidgetSize.FULL) -> BaseWidgetInfo: 88 """ 89 generate plotly plot with given data and layout (can be generated from plotly). 90 91 Args: 92 title: widget title 93 data: plotly figure data 94 layout: plotly figure layout 95 size: widget size 96 97 Example: 98 >>> figure = go.Figure(go.Bar(name="Bar plot", x=[1, 2, 3, 4], y=[10, 11, 20, 11])) 99 >>> f_dict = figure.to_plotly_json() 100 >>> widget_info = plotly_data(title="Some plot title", data=f_dict["data"], layout=f_dict["layout"]) 101 """ 102 return plotly_graph(graph_data=GraphData(title, data, layout), size=size) 103 104 105 def plotly_figure(*, title: str, figure: go.Figure, size: WidgetSize = WidgetSize.FULL) -> BaseWidgetInfo: 106 """ 107 generate plotly plot based on given plotly figure object. 108 109 Args: 110 title: title of widget 111 figure: plotly figure which should be rendered as widget 112 size: size of widget, default to WidgetSize.FULL 113 114 Example: 115 >>> bar_figure = go.Figure(go.Bar(name="Bar plot", x=[1, 2, 3, 4], y=[10, 11, 20, 11])) 116 >>> widget_info = plotly_figure(title="Bar plot widget", figure=bar_figure, size=WidgetSize.FULL) 117 """ 118 return plotly_graph(graph_data=GraphData.figure(title=title, figure=figure), size=size) 119 120 121 def plotly_graph_tabs(*, title: str, figures: List[GraphData], size: WidgetSize = WidgetSize.FULL) -> BaseWidgetInfo: 122 """ 123 generate Tab widget with multiple graphs 124 125 Args: 126 title: widget title 127 figures: list of graphs with tab titles 128 size: widget size 129 130 Example: 131 >>> bar_figure = go.Figure(go.Bar(name="Bar plot", x=[1, 2, 3, 4], y=[10, 11, 20, 11])) 132 >>> line_figure = go.Figure(go.Line(name="Bar plot", x=[1, 2, 3, 4], y=[10, 11, 20, 11])) 133 >>> widget_info = plotly_graph_tabs( 134 ... title="Tabbed widget", 135 ... figures=[GraphData.figure("Bar", bar_figure), GraphData.figure("Line", line_figure)], 136 ... ) 137 """ 138 return BaseWidgetInfo( 139 title=title, 140 type=WidgetType.TABBED_GRAPH.value, 141 size=size.value, 142 params={ 143 "graphs": [ 144 { 145 "id": str(uuid7()), 146 "title": graph.title, 147 "graph": { 148 "data": graph.data, 149 "layout": graph.layout, 150 }, 151 } 152 for graph in figures 153 ] 154 }, 155 ) 156 157 158 class CounterData: 159 label: str 160 value: str 161 162 def __init__(self, label: str, value: str): 163 """ 164 creates CounterData for counter widget with given label and value. 165 166 Args: 167 label: counter label 168 value: counter value 169 """ 170 self.label = label 171 self.value = value 172 173 @staticmethod 174 def float(label: str, value: float, precision: int) -> "CounterData": 175 """ 176 create CounterData for float value with given precision. 177 178 Args: 179 label: counter label 180 value: float value of counter 181 precision: decimal precision 182 """ 183 return CounterData(label, f"{value:.{precision}}") 184 185 @staticmethod 186 def string(label: str, value: str) -> "CounterData": 187 """ 188 create CounterData for string value with given precision. 189 190 Args: 191 label: counter label 192 value: string value of counter 193 """ 194 return CounterData(label, f"{value}") 195 196 @staticmethod 197 def int(label: str, value: int) -> "CounterData": 198 """ 199 create CounterData for int value. 200 201 Args: 202 label: counter label 203 value: int value 204 """ 205 return CounterData(label, f"{value}") 206 207 208 def counter(*, counters: List[CounterData], title: str = "", size: WidgetSize = WidgetSize.FULL) -> BaseWidgetInfo: 209 """ 210 generate widget with given counters 211 212 Args: 213 title: widget title 214 counters: list of counters in widget 215 size: widget size 216 217 Example: 218 >>> display_counters = [CounterData("value1", "some value"), CounterData.float("float", 0.111, 2)] 219 >>> widget_info = counter(counters=display_counters, title="counters example") 220 """ 221 return BaseWidgetInfo( 222 title=title, 223 type=WidgetType.COUNTER.value, 224 size=size.value, 225 params={"counters": [{"value": item.value, "label": item.label} for item in counters]}, 226 ) 227 228 229 def pie_chart( 230 *, 231 title: str, 232 data: Union[Dict[str, float], Tuple[List[str], List[float]]], 233 size: WidgetSize = WidgetSize.FULL, 234 colors: Union[Dict[str, str], List[str], None] = None, 235 ) -> BaseWidgetInfo: 236 if isinstance(data, dict): 237 labels = list(data.keys()) 238 values = list(data.values()) 239 else: 240 labels, values = data 241 if isinstance(colors, dict): 242 colors = [colors[lab] for lab in labels] 243 fig = go.Figure( 244 data=[ 245 go.Pie( 246 labels=labels, 247 values=values, 248 hole=0.4, 249 marker={"colors": colors, "line": dict(color="#000000", width=1)}, 250 ) 251 ] 252 ) 253 return plotly_figure(title=title, figure=fig, size=size) 254 255 256 def header_text(*, label: str, title: str = "", size: WidgetSize = WidgetSize.FULL): 257 """ 258 generate widget with some text as header 259 260 Args: 261 label: text to display 262 title: widget title 263 size: widget size 264 """ 265 return BaseWidgetInfo( 266 title=title, 267 type=WidgetType.COUNTER.value, 268 size=size.value, 269 params={"counters": [{"value": "", "label": label}]}, 270 ) 271 272 273 def text_widget(*, text: str, title: str = "", size: WidgetSize = WidgetSize.FULL): 274 """ 275 generate widget with markdown text 276 Args: 277 text: markdown formatted text 278 title: widget title 279 size: widget size 280 """ 281 return BaseWidgetInfo(title=title, type="text", size=size.value, params={"text": text}) 282 283 284 def table_data( 285 *, column_names: Iterable[str], data: Iterable[Iterable], title: str = "", size: WidgetSize = WidgetSize.FULL 286 ) -> BaseWidgetInfo: 287 """ 288 generate simple table with given columns and data 289 290 Args: 291 column_names: list of column names in display order 292 data: list of data rows (lists of object to show in table in order of columns), object will be converted to str 293 title: widget title 294 size: widget size 295 296 Example: 297 >>> columns = ["Column A", "Column B"] 298 >>> in_table_data = [[1, 2], [3, 4]] 299 >>> widget_info = table_data(column_names=columns, data=in_table_data, title="Table") 300 """ 301 return BaseWidgetInfo( 302 title=title, 303 type=WidgetType.TABLE.value, 304 params={ 305 "header": column_names, 306 "data": [[str(item) for item in row] for row in data], 307 }, 308 size=size.value, 309 ) 310 311 312 class ColumnType(Enum): 313 STRING = "string" 314 LINE = "line" 315 SCATTER = "scatter" 316 HISTOGRAM = "histogram" 317 318 319 class SortDirection(Enum): 320 ASC = "asc" 321 DESC = "desc" 322 323 324 @dataclasses.dataclass 325 class ColumnDefinition: 326 title: str 327 field_name: str 328 type: ColumnType = ColumnType.STRING 329 sort: Optional[SortDirection] = None 330 options: Optional[dict] = None 331 332 def as_dict(self) -> dict: 333 result: dict = {"title": self.title, "field": self.field_name} 334 if self.type != ColumnType.STRING: 335 result["type"] = self.type.value 336 if self.sort is not None: 337 result["sort"] = self.sort.value 338 if self.options is not None: 339 result["options"] = self.options 340 return result 341 342 343 @dataclasses.dataclass 344 class TabData: 345 title: str 346 widget: BaseWidgetInfo 347 348 349 def widget_tabs(*, title: str = "", size: WidgetSize = WidgetSize.FULL, tabs: List[TabData]) -> BaseWidgetInfo: 350 """ 351 generate widget with tabs which can contain any other widget. 352 353 Args: 354 title: widget title 355 size: widget size 356 tabs: list of TabData with widgets to include 357 358 Example: 359 >>> columns = ["Column A", "Column B"] 360 >>> in_table_data = [[1, 2], [3, 4]] 361 >>> tab_data = [ 362 ... TabData("Counters", counter(counters=[CounterData("counter", "value")], title="Counter")), 363 ... TabData("Table", table_data(column_names=columns, data=in_table_data, title="Table")), 364 ... ] 365 >>> widget_info = widget_tabs(title="Tabs", tabs=tab_data) 366 """ 367 return BaseWidgetInfo( 368 title=title, 369 type=WidgetType.TABS.value, 370 size=size.value, 371 tabs=[TabInfo(id=str(uuid7()), title=tab.title, widget=tab.widget) for tab in tabs], 372 ) 373 374 375 def widget_tabs_for_more_than_one( 376 *, title: str = "", size: WidgetSize = WidgetSize.FULL, tabs: List[TabData] 377 ) -> Optional[BaseWidgetInfo]: 378 """Draw tabs widget only if there is more than one tab, otherwise just draw one widget""" 379 if len(tabs) > 1: 380 return widget_tabs(title=title, size=size, tabs=tabs) 381 382 elif len(tabs) < 1: 383 return None 384 385 else: 386 return tabs[0].widget 387 388 389 class DetailsPartInfo: 390 title: str 391 info: Union[BaseWidgetInfo, PlotlyGraphInfo] 392 393 def __init__(self, title: str, info: Union[BaseWidgetInfo, PlotlyGraphInfo]): 394 self.title = title 395 self.info = info 396 397 398 class RowDetails: 399 parts: List[DetailsPartInfo] 400 401 def __init__(self, parts: Optional[List[DetailsPartInfo]] = None): 402 if parts is None: 403 parts = [] 404 self.parts = parts 405 406 def with_part(self, title: str, info: Union[BaseWidgetInfo, PlotlyGraphInfo]): 407 self.parts.append(DetailsPartInfo(title, info)) 408 return self 409 410 411 class RichTableDataRow: 412 details: Optional[RowDetails] 413 fields: dict 414 415 def __init__(self, fields: dict, details: Optional[RowDetails] = None): 416 self.fields = fields 417 self.details = details 418 419 420 def rich_table_data( 421 *, 422 title: str = "", 423 size: WidgetSize = WidgetSize.FULL, 424 rows_per_page: int = 10, 425 columns: List[ColumnDefinition], 426 data: List[RichTableDataRow], 427 ) -> BaseWidgetInfo: 428 """ 429 generate widget with rich table: with additional column types and details for rows 430 431 Args: 432 title: widget title 433 size: widget size 434 rows_per_page: maximum number per page to show 435 columns: list of columns in table 436 data: list of dicts with data (key-value pairs, keys is according to ColumnDefinition.field_name) 437 438 Example: 439 >>> columns_def = [ 440 ... ColumnDefinition("Column A", "field_1"), 441 ... ColumnDefinition("Column B", "field_2", ColumnType.HISTOGRAM, 442 ... options={"xField": "x", "yField": "y", "color": "#ed0400"}), 443 ... ColumnDefinition("Column C", "field_3", sort=SortDirection.ASC), 444 ... ] 445 >>> in_table_data = [ 446 ... RichTableDataRow(fields=dict(field_1="a", field_2=dict(x=[1, 2, 3], y=[10, 11, 3]), field_3="2")), 447 ... RichTableDataRow( 448 ... fields=dict(field_1="b", field_2=dict(x=[1, 2, 3], y=[10, 11, 3]), field_3="1"), 449 ... details=RowDetails() 450 ... .with_part("Some details", counter(counters=[CounterData("counter 1", "value")]) 451 ... ) 452 ... ) 453 ... ] 454 >>> widget_info = rich_table_data(title="Rich table", rows_per_page=10, columns=columns_def, data=in_table_data) 455 """ 456 additional_graphs = [] 457 458 converted_data = [] 459 for row in data: 460 if row.details is None or row.details.parts is None or len(row.details.parts) == 0: 461 converted_data.append(dict(**row.fields)) 462 continue 463 parts = [] 464 for part in row.details.parts: 465 parts.append( 466 dict( 467 title=part.title, 468 id=part.info.id, 469 type="widget" if isinstance(part.info, BaseWidgetInfo) else "graph", 470 ) 471 ) 472 additional_graphs.append(part.info) 473 converted_data.append(dict(details={"parts": parts}, **row.fields)) 474 475 return BaseWidgetInfo( 476 title=title, 477 type=WidgetType.BIG_TABLE.value, 478 details="", 479 alerts=[], 480 alertsPosition="row", 481 insights=[], 482 size=size.value, 483 params={ 484 "rowsPerPage": min(len(data), rows_per_page), 485 "columns": [column.as_dict() for column in columns], 486 "data": converted_data, 487 }, 488 additionalGraphs=additional_graphs, 489 ) 490 491 492 def get_histogram_figure( 493 *, 494 primary_hist: HistogramData, 495 secondary_hist: Optional[HistogramData] = None, 496 color_options: ColorOptions, 497 orientation: str = "v", 498 ) -> go.Figure: 499 figure = go.Figure() 500 curr_bar = go.Bar( 501 name=primary_hist.name, 502 x=primary_hist.x, 503 y=primary_hist.count, 504 marker_color=color_options.get_current_data_color(), 505 orientation=orientation, 506 ) 507 figure.add_trace(curr_bar) 508 509 if secondary_hist is not None: 510 ref_bar = go.Bar( 511 name=secondary_hist.name, 512 x=secondary_hist.x, 513 y=secondary_hist.count, 514 marker_color=color_options.get_reference_data_color(), 515 orientation=orientation, 516 ) 517 figure.add_trace(ref_bar) 518 519 return figure 520 521 522 def histogram( 523 *, 524 title: str, 525 primary_hist: HistogramData, 526 secondary_hist: Optional[HistogramData] = None, 527 color_options: ColorOptions, 528 orientation: str = "v", 529 size: WidgetSize = WidgetSize.FULL, 530 xaxis_title: Optional[str] = None, 531 yaxis_title: Optional[str] = None, 532 ) -> BaseWidgetInfo: 533 """ 534 generate widget with one or two histogram 535 536 Args: 537 title: widget title 538 primary_hist: first histogram to show in widget 539 secondary_hist: optional second histogram to show in widget 540 orientation: bars orientation in histograms 541 color_options: color options to use for widgets 542 size: widget size 543 xaxis_title: title for x-axis 544 yaxis_title: title for y-axis 545 Example: 546 >>> ref_hist = HistogramData(name="Histogram 1", x=pd.Series(["a", "b", "c"]), count=pd.Series([1, 2, 3])) 547 >>> curr_hist = HistogramData(name="Histogram 2", x=pd.Series(["a", "b", "c"]), count=pd.Series([3, 2 ,1])) 548 >>> widget_info = histogram( 549 >>> title="Histogram example", 550 >>> primary_hist=ref_hist, 551 >>> secondary_hist=curr_hist, 552 >>> color_options=color_options 553 >>> ) 554 """ 555 figure = get_histogram_figure( 556 primary_hist=primary_hist, 557 secondary_hist=secondary_hist, 558 color_options=color_options, 559 orientation=orientation, 560 ) 561 if xaxis_title is not None: 562 figure.update_layout( 563 xaxis_title=xaxis_title, 564 ) 565 566 if yaxis_title is not None: 567 figure.update_layout( 568 yaxis_title=yaxis_title, 569 ) 570 571 return plotly_figure(title=title, figure=figure, size=size) 572 573 574 def get_histogram_for_distribution( 575 *, 576 current_distribution: Distribution, 577 reference_distribution: Optional[Distribution] = None, 578 title: str = "", 579 xaxis_title: Optional[str] = None, 580 yaxis_title: Optional[str] = None, 581 color_options: ColorOptions, 582 ): 583 current_histogram = HistogramData( 584 name="current", 585 x=pd.Series(current_distribution.x), 586 count=pd.Series(current_distribution.y), 587 ) 588 589 if reference_distribution is not None: 590 reference_histogram: Optional[HistogramData] = HistogramData( 591 name="reference", 592 x=pd.Series(reference_distribution.x), 593 count=pd.Series(reference_distribution.y), 594 ) 595 596 else: 597 reference_histogram = None 598 599 return histogram( 600 title=title, 601 primary_hist=current_histogram, 602 secondary_hist=reference_histogram, 603 xaxis_title=xaxis_title, 604 yaxis_title=yaxis_title, 605 color_options=color_options, 606 ) 607 608 609 @dataclasses.dataclass 610 class HeatmapData: 611 name: str 612 matrix: pd.DataFrame 613 614 615 def get_heatmaps_widget( 616 *, 617 title: str = "", 618 primary_data: HeatmapData, 619 secondary_data: Optional[HeatmapData] = None, 620 size: WidgetSize = WidgetSize.FULL, 621 color_options: ColorOptions, 622 ) -> BaseWidgetInfo: 623 """ 624 Create a widget with heatmap(s) 625 """ 626 627 if secondary_data is not None: 628 subplot_titles = [primary_data.name, secondary_data.name] 629 heatmaps_count = 2 630 631 else: 632 subplot_titles = [""] 633 heatmaps_count = 1 634 635 figure = make_subplots(rows=1, cols=heatmaps_count, subplot_titles=subplot_titles, shared_yaxes=True) 636 637 for idx, heatmap_data in enumerate([primary_data, secondary_data]): 638 if heatmap_data is None: 639 continue 640 data = heatmap_data.matrix 641 columns = heatmap_data.matrix.columns 642 643 # show values if thw heatmap is small 644 if len(columns) < 15: 645 heatmap_text: Optional[pd.DataFrame] = np.round(data, 2).astype(str) # type: ignore[assignment] 646 heatmap_text_template: Optional[str] = "%{text}" 647 648 else: 649 heatmap_text = None 650 heatmap_text_template = None 651 652 figure.add_trace( 653 go.Heatmap( 654 z=data, 655 x=columns, 656 y=columns, 657 text=heatmap_text, 658 texttemplate=heatmap_text_template, 659 coloraxis="coloraxis", 660 ), 661 1, 662 idx + 1, 663 ) 664 665 figure.update_layout(coloraxis={"colorscale": color_options.heatmap}) 666 figure.update_yaxes(type="category") 667 figure.update_xaxes(tickangle=-45) 668 return plotly_figure(title=title, figure=figure, size=size) 669 670 671 def get_roc_auc_tab_data( 672 curr_roc_curve: ROCCurve, ref_roc_curve: Optional[ROCCurve], color_options: ColorOptions 673 ) -> List[Tuple[str, BaseWidgetInfo]]: 674 additional_plots = [] 675 cols = 1 676 subplot_titles = [""] 677 if ref_roc_curve is not None: 678 cols = 2 679 subplot_titles = ["current", "reference"] 680 for label in curr_roc_curve.keys(): 681 fig = make_subplots(rows=1, cols=cols, subplot_titles=subplot_titles, shared_yaxes=True) 682 trace = go.Scatter( 683 x=curr_roc_curve[label].fpr, 684 y=curr_roc_curve[label].tpr, 685 mode="lines", 686 name="ROC", 687 legendgroup="ROC", 688 marker=dict( 689 size=6, 690 color=color_options.get_current_data_color(), 691 ), 692 ) 693 fig.add_trace(trace, 1, 1) 694 fig.update_xaxes(title_text="False Positive Rate", row=1, col=1) 695 if ref_roc_curve is not None: 696 trace = go.Scatter( 697 x=ref_roc_curve[label].fpr, 698 y=ref_roc_curve[label].tpr, 699 mode="lines", 700 name="ROC", 701 legendgroup="ROC", 702 showlegend=False, 703 marker=dict( 704 size=6, 705 color=color_options.get_current_data_color(), 706 ), 707 ) 708 fig.add_trace(trace, 1, 2) 709 fig.update_xaxes(title_text="False Positive Rate", row=1, col=2) 710 fig.update_layout(yaxis_title="True Positive Rate", showlegend=True) 711 712 additional_plots.append((str(label), plotly_figure(title="", figure=fig))) 713 return additional_plots 714 715 716 def get_pr_rec_plot_data( 717 current_pr_curve: PRCurve, reference_pr_curve: Optional[PRCurve], color_options: ColorOptions 718 ) -> List[Tuple[str, BaseWidgetInfo]]: 719 additional_plots = [] 720 cols = 1 721 subplot_titles = [""] 722 if reference_pr_curve is not None: 723 cols = 2 724 subplot_titles = ["current", "reference"] 725 for label in current_pr_curve.keys(): 726 fig = make_subplots(rows=1, cols=cols, subplot_titles=subplot_titles, shared_yaxes=True) 727 trace = go.Scatter( 728 x=current_pr_curve[label].rcl, 729 y=current_pr_curve[label].pr, 730 mode="lines", 731 name="PR", 732 legendgroup="PR", 733 marker=dict( 734 size=6, 735 color=color_options.get_current_data_color(), 736 ), 737 ) 738 fig.add_trace(trace, 1, 1) 739 fig.update_xaxes(title_text="Recall", row=1, col=1) 740 if reference_pr_curve is not None: 741 trace = go.Scatter( 742 x=reference_pr_curve[label].rcl, 743 y=reference_pr_curve[label].pr, 744 mode="lines", 745 name="PR", 746 legendgroup="PR", 747 showlegend=False, 748 marker=dict( 749 size=6, 750 color=color_options.get_current_data_color(), 751 ), 752 ) 753 fig.add_trace(trace, 1, 2) 754 fig.update_xaxes(title_text="Recall", row=1, col=2) 755 fig.update_layout(yaxis_title="Precision", showlegend=True) 756 757 additional_plots.append((str(label), plotly_figure(title="", figure=fig))) 758 return additional_plots 759 760 761 def get_lift_plot_data( 762 current_lift_curve: LiftCurve, 763 reference_lift_curve: Optional[LiftCurve], 764 color_options: ColorOptions, 765 ) -> List[Tuple[str, BaseWidgetInfo]]: 766 """ 767 Forms plot data for lift metric visualization 768 769 Parameters 770 ---------- 771 current_lift_curve: dict 772 Calculated lift table data for current sample 773 reference_lift_curve: Optional[dict] 774 Calculated lift table data for reference sample 775 color_options: ColorOptions 776 Standard Evidently class-collection of colors for data visualization 777 778 Return values 779 ------------- 780 additional_plots: List[Tuple[str, BaseWidgetInfo]] 781 Plot objects within List 782 """ 783 additional_plots = [] 784 cols = 1 785 subplot_titles = [""] 786 if reference_lift_curve is not None: 787 cols = 2 788 subplot_titles = ["current", "reference"] 789 for label in current_lift_curve.keys(): 790 fig = make_subplots(rows=1, cols=cols, subplot_titles=subplot_titles, shared_yaxes=True) 791 trace = go.Scatter( 792 x=current_lift_curve[label].top, 793 y=current_lift_curve[label].lift, 794 mode="lines+markers", 795 name="Lift", 796 hoverinfo="text", 797 text=[ 798 f"top: {str(int(current_lift_curve[label].top[i]))}, " f"lift={str(current_lift_curve[label].lift[i])}" 799 for i in range(100) 800 ], 801 legendgroup="Lift", 802 marker=dict( 803 size=6, 804 color=color_options.get_current_data_color(), 805 ), 806 ) 807 fig.add_trace(trace, 1, 1) 808 fig.update_xaxes(title_text="Top", row=1, col=1) 809 if reference_lift_curve is not None: 810 trace = go.Scatter( 811 x=reference_lift_curve[label].top, 812 y=reference_lift_curve[label].lift, 813 mode="lines+markers", 814 name="Lift", 815 hoverinfo="text", 816 text=[ 817 f"top: {str(int(reference_lift_curve[label].top[i]))}, " 818 f"lift={str(reference_lift_curve[label].lift[i])}" 819 for i in range(100) 820 ], 821 legendgroup="Lift", 822 showlegend=False, 823 marker=dict( 824 size=6, 825 color=color_options.get_current_data_color(), 826 ), 827 ) 828 fig.add_trace(trace, 1, 2) 829 fig.update_xaxes(title_text="Top", row=1, col=2) 830 fig.update_layout(yaxis_title="Lift", showlegend=True) 831 832 additional_plots.append((str(label), plotly_figure(title="", figure=fig))) 833 return additional_plots 834 835 836 def class_separation_traces_raw(df, label, target_name, color_options): 837 traces = [] 838 traces.append( 839 go.Scatter( 840 x=np.random.random(df[df[target_name] == label].shape[0]), 841 y=df[df[target_name] == label][label], 842 mode="markers", 843 name=str(label), 844 legendgroup=str(label), 845 marker=dict(size=6, color=color_options.primary_color), 846 ) 847 ) 848 traces.append( 849 go.Scatter( 850 x=np.random.random(df[df[target_name] != label].shape[0]), 851 y=df[df[target_name] != label][label], 852 mode="markers", 853 name="other", 854 legendgroup="other", 855 marker=dict(size=6, color=color_options.secondary_color), 856 ) 857 ) 858 return traces 859 860 861 def class_separation_traces_agg(df, label, color_options): 862 traces = [] 863 df_name = df[df["values"] == label] 864 traces.append( 865 go.Box( 866 lowerfence=df_name["mins"], 867 q1=df_name["lowers"], 868 q3=df_name["uppers"], 869 median=df_name["means"], 870 upperfence=df_name["maxs"], 871 x=df_name["values"].astype(str), 872 marker_color=color_options.get_current_data_color(), 873 ) 874 ) 875 df_name = df[df["values"] == "others"] 876 traces.append( 877 go.Box( 878 lowerfence=df_name["mins"], 879 q1=df_name["lowers"], 880 q3=df_name["uppers"], 881 median=df_name["means"], 882 upperfence=df_name["maxs"], 883 x=df_name["values"], 884 marker_color=color_options.get_reference_data_color(), 885 ) 886 ) 887 return traces 888 889 890 def get_class_separation_plot_data( 891 current_plot: pd.DataFrame, reference_plot: Optional[pd.DataFrame], target_name: str, color_options: ColorOptions 892 ) -> List[Tuple[str, BaseWidgetInfo]]: 893 additional_plots = [] 894 cols = 1 895 subplot_titles = [""] 896 if reference_plot is not None: 897 cols = 2 898 subplot_titles = ["current", "reference"] 899 for label in current_plot.columns.drop(target_name): 900 fig = make_subplots(rows=1, cols=cols, subplot_titles=subplot_titles, shared_yaxes=True) 901 traces = class_separation_traces_raw(current_plot, label, target_name, color_options) 902 for trace in traces: 903 fig.add_trace(trace, 1, 1) 904 fig.update_xaxes(dict(range=(-2, 3), showticklabels=False), row=1, col=1) 905 906 if reference_plot is not None: 907 traces = class_separation_traces_raw(reference_plot, label, target_name, color_options) 908 for trace in traces: 909 fig.add_trace(trace, 1, 2) 910 fig.update_xaxes(dict(range=(-2, 3), showticklabels=False), row=1, col=2) 911 912 fig.update_layout(yaxis_title="Probability", showlegend=True) 913 914 additional_plots.append((str(label), plotly_figure(title="", figure=fig))) 915 return additional_plots 916 917 918 def get_class_separation_plot_data_agg( 919 current_plot: Dict[Label, pd.DataFrame], 920 reference_plot: Optional[Dict[Label, pd.DataFrame]], 921 target_name: str, 922 color_options: ColorOptions, 923 ) -> List[Tuple[str, BaseWidgetInfo]]: 924 additional_plots = [] 925 cols = 1 926 subplot_titles = [""] 927 if reference_plot is not None: 928 cols = 2 929 subplot_titles = ["current", "reference"] 930 for label in current_plot.keys(): 931 fig = make_subplots(rows=1, cols=cols, subplot_titles=subplot_titles, shared_yaxes=True) 932 traces = class_separation_traces_agg(current_plot[label], label, color_options) 933 for trace in traces: 934 fig.add_trace(trace, 1, 1) 935 936 if reference_plot is not None: 937 traces = class_separation_traces_agg(reference_plot[label], label, color_options) 938 for trace in traces: 939 fig.add_trace(trace, 1, 2) 940 941 fig.update_layout(yaxis_title="Probability", showlegend=False) 942 943 additional_plots.append((str(label), plotly_figure(title="", figure=fig))) 944 return additional_plots 945 946 947 def group_widget( 948 *, 949 title: str, 950 widgets: List[BaseWidgetInfo], 951 ) -> BaseWidgetInfo: 952 return BaseWidgetInfo( 953 title=title, 954 type=WidgetType.GROUP.value, 955 widgets=widgets, 956 size=2, 957 ) 958 959 960 def rich_data( 961 *, 962 title: str, 963 description: str, 964 header: List[str], 965 metrics: List[dict], 966 graph: Optional[PlotlyGraphInfo], 967 ): 968 return BaseWidgetInfo( 969 type=WidgetType.RICH_DATA.value, 970 title="", 971 size=2, 972 params={ 973 "header": title, 974 "description": description, 975 "metricsValuesHeaders": header, 976 "metrics": metrics, 977 "graph": graph, 978 "details": {"parts": [], "insights": []}, 979 }, 980 additionalGraphs=[], 981 )