base_renderer.py
1 import dataclasses 2 import warnings 3 from typing import TYPE_CHECKING 4 from typing import Any 5 from typing import Dict 6 from typing import Generic 7 from typing import List 8 from typing import Optional 9 from typing import TypeVar 10 from typing import Union 11 12 import pandas as pd 13 import uuid6 14 15 from evidently._pydantic_compat import BaseModel 16 from evidently._pydantic_compat import Field 17 from evidently.legacy.model.widget import AdditionalGraphInfo 18 from evidently.legacy.model.widget import BaseWidgetInfo 19 from evidently.legacy.model.widget import PlotlyGraphInfo 20 from evidently.legacy.options import ColorOptions 21 22 if TYPE_CHECKING: 23 from evidently.legacy.base_metric import Metric 24 from evidently.legacy.core import IncludeOptions 25 from evidently.legacy.tests.base_test import Test 26 27 28 class BaseRenderer: 29 """Base class for all renderers""" 30 31 color_options: ColorOptions 32 33 def __init__(self, color_options: Optional[ColorOptions] = None) -> None: 34 if color_options is None: 35 self.color_options = ColorOptions() 36 37 else: 38 self.color_options = color_options 39 40 41 TMetric = TypeVar("TMetric", bound="Metric") 42 43 44 class MetricRenderer(Generic[TMetric], BaseRenderer): 45 def render_pandas(self, obj: TMetric) -> pd.DataFrame: 46 result = obj.get_result() 47 if not result.__config__.pd_include: 48 warnings.warn( 49 f"{obj.get_id()} metric does not support as_dataframe yet. Please submit an issue to https://github.com/evidentlyai/evidently/issues" 50 ) 51 return pd.DataFrame() 52 return result.get_pandas() 53 54 def render_json( 55 self, 56 obj: TMetric, 57 include_render: bool = False, 58 include: "IncludeOptions" = None, 59 exclude: "IncludeOptions" = None, 60 ) -> dict: 61 result = obj.get_result() 62 return result.get_dict(include_render=include_render, include=include, exclude=exclude) 63 64 def render_html(self, obj: TMetric) -> List[BaseWidgetInfo]: 65 raise NotImplementedError() 66 67 68 class DetailsInfo(BaseModel): 69 title: str 70 info: Union[BaseWidgetInfo, Any] 71 id: str = Field(default_factory=lambda: str(uuid6.uuid7())) 72 73 74 class TestHtmlInfo(BaseModel): 75 name: str 76 description: str 77 test_fingerprint: str 78 status: str 79 details: List[DetailsInfo] 80 groups: Dict[str, str] 81 82 def with_details(self, title: str, info: BaseWidgetInfo): 83 self.details.append(DetailsInfo(title=title, info=info)) 84 return self 85 86 87 TTest = TypeVar("TTest", bound="Test") 88 89 90 class TestRenderer(Generic[TTest], BaseRenderer): 91 def html_description(self, obj: TTest): 92 return obj.get_result().description 93 94 def json_description(self, obj: TTest): 95 return obj.get_result().description 96 97 def render_html(self, obj: TTest) -> TestHtmlInfo: 98 result = obj.get_result() 99 return TestHtmlInfo( 100 name=result.name, 101 description=self.html_description(obj), 102 test_fingerprint=obj.get_fingerprint(), 103 status=result.status.value, 104 details=[], 105 groups=obj.get_groups(), 106 ) 107 108 def render_json( 109 self, 110 obj: TTest, 111 include_render: bool = False, 112 include: "IncludeOptions" = None, 113 exclude: "IncludeOptions" = None, 114 ) -> dict: 115 return obj.get_result().get_dict(include_render=include_render, include=include, exclude=exclude) 116 117 118 @dataclasses.dataclass 119 class RenderersDefinitions: 120 typed_renderers: dict = dataclasses.field(default_factory=dict) 121 default_html_test_renderer: Optional[TestRenderer] = None 122 default_html_metric_renderer: Optional[MetricRenderer] = None 123 124 125 def default_renderer(wrap_type): 126 def wrapper(cls): 127 DEFAULT_RENDERERS.typed_renderers[wrap_type] = cls() 128 return cls 129 130 return wrapper 131 132 133 DEFAULT_RENDERERS = RenderersDefinitions(default_html_test_renderer=TestRenderer()) 134 135 136 class WidgetIdGenerator: 137 def __init__(self, base_id: str): 138 self.base_id = base_id 139 self.counter = 0 140 141 def get_id(self, postfix: str = None) -> str: 142 val = f"{self.base_id}-{self.counter}" 143 if postfix is not None: 144 val = f"{val}-{postfix}" 145 self.counter += 1 146 return val 147 148 149 def replace_widgets_ids(widgets: List[BaseWidgetInfo], generator: WidgetIdGenerator): 150 for widget in widgets: 151 replace_widget_ids(widget, generator) 152 153 154 def replace_test_widget_ids(widget: TestHtmlInfo, generator: WidgetIdGenerator): 155 for detail in widget.details: 156 detail.id = generator.get_id() 157 replace_widget_ids(detail.info, generator) 158 159 160 def replace_widget_ids(widget: BaseWidgetInfo, generator: WidgetIdGenerator): 161 widget.id = generator.get_id() 162 163 add_graph_id_mapping: Dict[str, Union[BaseWidgetInfo, AdditionalGraphInfo, PlotlyGraphInfo]] = {} 164 for add_graph in widget.additionalGraphs: 165 if isinstance(add_graph, BaseWidgetInfo): 166 add_graph_id_mapping[add_graph.id] = add_graph 167 replace_widget_ids(add_graph, generator) 168 elif isinstance(add_graph, (AdditionalGraphInfo, PlotlyGraphInfo)): 169 add_graph_id_mapping[add_graph.id] = add_graph 170 add_graph.id = generator.get_id(add_graph.id.replace(" ", "-")) 171 else: 172 raise ValueError(f"Unknown add graph type {add_graph.__class__.__name__}") 173 174 parts = [] 175 if isinstance(widget.params, dict): 176 if "data" in widget.params: 177 data = widget.params["data"] 178 for item in data: 179 if "details" in item and "parts" in item["details"]: 180 parts.extend(item["details"]["parts"]) 181 182 if "details" in widget.params: 183 details = widget.params["details"] 184 if "parts" in details: 185 parts.extend(details["parts"]) 186 187 for part in parts: 188 if "id" in part: 189 widget_id = part["id"] 190 if widget_id in add_graph_id_mapping: 191 part["id"] = add_graph_id_mapping[widget_id].id 192 193 for w in widget.widgets: 194 replace_widget_ids(w, generator)