/ src / evidently / legacy / renderers / html_widgets.py
html_widgets.py
  1  import dataclasses
  2  from enum import Enum
  3  from typing import Dict
  4  from typing import Iterable
  5  from typing import List
  6  from typing import Optional
  7  from typing import Tuple
  8  from typing import Union
  9  
 10  import numpy as np
 11  import pandas as pd
 12  from plotly import graph_objs as go
 13  from plotly.subplots import make_subplots
 14  from uuid6 import uuid7
 15  
 16  from evidently.legacy.metric_results import Distribution
 17  from evidently.legacy.metric_results import HistogramData
 18  from evidently.legacy.metric_results import Label
 19  from evidently.legacy.metric_results import LiftCurve
 20  from evidently.legacy.metric_results import PRCurve
 21  from evidently.legacy.metric_results import ROCCurve
 22  from evidently.legacy.model.widget import BaseWidgetInfo
 23  from evidently.legacy.model.widget import PlotlyGraphInfo
 24  from evidently.legacy.model.widget import TabInfo
 25  from evidently.legacy.model.widget import WidgetType
 26  from evidently.legacy.options import ColorOptions
 27  
 28  
 29  class WidgetSize(int, Enum):
 30      SMALL = 0
 31      HALF = 1
 32      FULL = 2
 33  
 34  
 35  class GraphData:
 36      title: str
 37      data: dict
 38      layout: dict
 39  
 40      def __init__(self, title: str, data: dict, layout: dict):
 41          """
 42          create GraphData object for usage in plotly_graph_tabs or plotly_data.
 43  
 44          Args:
 45              title: title of graph
 46              data: plotly figure data
 47              layout: plotly figure layout
 48          """
 49          self.title = title
 50          self.data = data
 51          self.layout = layout
 52  
 53      @staticmethod
 54      def figure(title: str, figure: go.Figure):
 55          """
 56          create GraphData from plotly figure itself
 57          Args:
 58              title: title of graph
 59              figure: plotly figure for getting data from
 60          """
 61          data = figure.to_plotly_json()
 62          return GraphData(title, data["data"], data["layout"])
 63  
 64  
 65  def plotly_graph(*, graph_data: GraphData, size: WidgetSize = WidgetSize.FULL) -> BaseWidgetInfo:
 66      """
 67      generate plotly plot with given GraphData object.
 68  
 69      Args:
 70          graph_data: plot data for widget
 71          size: size of widget to render
 72  
 73      Example:
 74          >>> figure = go.Figure(go.Bar(name="Bar plot", x=[1, 2, 3, 4], y=[10, 11, 20, 11]))
 75          >>> f_dict = figure.to_plotly_json()
 76          >>> bar_graph_data = GraphData(title="Some plot title", data=f_dict["data"], layout=f_dict["layout"])
 77          >>> widget_info = plotly_graph(graph_data=bar_graph_data, size=WidgetSize.FULL)
 78      """
 79      return BaseWidgetInfo(
 80          title=graph_data.title,
 81          type=WidgetType.BIG_GRAPH.value,
 82          size=size.value,
 83          params={"data": graph_data.data, "layout": graph_data.layout},
 84      )
 85  
 86  
 87  def plotly_data(*, title: str, data: dict, layout: dict, size: WidgetSize = WidgetSize.FULL) -> BaseWidgetInfo:
 88      """
 89      generate plotly plot with given data and layout (can be generated from plotly).
 90  
 91      Args:
 92          title: widget title
 93          data: plotly figure data
 94          layout: plotly figure layout
 95          size: widget size
 96  
 97      Example:
 98          >>> figure = go.Figure(go.Bar(name="Bar plot", x=[1, 2, 3, 4], y=[10, 11, 20, 11]))
 99          >>> f_dict = figure.to_plotly_json()
100          >>> widget_info = plotly_data(title="Some plot title", data=f_dict["data"], layout=f_dict["layout"])
101      """
102      return plotly_graph(graph_data=GraphData(title, data, layout), size=size)
103  
104  
105  def plotly_figure(*, title: str, figure: go.Figure, size: WidgetSize = WidgetSize.FULL) -> BaseWidgetInfo:
106      """
107      generate plotly plot based on given plotly figure object.
108  
109      Args:
110          title: title of widget
111          figure: plotly figure which should be rendered as widget
112          size: size of widget, default to WidgetSize.FULL
113  
114      Example:
115          >>> bar_figure = go.Figure(go.Bar(name="Bar plot", x=[1, 2, 3, 4], y=[10, 11, 20, 11]))
116          >>> widget_info = plotly_figure(title="Bar plot widget", figure=bar_figure, size=WidgetSize.FULL)
117      """
118      return plotly_graph(graph_data=GraphData.figure(title=title, figure=figure), size=size)
119  
120  
121  def plotly_graph_tabs(*, title: str, figures: List[GraphData], size: WidgetSize = WidgetSize.FULL) -> BaseWidgetInfo:
122      """
123      generate Tab widget with multiple graphs
124  
125      Args:
126          title: widget title
127          figures: list of graphs with tab titles
128          size: widget size
129  
130      Example:
131          >>> bar_figure = go.Figure(go.Bar(name="Bar plot", x=[1, 2, 3, 4], y=[10, 11, 20, 11]))
132          >>> line_figure = go.Figure(go.Line(name="Bar plot", x=[1, 2, 3, 4], y=[10, 11, 20, 11]))
133          >>> widget_info = plotly_graph_tabs(
134          ...     title="Tabbed widget",
135          ...     figures=[GraphData.figure("Bar", bar_figure), GraphData.figure("Line", line_figure)],
136          ... )
137      """
138      return BaseWidgetInfo(
139          title=title,
140          type=WidgetType.TABBED_GRAPH.value,
141          size=size.value,
142          params={
143              "graphs": [
144                  {
145                      "id": str(uuid7()),
146                      "title": graph.title,
147                      "graph": {
148                          "data": graph.data,
149                          "layout": graph.layout,
150                      },
151                  }
152                  for graph in figures
153              ]
154          },
155      )
156  
157  
158  class CounterData:
159      label: str
160      value: str
161  
162      def __init__(self, label: str, value: str):
163          """
164          creates CounterData for counter widget with given label and value.
165  
166          Args:
167              label: counter label
168              value: counter value
169          """
170          self.label = label
171          self.value = value
172  
173      @staticmethod
174      def float(label: str, value: float, precision: int) -> "CounterData":
175          """
176          create CounterData for float value with given precision.
177  
178          Args:
179              label: counter label
180              value: float value of counter
181              precision: decimal precision
182          """
183          return CounterData(label, f"{value:.{precision}}")
184  
185      @staticmethod
186      def string(label: str, value: str) -> "CounterData":
187          """
188          create CounterData for string value with given precision.
189  
190          Args:
191              label: counter label
192              value: string value of counter
193          """
194          return CounterData(label, f"{value}")
195  
196      @staticmethod
197      def int(label: str, value: int) -> "CounterData":
198          """
199          create CounterData for int value.
200  
201          Args:
202              label: counter label
203              value: int value
204          """
205          return CounterData(label, f"{value}")
206  
207  
208  def counter(*, counters: List[CounterData], title: str = "", size: WidgetSize = WidgetSize.FULL) -> BaseWidgetInfo:
209      """
210      generate widget with given counters
211  
212      Args:
213          title: widget title
214          counters: list of counters in widget
215          size: widget size
216  
217      Example:
218          >>> display_counters = [CounterData("value1", "some value"), CounterData.float("float", 0.111, 2)]
219          >>> widget_info = counter(counters=display_counters, title="counters example")
220      """
221      return BaseWidgetInfo(
222          title=title,
223          type=WidgetType.COUNTER.value,
224          size=size.value,
225          params={"counters": [{"value": item.value, "label": item.label} for item in counters]},
226      )
227  
228  
229  def pie_chart(
230      *,
231      title: str,
232      data: Union[Dict[str, float], Tuple[List[str], List[float]]],
233      size: WidgetSize = WidgetSize.FULL,
234      colors: Union[Dict[str, str], List[str], None] = None,
235  ) -> BaseWidgetInfo:
236      if isinstance(data, dict):
237          labels = list(data.keys())
238          values = list(data.values())
239      else:
240          labels, values = data
241      if isinstance(colors, dict):
242          colors = [colors[lab] for lab in labels]
243      fig = go.Figure(
244          data=[
245              go.Pie(
246                  labels=labels,
247                  values=values,
248                  hole=0.4,
249                  marker={"colors": colors, "line": dict(color="#000000", width=1)},
250              )
251          ]
252      )
253      return plotly_figure(title=title, figure=fig, size=size)
254  
255  
256  def header_text(*, label: str, title: str = "", size: WidgetSize = WidgetSize.FULL):
257      """
258      generate widget with some text as header
259  
260      Args:
261          label: text to display
262          title: widget title
263          size: widget size
264      """
265      return BaseWidgetInfo(
266          title=title,
267          type=WidgetType.COUNTER.value,
268          size=size.value,
269          params={"counters": [{"value": "", "label": label}]},
270      )
271  
272  
273  def text_widget(*, text: str, title: str = "", size: WidgetSize = WidgetSize.FULL):
274      """
275      generate widget with markdown text
276      Args:
277          text: markdown formatted text
278          title: widget title
279          size: widget size
280      """
281      return BaseWidgetInfo(title=title, type="text", size=size.value, params={"text": text})
282  
283  
284  def table_data(
285      *, column_names: Iterable[str], data: Iterable[Iterable], title: str = "", size: WidgetSize = WidgetSize.FULL
286  ) -> BaseWidgetInfo:
287      """
288      generate simple table with given columns and data
289  
290      Args:
291          column_names: list of column names in display order
292          data: list of data rows (lists of object to show in table in order of columns), object will be converted to str
293          title: widget title
294          size: widget size
295  
296      Example:
297          >>> columns = ["Column A", "Column B"]
298          >>> in_table_data = [[1, 2], [3, 4]]
299          >>> widget_info = table_data(column_names=columns, data=in_table_data, title="Table")
300      """
301      return BaseWidgetInfo(
302          title=title,
303          type=WidgetType.TABLE.value,
304          params={
305              "header": column_names,
306              "data": [[str(item) for item in row] for row in data],
307          },
308          size=size.value,
309      )
310  
311  
312  class ColumnType(Enum):
313      STRING = "string"
314      LINE = "line"
315      SCATTER = "scatter"
316      HISTOGRAM = "histogram"
317  
318  
319  class SortDirection(Enum):
320      ASC = "asc"
321      DESC = "desc"
322  
323  
324  @dataclasses.dataclass
325  class ColumnDefinition:
326      title: str
327      field_name: str
328      type: ColumnType = ColumnType.STRING
329      sort: Optional[SortDirection] = None
330      options: Optional[dict] = None
331  
332      def as_dict(self) -> dict:
333          result: dict = {"title": self.title, "field": self.field_name}
334          if self.type != ColumnType.STRING:
335              result["type"] = self.type.value
336          if self.sort is not None:
337              result["sort"] = self.sort.value
338          if self.options is not None:
339              result["options"] = self.options
340          return result
341  
342  
343  @dataclasses.dataclass
344  class TabData:
345      title: str
346      widget: BaseWidgetInfo
347  
348  
349  def widget_tabs(*, title: str = "", size: WidgetSize = WidgetSize.FULL, tabs: List[TabData]) -> BaseWidgetInfo:
350      """
351      generate widget with tabs which can contain any other widget.
352  
353      Args:
354          title: widget title
355          size: widget size
356          tabs: list of TabData with widgets to include
357  
358      Example:
359          >>> columns = ["Column A", "Column B"]
360          >>> in_table_data = [[1, 2], [3, 4]]
361          >>> tab_data = [
362          ...     TabData("Counters", counter(counters=[CounterData("counter", "value")], title="Counter")),
363          ...     TabData("Table", table_data(column_names=columns, data=in_table_data, title="Table")),
364          ... ]
365          >>> widget_info = widget_tabs(title="Tabs", tabs=tab_data)
366      """
367      return BaseWidgetInfo(
368          title=title,
369          type=WidgetType.TABS.value,
370          size=size.value,
371          tabs=[TabInfo(id=str(uuid7()), title=tab.title, widget=tab.widget) for tab in tabs],
372      )
373  
374  
375  def widget_tabs_for_more_than_one(
376      *, title: str = "", size: WidgetSize = WidgetSize.FULL, tabs: List[TabData]
377  ) -> Optional[BaseWidgetInfo]:
378      """Draw tabs widget only if there is more than one tab, otherwise just draw one widget"""
379      if len(tabs) > 1:
380          return widget_tabs(title=title, size=size, tabs=tabs)
381  
382      elif len(tabs) < 1:
383          return None
384  
385      else:
386          return tabs[0].widget
387  
388  
389  class DetailsPartInfo:
390      title: str
391      info: Union[BaseWidgetInfo, PlotlyGraphInfo]
392  
393      def __init__(self, title: str, info: Union[BaseWidgetInfo, PlotlyGraphInfo]):
394          self.title = title
395          self.info = info
396  
397  
398  class RowDetails:
399      parts: List[DetailsPartInfo]
400  
401      def __init__(self, parts: Optional[List[DetailsPartInfo]] = None):
402          if parts is None:
403              parts = []
404          self.parts = parts
405  
406      def with_part(self, title: str, info: Union[BaseWidgetInfo, PlotlyGraphInfo]):
407          self.parts.append(DetailsPartInfo(title, info))
408          return self
409  
410  
411  class RichTableDataRow:
412      details: Optional[RowDetails]
413      fields: dict
414  
415      def __init__(self, fields: dict, details: Optional[RowDetails] = None):
416          self.fields = fields
417          self.details = details
418  
419  
420  def rich_table_data(
421      *,
422      title: str = "",
423      size: WidgetSize = WidgetSize.FULL,
424      rows_per_page: int = 10,
425      columns: List[ColumnDefinition],
426      data: List[RichTableDataRow],
427  ) -> BaseWidgetInfo:
428      """
429      generate widget with rich table: with additional column types and details for rows
430  
431      Args:
432           title: widget title
433           size: widget size
434           rows_per_page: maximum number per page to show
435           columns: list of columns in table
436           data: list of dicts with data (key-value pairs, keys is according to ColumnDefinition.field_name)
437  
438      Example:
439          >>> columns_def = [
440          ...     ColumnDefinition("Column A", "field_1"),
441          ...     ColumnDefinition("Column B", "field_2", ColumnType.HISTOGRAM,
442          ...                      options={"xField": "x", "yField": "y", "color": "#ed0400"}),
443          ...     ColumnDefinition("Column C", "field_3", sort=SortDirection.ASC),
444          ... ]
445          >>> in_table_data = [
446          ...     RichTableDataRow(fields=dict(field_1="a", field_2=dict(x=[1, 2, 3], y=[10, 11, 3]), field_3="2")),
447          ...     RichTableDataRow(
448          ...         fields=dict(field_1="b", field_2=dict(x=[1, 2, 3], y=[10, 11, 3]), field_3="1"),
449          ...         details=RowDetails()
450          ...             .with_part("Some details", counter(counters=[CounterData("counter 1", "value")])
451          ...         )
452          ...     )
453          ... ]
454          >>> widget_info = rich_table_data(title="Rich table", rows_per_page=10, columns=columns_def, data=in_table_data)
455      """
456      additional_graphs = []
457  
458      converted_data = []
459      for row in data:
460          if row.details is None or row.details.parts is None or len(row.details.parts) == 0:
461              converted_data.append(dict(**row.fields))
462              continue
463          parts = []
464          for part in row.details.parts:
465              parts.append(
466                  dict(
467                      title=part.title,
468                      id=part.info.id,
469                      type="widget" if isinstance(part.info, BaseWidgetInfo) else "graph",
470                  )
471              )
472              additional_graphs.append(part.info)
473          converted_data.append(dict(details={"parts": parts}, **row.fields))
474  
475      return BaseWidgetInfo(
476          title=title,
477          type=WidgetType.BIG_TABLE.value,
478          details="",
479          alerts=[],
480          alertsPosition="row",
481          insights=[],
482          size=size.value,
483          params={
484              "rowsPerPage": min(len(data), rows_per_page),
485              "columns": [column.as_dict() for column in columns],
486              "data": converted_data,
487          },
488          additionalGraphs=additional_graphs,
489      )
490  
491  
492  def get_histogram_figure(
493      *,
494      primary_hist: HistogramData,
495      secondary_hist: Optional[HistogramData] = None,
496      color_options: ColorOptions,
497      orientation: str = "v",
498  ) -> go.Figure:
499      figure = go.Figure()
500      curr_bar = go.Bar(
501          name=primary_hist.name,
502          x=primary_hist.x,
503          y=primary_hist.count,
504          marker_color=color_options.get_current_data_color(),
505          orientation=orientation,
506      )
507      figure.add_trace(curr_bar)
508  
509      if secondary_hist is not None:
510          ref_bar = go.Bar(
511              name=secondary_hist.name,
512              x=secondary_hist.x,
513              y=secondary_hist.count,
514              marker_color=color_options.get_reference_data_color(),
515              orientation=orientation,
516          )
517          figure.add_trace(ref_bar)
518  
519      return figure
520  
521  
522  def histogram(
523      *,
524      title: str,
525      primary_hist: HistogramData,
526      secondary_hist: Optional[HistogramData] = None,
527      color_options: ColorOptions,
528      orientation: str = "v",
529      size: WidgetSize = WidgetSize.FULL,
530      xaxis_title: Optional[str] = None,
531      yaxis_title: Optional[str] = None,
532  ) -> BaseWidgetInfo:
533      """
534      generate widget with one or two histogram
535  
536      Args:
537          title: widget title
538          primary_hist: first histogram to show in widget
539          secondary_hist: optional second histogram to show in widget
540          orientation: bars orientation in histograms
541          color_options: color options to use for widgets
542          size: widget size
543          xaxis_title: title for x-axis
544          yaxis_title: title for y-axis
545      Example:
546          >>> ref_hist = HistogramData(name="Histogram 1", x=pd.Series(["a", "b", "c"]), count=pd.Series([1, 2, 3]))
547          >>> curr_hist = HistogramData(name="Histogram 2", x=pd.Series(["a", "b", "c"]), count=pd.Series([3, 2 ,1]))
548          >>> widget_info = histogram(
549          >>>     title="Histogram example",
550          >>>     primary_hist=ref_hist,
551          >>>     secondary_hist=curr_hist,
552          >>>     color_options=color_options
553          >>> )
554      """
555      figure = get_histogram_figure(
556          primary_hist=primary_hist,
557          secondary_hist=secondary_hist,
558          color_options=color_options,
559          orientation=orientation,
560      )
561      if xaxis_title is not None:
562          figure.update_layout(
563              xaxis_title=xaxis_title,
564          )
565  
566      if yaxis_title is not None:
567          figure.update_layout(
568              yaxis_title=yaxis_title,
569          )
570  
571      return plotly_figure(title=title, figure=figure, size=size)
572  
573  
574  def get_histogram_for_distribution(
575      *,
576      current_distribution: Distribution,
577      reference_distribution: Optional[Distribution] = None,
578      title: str = "",
579      xaxis_title: Optional[str] = None,
580      yaxis_title: Optional[str] = None,
581      color_options: ColorOptions,
582  ):
583      current_histogram = HistogramData(
584          name="current",
585          x=pd.Series(current_distribution.x),
586          count=pd.Series(current_distribution.y),
587      )
588  
589      if reference_distribution is not None:
590          reference_histogram: Optional[HistogramData] = HistogramData(
591              name="reference",
592              x=pd.Series(reference_distribution.x),
593              count=pd.Series(reference_distribution.y),
594          )
595  
596      else:
597          reference_histogram = None
598  
599      return histogram(
600          title=title,
601          primary_hist=current_histogram,
602          secondary_hist=reference_histogram,
603          xaxis_title=xaxis_title,
604          yaxis_title=yaxis_title,
605          color_options=color_options,
606      )
607  
608  
609  @dataclasses.dataclass
610  class HeatmapData:
611      name: str
612      matrix: pd.DataFrame
613  
614  
615  def get_heatmaps_widget(
616      *,
617      title: str = "",
618      primary_data: HeatmapData,
619      secondary_data: Optional[HeatmapData] = None,
620      size: WidgetSize = WidgetSize.FULL,
621      color_options: ColorOptions,
622  ) -> BaseWidgetInfo:
623      """
624      Create a widget with heatmap(s)
625      """
626  
627      if secondary_data is not None:
628          subplot_titles = [primary_data.name, secondary_data.name]
629          heatmaps_count = 2
630  
631      else:
632          subplot_titles = [""]
633          heatmaps_count = 1
634  
635      figure = make_subplots(rows=1, cols=heatmaps_count, subplot_titles=subplot_titles, shared_yaxes=True)
636  
637      for idx, heatmap_data in enumerate([primary_data, secondary_data]):
638          if heatmap_data is None:
639              continue
640          data = heatmap_data.matrix
641          columns = heatmap_data.matrix.columns
642  
643          # show values if thw heatmap is small
644          if len(columns) < 15:
645              heatmap_text: Optional[pd.DataFrame] = np.round(data, 2).astype(str)  # type: ignore[assignment]
646              heatmap_text_template: Optional[str] = "%{text}"
647  
648          else:
649              heatmap_text = None
650              heatmap_text_template = None
651  
652          figure.add_trace(
653              go.Heatmap(
654                  z=data,
655                  x=columns,
656                  y=columns,
657                  text=heatmap_text,
658                  texttemplate=heatmap_text_template,
659                  coloraxis="coloraxis",
660              ),
661              1,
662              idx + 1,
663          )
664  
665      figure.update_layout(coloraxis={"colorscale": color_options.heatmap})
666      figure.update_yaxes(type="category")
667      figure.update_xaxes(tickangle=-45)
668      return plotly_figure(title=title, figure=figure, size=size)
669  
670  
671  def get_roc_auc_tab_data(
672      curr_roc_curve: ROCCurve, ref_roc_curve: Optional[ROCCurve], color_options: ColorOptions
673  ) -> List[Tuple[str, BaseWidgetInfo]]:
674      additional_plots = []
675      cols = 1
676      subplot_titles = [""]
677      if ref_roc_curve is not None:
678          cols = 2
679          subplot_titles = ["current", "reference"]
680      for label in curr_roc_curve.keys():
681          fig = make_subplots(rows=1, cols=cols, subplot_titles=subplot_titles, shared_yaxes=True)
682          trace = go.Scatter(
683              x=curr_roc_curve[label].fpr,
684              y=curr_roc_curve[label].tpr,
685              mode="lines",
686              name="ROC",
687              legendgroup="ROC",
688              marker=dict(
689                  size=6,
690                  color=color_options.get_current_data_color(),
691              ),
692          )
693          fig.add_trace(trace, 1, 1)
694          fig.update_xaxes(title_text="False Positive Rate", row=1, col=1)
695          if ref_roc_curve is not None:
696              trace = go.Scatter(
697                  x=ref_roc_curve[label].fpr,
698                  y=ref_roc_curve[label].tpr,
699                  mode="lines",
700                  name="ROC",
701                  legendgroup="ROC",
702                  showlegend=False,
703                  marker=dict(
704                      size=6,
705                      color=color_options.get_current_data_color(),
706                  ),
707              )
708              fig.add_trace(trace, 1, 2)
709              fig.update_xaxes(title_text="False Positive Rate", row=1, col=2)
710          fig.update_layout(yaxis_title="True Positive Rate", showlegend=True)
711  
712          additional_plots.append((str(label), plotly_figure(title="", figure=fig)))
713      return additional_plots
714  
715  
716  def get_pr_rec_plot_data(
717      current_pr_curve: PRCurve, reference_pr_curve: Optional[PRCurve], color_options: ColorOptions
718  ) -> List[Tuple[str, BaseWidgetInfo]]:
719      additional_plots = []
720      cols = 1
721      subplot_titles = [""]
722      if reference_pr_curve is not None:
723          cols = 2
724          subplot_titles = ["current", "reference"]
725      for label in current_pr_curve.keys():
726          fig = make_subplots(rows=1, cols=cols, subplot_titles=subplot_titles, shared_yaxes=True)
727          trace = go.Scatter(
728              x=current_pr_curve[label].rcl,
729              y=current_pr_curve[label].pr,
730              mode="lines",
731              name="PR",
732              legendgroup="PR",
733              marker=dict(
734                  size=6,
735                  color=color_options.get_current_data_color(),
736              ),
737          )
738          fig.add_trace(trace, 1, 1)
739          fig.update_xaxes(title_text="Recall", row=1, col=1)
740          if reference_pr_curve is not None:
741              trace = go.Scatter(
742                  x=reference_pr_curve[label].rcl,
743                  y=reference_pr_curve[label].pr,
744                  mode="lines",
745                  name="PR",
746                  legendgroup="PR",
747                  showlegend=False,
748                  marker=dict(
749                      size=6,
750                      color=color_options.get_current_data_color(),
751                  ),
752              )
753              fig.add_trace(trace, 1, 2)
754              fig.update_xaxes(title_text="Recall", row=1, col=2)
755          fig.update_layout(yaxis_title="Precision", showlegend=True)
756  
757          additional_plots.append((str(label), plotly_figure(title="", figure=fig)))
758      return additional_plots
759  
760  
761  def get_lift_plot_data(
762      current_lift_curve: LiftCurve,
763      reference_lift_curve: Optional[LiftCurve],
764      color_options: ColorOptions,
765  ) -> List[Tuple[str, BaseWidgetInfo]]:
766      """
767      Forms plot data for lift metric visualization
768  
769      Parameters
770      ----------
771      current_lift_curve: dict
772          Calculated lift table data for current sample
773      reference_lift_curve: Optional[dict]
774          Calculated lift table data for reference sample
775      color_options: ColorOptions
776          Standard Evidently class-collection of colors for data visualization
777  
778      Return values
779      -------------
780      additional_plots: List[Tuple[str, BaseWidgetInfo]]
781          Plot objects within List
782      """
783      additional_plots = []
784      cols = 1
785      subplot_titles = [""]
786      if reference_lift_curve is not None:
787          cols = 2
788          subplot_titles = ["current", "reference"]
789      for label in current_lift_curve.keys():
790          fig = make_subplots(rows=1, cols=cols, subplot_titles=subplot_titles, shared_yaxes=True)
791          trace = go.Scatter(
792              x=current_lift_curve[label].top,
793              y=current_lift_curve[label].lift,
794              mode="lines+markers",
795              name="Lift",
796              hoverinfo="text",
797              text=[
798                  f"top: {str(int(current_lift_curve[label].top[i]))}, " f"lift={str(current_lift_curve[label].lift[i])}"
799                  for i in range(100)
800              ],
801              legendgroup="Lift",
802              marker=dict(
803                  size=6,
804                  color=color_options.get_current_data_color(),
805              ),
806          )
807          fig.add_trace(trace, 1, 1)
808          fig.update_xaxes(title_text="Top", row=1, col=1)
809          if reference_lift_curve is not None:
810              trace = go.Scatter(
811                  x=reference_lift_curve[label].top,
812                  y=reference_lift_curve[label].lift,
813                  mode="lines+markers",
814                  name="Lift",
815                  hoverinfo="text",
816                  text=[
817                      f"top: {str(int(reference_lift_curve[label].top[i]))}, "
818                      f"lift={str(reference_lift_curve[label].lift[i])}"
819                      for i in range(100)
820                  ],
821                  legendgroup="Lift",
822                  showlegend=False,
823                  marker=dict(
824                      size=6,
825                      color=color_options.get_current_data_color(),
826                  ),
827              )
828              fig.add_trace(trace, 1, 2)
829              fig.update_xaxes(title_text="Top", row=1, col=2)
830          fig.update_layout(yaxis_title="Lift", showlegend=True)
831  
832          additional_plots.append((str(label), plotly_figure(title="", figure=fig)))
833      return additional_plots
834  
835  
836  def class_separation_traces_raw(df, label, target_name, color_options):
837      traces = []
838      traces.append(
839          go.Scatter(
840              x=np.random.random(df[df[target_name] == label].shape[0]),
841              y=df[df[target_name] == label][label],
842              mode="markers",
843              name=str(label),
844              legendgroup=str(label),
845              marker=dict(size=6, color=color_options.primary_color),
846          )
847      )
848      traces.append(
849          go.Scatter(
850              x=np.random.random(df[df[target_name] != label].shape[0]),
851              y=df[df[target_name] != label][label],
852              mode="markers",
853              name="other",
854              legendgroup="other",
855              marker=dict(size=6, color=color_options.secondary_color),
856          )
857      )
858      return traces
859  
860  
861  def class_separation_traces_agg(df, label, color_options):
862      traces = []
863      df_name = df[df["values"] == label]
864      traces.append(
865          go.Box(
866              lowerfence=df_name["mins"],
867              q1=df_name["lowers"],
868              q3=df_name["uppers"],
869              median=df_name["means"],
870              upperfence=df_name["maxs"],
871              x=df_name["values"].astype(str),
872              marker_color=color_options.get_current_data_color(),
873          )
874      )
875      df_name = df[df["values"] == "others"]
876      traces.append(
877          go.Box(
878              lowerfence=df_name["mins"],
879              q1=df_name["lowers"],
880              q3=df_name["uppers"],
881              median=df_name["means"],
882              upperfence=df_name["maxs"],
883              x=df_name["values"],
884              marker_color=color_options.get_reference_data_color(),
885          )
886      )
887      return traces
888  
889  
890  def get_class_separation_plot_data(
891      current_plot: pd.DataFrame, reference_plot: Optional[pd.DataFrame], target_name: str, color_options: ColorOptions
892  ) -> List[Tuple[str, BaseWidgetInfo]]:
893      additional_plots = []
894      cols = 1
895      subplot_titles = [""]
896      if reference_plot is not None:
897          cols = 2
898          subplot_titles = ["current", "reference"]
899      for label in current_plot.columns.drop(target_name):
900          fig = make_subplots(rows=1, cols=cols, subplot_titles=subplot_titles, shared_yaxes=True)
901          traces = class_separation_traces_raw(current_plot, label, target_name, color_options)
902          for trace in traces:
903              fig.add_trace(trace, 1, 1)
904          fig.update_xaxes(dict(range=(-2, 3), showticklabels=False), row=1, col=1)
905  
906          if reference_plot is not None:
907              traces = class_separation_traces_raw(reference_plot, label, target_name, color_options)
908              for trace in traces:
909                  fig.add_trace(trace, 1, 2)
910              fig.update_xaxes(dict(range=(-2, 3), showticklabels=False), row=1, col=2)
911  
912          fig.update_layout(yaxis_title="Probability", showlegend=True)
913  
914          additional_plots.append((str(label), plotly_figure(title="", figure=fig)))
915      return additional_plots
916  
917  
918  def get_class_separation_plot_data_agg(
919      current_plot: Dict[Label, pd.DataFrame],
920      reference_plot: Optional[Dict[Label, pd.DataFrame]],
921      target_name: str,
922      color_options: ColorOptions,
923  ) -> List[Tuple[str, BaseWidgetInfo]]:
924      additional_plots = []
925      cols = 1
926      subplot_titles = [""]
927      if reference_plot is not None:
928          cols = 2
929          subplot_titles = ["current", "reference"]
930      for label in current_plot.keys():
931          fig = make_subplots(rows=1, cols=cols, subplot_titles=subplot_titles, shared_yaxes=True)
932          traces = class_separation_traces_agg(current_plot[label], label, color_options)
933          for trace in traces:
934              fig.add_trace(trace, 1, 1)
935  
936          if reference_plot is not None:
937              traces = class_separation_traces_agg(reference_plot[label], label, color_options)
938              for trace in traces:
939                  fig.add_trace(trace, 1, 2)
940  
941          fig.update_layout(yaxis_title="Probability", showlegend=False)
942  
943          additional_plots.append((str(label), plotly_figure(title="", figure=fig)))
944      return additional_plots
945  
946  
947  def group_widget(
948      *,
949      title: str,
950      widgets: List[BaseWidgetInfo],
951  ) -> BaseWidgetInfo:
952      return BaseWidgetInfo(
953          title=title,
954          type=WidgetType.GROUP.value,
955          widgets=widgets,
956          size=2,
957      )
958  
959  
960  def rich_data(
961      *,
962      title: str,
963      description: str,
964      header: List[str],
965      metrics: List[dict],
966      graph: Optional[PlotlyGraphInfo],
967  ):
968      return BaseWidgetInfo(
969          type=WidgetType.RICH_DATA.value,
970          title="",
971          size=2,
972          params={
973              "header": title,
974              "description": description,
975              "metricsValuesHeaders": header,
976              "metrics": metrics,
977              "graph": graph,
978              "details": {"parts": [], "insights": []},
979          },
980          additionalGraphs=[],
981      )