/ src / evidently / legacy / ui / demo_projects / bikes_v2.py
bikes_v2.py
  1  from datetime import datetime
  2  from datetime import timedelta
  3  from typing import Tuple
  4  
  5  from pandas import DataFrame
  6  
  7  from evidently.core.datasets import DataDefinition
  8  from evidently.core.datasets import Dataset
  9  from evidently.core.report import Report
 10  from evidently.legacy import metrics
 11  from evidently.legacy.pipeline.column_mapping import ColumnMapping
 12  from evidently.legacy.renderers.html_widgets import WidgetSize
 13  from evidently.legacy.ui.dashboards import CounterAgg
 14  from evidently.legacy.ui.dashboards import DashboardPanelCounter
 15  from evidently.legacy.ui.dashboards import DashboardPanelPlot
 16  from evidently.legacy.ui.dashboards import PanelValue
 17  from evidently.legacy.ui.dashboards import PlotType
 18  from evidently.legacy.ui.dashboards import ReportFilter
 19  from evidently.legacy.ui.demo_projects import DemoProject
 20  from evidently.legacy.ui.demo_projects.bikes import create_data
 21  from evidently.legacy.ui.workspace.base import WorkspaceBase
 22  from evidently.metrics.column_statistics import MaxValue
 23  from evidently.metrics.column_statistics import MeanValue
 24  from evidently.metrics.column_statistics import MedianValue
 25  from evidently.metrics.column_statistics import MinValue
 26  from evidently.metrics.column_statistics import QuantileValue
 27  from evidently.metrics.column_statistics import StdValue
 28  from evidently.tests import gte
 29  from evidently.tests import lte
 30  from evidently.ui.backport import snapshot_v2_to_v1
 31  
 32  
 33  def create_snapshot(i: int, data: Tuple[DataFrame, DataFrame, ColumnMapping]):
 34      current, reference, column_mapping = data
 35      if column_mapping.numerical_features is None or len(column_mapping.numerical_features) < 1:
 36          raise ValueError("ColumnMapping must have at least one numerical feature")
 37  
 38      report = Report(
 39          [
 40              MinValue(column=column_mapping.numerical_features[0], tests=[lte(0.2), gte(2)]),
 41              MaxValue(column=column_mapping.numerical_features[0]),
 42              MedianValue(column=column_mapping.numerical_features[0]),
 43              MeanValue(column=column_mapping.numerical_features[0]),
 44              StdValue(column=column_mapping.numerical_features[0]),
 45              QuantileValue(column=column_mapping.numerical_features[0]),
 46              QuantileValue(column=column_mapping.numerical_features[0], quantile=0.95),
 47          ]
 48      )
 49  
 50      # report.set_batch_size("daily")
 51  
 52      data_chunk = current.loc[datetime(2023, 1, 29) + timedelta(days=i) : datetime(2023, 1, 29) + timedelta(i + 1)]
 53  
 54      dataset = Dataset.from_pandas(
 55          data=data_chunk,
 56          data_definition=DataDefinition(
 57              numerical_columns=column_mapping.numerical_features,
 58              categorical_columns=column_mapping.categorical_features,
 59              text_columns=column_mapping.text_features,
 60          ),
 61      )
 62  
 63      snapshot = report.run(dataset, None)
 64  
 65      return snapshot_v2_to_v1(snapshot)
 66  
 67  
 68  def noop():
 69      pass
 70  
 71  
 72  def create_project(workspace: WorkspaceBase, name: str):
 73      project = workspace.create_project(name)
 74      project.description = "A toy demo project using Bike Demand forecasting dataset"
 75  
 76      # feel free to change
 77      is_create_dashboard = False
 78  
 79      if is_create_dashboard:
 80          project.dashboard.add_panel(
 81              DashboardPanelCounter(
 82                  filter=ReportFilter(metadata_values={}, tag_values=[]),
 83                  agg=CounterAgg.NONE,
 84                  title="Bike Rental Demand Forecast",
 85              )
 86          )
 87  
 88          project.dashboard.add_panel(
 89              DashboardPanelCounter(
 90                  title="Model Calls",
 91                  filter=ReportFilter(metadata_values={}, tag_values=[]),
 92                  value=PanelValue(
 93                      metric_id="DatasetSummaryMetric",
 94                      field_path=metrics.DatasetSummaryMetric.fields.current.number_of_rows,
 95                      legend="count",
 96                  ),
 97                  text="count",
 98                  agg=CounterAgg.SUM,
 99                  size=WidgetSize.HALF,
100              )
101          )
102          project.dashboard.add_panel(
103              DashboardPanelCounter(
104                  title="Share of Drifted Features",
105                  filter=ReportFilter(metadata_values={}, tag_values=[]),
106                  value=PanelValue(
107                      metric_id="DatasetDriftMetric",
108                      field_path="share_of_drifted_columns",
109                      legend="share",
110                  ),
111                  text="share",
112                  agg=CounterAgg.LAST,
113                  size=WidgetSize.HALF,
114              )
115          )
116          project.dashboard.add_panel(
117              DashboardPanelPlot(
118                  title="Target and Prediction",
119                  filter=ReportFilter(metadata_values={}, tag_values=[]),
120                  values=[
121                      PanelValue(
122                          metric_id="ColumnSummaryMetric",
123                          field_path="current_characteristics.mean",
124                          metric_args={"column_name.name": "cnt"},
125                          legend="Target (daily mean)",
126                      ),
127                      PanelValue(
128                          metric_id="ColumnSummaryMetric",
129                          field_path="current_characteristics.mean",
130                          metric_args={"column_name.name": "prediction"},
131                          legend="Prediction (daily mean)",
132                      ),
133                  ],
134                  plot_type=PlotType.LINE,
135                  size=WidgetSize.FULL,
136              )
137          )
138          project.dashboard.add_panel(
139              DashboardPanelPlot(
140                  title="MAE",
141                  filter=ReportFilter(metadata_values={}, tag_values=[]),
142                  values=[
143                      PanelValue(
144                          metric_id="RegressionQualityMetric",
145                          field_path=metrics.RegressionQualityMetric.fields.current.mean_abs_error,
146                          legend="MAE",
147                      ),
148                  ],
149                  plot_type=PlotType.LINE,
150                  size=WidgetSize.HALF,
151              )
152          )
153          project.dashboard.add_panel(
154              DashboardPanelPlot(
155                  title="MAPE",
156                  filter=ReportFilter(metadata_values={}, tag_values=[]),
157                  values=[
158                      PanelValue(
159                          metric_id="RegressionQualityMetric",
160                          field_path=metrics.RegressionQualityMetric.fields.current.mean_abs_perc_error,
161                          legend="MAPE",
162                      ),
163                  ],
164                  plot_type=PlotType.LINE,
165                  size=WidgetSize.HALF,
166              )
167          )
168          project.dashboard.add_panel(
169              DashboardPanelPlot(
170                  title="Features Drift (Wasserstein Distance)",
171                  filter=ReportFilter(metadata_values={}, tag_values=[]),
172                  values=[
173                      PanelValue(
174                          metric_id="ColumnDriftMetric",
175                          metric_args={"column_name.name": "temp"},
176                          field_path=metrics.ColumnDriftMetric.fields.drift_score,
177                          legend="temp",
178                      ),
179                      PanelValue(
180                          metric_id="ColumnDriftMetric",
181                          metric_args={"column_name.name": "atemp"},
182                          field_path=metrics.ColumnDriftMetric.fields.drift_score,
183                          legend="atemp",
184                      ),
185                      PanelValue(
186                          metric_id="ColumnDriftMetric",
187                          metric_args={"column_name.name": "hum"},
188                          field_path=metrics.ColumnDriftMetric.fields.drift_score,
189                          legend="hum",
190                      ),
191                      PanelValue(
192                          metric_id="ColumnDriftMetric",
193                          metric_args={"column_name.name": "windspeed"},
194                          field_path=metrics.ColumnDriftMetric.fields.drift_score,
195                          legend="windspeed",
196                      ),
197                  ],
198                  plot_type=PlotType.LINE,
199                  size=WidgetSize.FULL,
200              )
201          )
202          project.save()
203      return project
204  
205  
206  bikes_v2_demo_project = DemoProject(
207      name="Demo project - Bikes v2",
208      create_data=create_data,
209      create_snapshot=create_snapshot,
210      create_report=None,
211      create_test_suite=None,
212      create_project=create_project,
213      count=28,
214  )