/ src / evidently / legacy / renderers / base_renderer.py
base_renderer.py
  1  import dataclasses
  2  import warnings
  3  from typing import TYPE_CHECKING
  4  from typing import Any
  5  from typing import Dict
  6  from typing import Generic
  7  from typing import List
  8  from typing import Optional
  9  from typing import TypeVar
 10  from typing import Union
 11  
 12  import pandas as pd
 13  import uuid6
 14  
 15  from evidently._pydantic_compat import BaseModel
 16  from evidently._pydantic_compat import Field
 17  from evidently.legacy.model.widget import AdditionalGraphInfo
 18  from evidently.legacy.model.widget import BaseWidgetInfo
 19  from evidently.legacy.model.widget import PlotlyGraphInfo
 20  from evidently.legacy.options import ColorOptions
 21  
 22  if TYPE_CHECKING:
 23      from evidently.legacy.base_metric import Metric
 24      from evidently.legacy.core import IncludeOptions
 25      from evidently.legacy.tests.base_test import Test
 26  
 27  
 28  class BaseRenderer:
 29      """Base class for all renderers"""
 30  
 31      color_options: ColorOptions
 32  
 33      def __init__(self, color_options: Optional[ColorOptions] = None) -> None:
 34          if color_options is None:
 35              self.color_options = ColorOptions()
 36  
 37          else:
 38              self.color_options = color_options
 39  
 40  
 41  TMetric = TypeVar("TMetric", bound="Metric")
 42  
 43  
 44  class MetricRenderer(Generic[TMetric], BaseRenderer):
 45      def render_pandas(self, obj: TMetric) -> pd.DataFrame:
 46          result = obj.get_result()
 47          if not result.__config__.pd_include:
 48              warnings.warn(
 49                  f"{obj.get_id()} metric does not support as_dataframe yet. Please submit an issue to https://github.com/evidentlyai/evidently/issues"
 50              )
 51              return pd.DataFrame()
 52          return result.get_pandas()
 53  
 54      def render_json(
 55          self,
 56          obj: TMetric,
 57          include_render: bool = False,
 58          include: "IncludeOptions" = None,
 59          exclude: "IncludeOptions" = None,
 60      ) -> dict:
 61          result = obj.get_result()
 62          return result.get_dict(include_render=include_render, include=include, exclude=exclude)
 63  
 64      def render_html(self, obj: TMetric) -> List[BaseWidgetInfo]:
 65          raise NotImplementedError()
 66  
 67  
 68  class DetailsInfo(BaseModel):
 69      title: str
 70      info: Union[BaseWidgetInfo, Any]
 71      id: str = Field(default_factory=lambda: str(uuid6.uuid7()))
 72  
 73  
 74  class TestHtmlInfo(BaseModel):
 75      name: str
 76      description: str
 77      test_fingerprint: str
 78      status: str
 79      details: List[DetailsInfo]
 80      groups: Dict[str, str]
 81  
 82      def with_details(self, title: str, info: BaseWidgetInfo):
 83          self.details.append(DetailsInfo(title=title, info=info))
 84          return self
 85  
 86  
 87  TTest = TypeVar("TTest", bound="Test")
 88  
 89  
 90  class TestRenderer(Generic[TTest], BaseRenderer):
 91      def html_description(self, obj: TTest):
 92          return obj.get_result().description
 93  
 94      def json_description(self, obj: TTest):
 95          return obj.get_result().description
 96  
 97      def render_html(self, obj: TTest) -> TestHtmlInfo:
 98          result = obj.get_result()
 99          return TestHtmlInfo(
100              name=result.name,
101              description=self.html_description(obj),
102              test_fingerprint=obj.get_fingerprint(),
103              status=result.status.value,
104              details=[],
105              groups=obj.get_groups(),
106          )
107  
108      def render_json(
109          self,
110          obj: TTest,
111          include_render: bool = False,
112          include: "IncludeOptions" = None,
113          exclude: "IncludeOptions" = None,
114      ) -> dict:
115          return obj.get_result().get_dict(include_render=include_render, include=include, exclude=exclude)
116  
117  
118  @dataclasses.dataclass
119  class RenderersDefinitions:
120      typed_renderers: dict = dataclasses.field(default_factory=dict)
121      default_html_test_renderer: Optional[TestRenderer] = None
122      default_html_metric_renderer: Optional[MetricRenderer] = None
123  
124  
125  def default_renderer(wrap_type):
126      def wrapper(cls):
127          DEFAULT_RENDERERS.typed_renderers[wrap_type] = cls()
128          return cls
129  
130      return wrapper
131  
132  
133  DEFAULT_RENDERERS = RenderersDefinitions(default_html_test_renderer=TestRenderer())
134  
135  
136  class WidgetIdGenerator:
137      def __init__(self, base_id: str):
138          self.base_id = base_id
139          self.counter = 0
140  
141      def get_id(self, postfix: str = None) -> str:
142          val = f"{self.base_id}-{self.counter}"
143          if postfix is not None:
144              val = f"{val}-{postfix}"
145          self.counter += 1
146          return val
147  
148  
149  def replace_widgets_ids(widgets: List[BaseWidgetInfo], generator: WidgetIdGenerator):
150      for widget in widgets:
151          replace_widget_ids(widget, generator)
152  
153  
154  def replace_test_widget_ids(widget: TestHtmlInfo, generator: WidgetIdGenerator):
155      for detail in widget.details:
156          detail.id = generator.get_id()
157          replace_widget_ids(detail.info, generator)
158  
159  
160  def replace_widget_ids(widget: BaseWidgetInfo, generator: WidgetIdGenerator):
161      widget.id = generator.get_id()
162  
163      add_graph_id_mapping: Dict[str, Union[BaseWidgetInfo, AdditionalGraphInfo, PlotlyGraphInfo]] = {}
164      for add_graph in widget.additionalGraphs:
165          if isinstance(add_graph, BaseWidgetInfo):
166              add_graph_id_mapping[add_graph.id] = add_graph
167              replace_widget_ids(add_graph, generator)
168          elif isinstance(add_graph, (AdditionalGraphInfo, PlotlyGraphInfo)):
169              add_graph_id_mapping[add_graph.id] = add_graph
170              add_graph.id = generator.get_id(add_graph.id.replace(" ", "-"))
171          else:
172              raise ValueError(f"Unknown add graph type {add_graph.__class__.__name__}")
173  
174      parts = []
175      if isinstance(widget.params, dict):
176          if "data" in widget.params:
177              data = widget.params["data"]
178              for item in data:
179                  if "details" in item and "parts" in item["details"]:
180                      parts.extend(item["details"]["parts"])
181  
182          if "details" in widget.params:
183              details = widget.params["details"]
184              if "parts" in details:
185                  parts.extend(details["parts"])
186  
187      for part in parts:
188          if "id" in part:
189              widget_id = part["id"]
190              if widget_id in add_graph_id_mapping:
191                  part["id"] = add_graph_id_mapping[widget_id].id
192  
193      for w in widget.widgets:
194          replace_widget_ids(w, generator)