/ src / evidently / ui / service / api / datasets.py
datasets.py
  1  import datetime
  2  import json
  3  from io import BytesIO
  4  from json import JSONDecodeError
  5  from typing import Callable
  6  from typing import Dict
  7  from typing import List
  8  from typing import Optional
  9  
 10  import pandas as pd
 11  from litestar import Router
 12  from litestar import delete
 13  from litestar import get
 14  from litestar import patch
 15  from litestar import post
 16  from litestar.datastructures import UploadFile
 17  from litestar.enums import RequestEncodingType
 18  from litestar.exceptions import HTTPException
 19  from litestar.exceptions import ValidationException
 20  from litestar.params import Body
 21  from litestar.params import Dependency
 22  from litestar.params import Parameter
 23  from litestar.response.base import ASGIResponse
 24  from typing_extensions import Annotated
 25  
 26  from evidently._pydantic_compat import BaseModel
 27  from evidently._pydantic_compat import Extra
 28  from evidently._pydantic_compat import ValidationError
 29  from evidently._pydantic_compat import parse_obj_as
 30  from evidently.core.datasets import DataDefinition
 31  from evidently.core.datasets import Dataset
 32  from evidently.legacy.suite.base_suite import MetadataValueType
 33  from evidently.legacy.ui.api.models import EvidentlyAPIModel
 34  from evidently.legacy.ui.type_aliases import ProjectID
 35  from evidently.legacy.ui.type_aliases import UserID
 36  from evidently.ui.service.datasets.data_source import DataSourceDTO
 37  from evidently.ui.service.datasets.data_source import SortBy
 38  from evidently.ui.service.datasets.filters import FilterBy
 39  from evidently.ui.service.datasets.metadata import DatasetMetadata
 40  from evidently.ui.service.datasets.metadata import DatasetMetadataFull
 41  from evidently.ui.service.datasets.metadata import DatasetOrigin
 42  from evidently.ui.service.datasets.models import DatasetPagination
 43  from evidently.ui.service.datasets.snapshot_links import SnapshotDatasetLinksManager
 44  from evidently.ui.service.managers.datasets import DatasetManager
 45  from evidently.ui.service.type_aliases import DatasetID
 46  from evidently.ui.service.type_aliases import SnapshotID
 47  
 48  
 49  class UploadDatasetRequest(BaseModel):
 50      """Request for uploading a dataset."""
 51  
 52      class Config:
 53          extra = Extra.forbid
 54          arbitrary_types_allowed = True
 55  
 56      name: str
 57      file: UploadFile
 58      description: Optional[str] = None
 59      data_definition_str: Optional[str] = None
 60      metadata_str: str = ""
 61      tags_str: str = ""
 62  
 63      @property
 64      def metadata(self) -> Dict[str, MetadataValueType]:
 65          """Parse metadata from string."""
 66          if not self.metadata_str:
 67              return {}
 68          return parse_obj_as(Dict[str, MetadataValueType], json.loads(self.metadata_str))
 69  
 70      @property
 71      def tags(self) -> List[str]:
 72          """Parse tags from string."""
 73          if not self.tags_str:
 74              return []
 75          return parse_obj_as(List[str], json.loads(self.tags_str))
 76  
 77      @property
 78      def data_definition(self) -> Optional[DataDefinition]:
 79          """Parse data definition from string."""
 80          if self.data_definition_str:
 81              return parse_obj_as(DataDefinition, json.loads(self.data_definition_str))
 82          return None
 83  
 84  
 85  class UploadDatasetResponse(EvidentlyAPIModel):
 86      """Response for dataset upload."""
 87  
 88      dataset: DatasetMetadata
 89  
 90  
 91  class PatchDatasetRequest(EvidentlyAPIModel):
 92      """Request for updating a dataset."""
 93  
 94      name: Optional[str] = None
 95      description: Optional[str] = None
 96      data_definition: Optional[DataDefinition] = None
 97      metadata: Optional[Dict[str, MetadataValueType]] = None
 98      tags: Optional[List[str]] = None
 99      human_feedback_custom_shortcut_labels: Optional[List[str]] = None
100  
101  
102  @post("/upload")
103  async def upload_dataset(
104      data: Annotated[UploadDatasetRequest, Body(media_type=RequestEncodingType.MULTI_PART)],
105      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
106      snapshot_dataset_links: Annotated["SnapshotDatasetLinksManager", Dependency(skip_validation=True)],
107      user_id: UserID,
108      project_id: ProjectID,
109      snapshot_id: Annotated[Optional[SnapshotID], Parameter(title="snapshot id")] = None,
110      dataset_type: Annotated[Optional[str], Parameter(title="dataset type (input/output)")] = None,
111      dataset_subtype: Annotated[Optional[str], Parameter(title="dataset subtype (current/reference)")] = None,
112  ) -> UploadDatasetResponse:
113      """Upload a dataset."""
114      if snapshot_id is not None:
115          if dataset_type is None or dataset_subtype is None:
116              raise HTTPException(
117                  status_code=400, detail="snapshot_id, dataset_type and dataset_subtype must be specified together"
118              )
119  
120      default_dd = (
121          Dataset.from_pandas(pd.DataFrame()).data_definition if data.data_definition is None else data.data_definition
122      )
123      dataset = await dataset_manager.upload_dataset(
124          user_id,
125          project_id,
126          data.name,
127          data.description,
128          data.file,
129          default_dd,
130          origin=DatasetOrigin.file,
131          metadata=data.metadata,
132          tags=data.tags,
133      )
134  
135      if (
136          snapshot_id is not None
137          and dataset_type is not None
138          and dataset_subtype is not None
139          and snapshot_dataset_links is not None
140      ):
141          await snapshot_dataset_links.link_dataset_snapshot(
142              project_id=project_id,
143              snapshot_id=snapshot_id,
144              dataset_id=dataset.id,
145              dataset_type=dataset_type,
146              dataset_subtype=dataset_subtype,
147          )
148  
149      return UploadDatasetResponse(dataset=dataset)
150  
151  
152  @patch("/{dataset_id:uuid}")
153  async def update_dataset(
154      dataset_id: Annotated[DatasetID, Parameter(title="dataset id")],
155      data: Annotated[PatchDatasetRequest, Body()],
156      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
157      user_id: UserID,
158  ) -> None:
159      """Update a dataset."""
160      await dataset_manager.update_dataset(
161          user_id,
162          dataset_id,
163          data.name,
164          data.description,
165          data.data_definition,
166          data.metadata,
167          data.tags,
168          data.human_feedback_custom_shortcut_labels,
169      )
170  
171  
172  @delete("/{dataset_id:uuid}")
173  async def delete_dataset(
174      dataset_id: Annotated[DatasetID, Parameter(title="dataset id")],
175      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
176      user_id: UserID,
177  ) -> None:
178      """Delete a dataset."""
179      await dataset_manager.delete_dataset(user_id, dataset_id)
180  
181  
182  @get("/{dataset_id:uuid}")
183  async def get_dataset(
184      dataset_id: Annotated[DatasetID, Parameter(title="dataset id")],
185      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
186      user_id: UserID,
187      page_size: Annotated[int, Parameter(gt=0)] = 10,
188      current_page: Annotated[int, Parameter(gt=0)] = 1,
189      sort_by_column: Optional[str] = None,
190      sort_ascending: Optional[bool] = True,
191      filters: Optional[List[str]] = None,
192  ) -> DatasetPagination:
193      """Get a dataset with pagination."""
194  
195      try:
196          filter_queries = [parse_obj_as(FilterBy, json.loads(_filter)) for _filter in filters] if filters else None
197      except (ValidationError, JSONDecodeError) as e:
198          raise ValidationException(str(e)) from e
199  
200      sort_by = SortBy(column=sort_by_column, ascending=sort_ascending) if sort_by_column else None
201      result = await dataset_manager.get_dataset_pagination(
202          user_id, dataset_id, page_size, current_page, sort_by, filter_queries
203      )
204      return result
205  
206  
207  @get("/{dataset_id:uuid}/download")
208  async def download_dataset(
209      dataset_id: Annotated[DatasetID, Parameter(title="dataset id")],
210      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
211      user_id: UserID,
212      format: str = "parquet+sdk",
213  ) -> ASGIResponse:
214      """Download a dataset."""
215      dataset = await dataset_manager.get_dataset_metadata(user_id, dataset_id)
216      df, _ = await dataset_manager.get_dataset(user_id, dataset_id)
217  
218      buf = BytesIO()
219      file_format: str
220      if "parquet" in format.split("+"):
221          df.to_parquet(buf)
222          file_format = "parquet"
223      elif "csv" in format.split("+"):
224          df.to_csv(buf, index=False)
225          file_format = "csv"
226      else:
227          raise HTTPException(status_code=400, detail="unsupported file format")
228  
229      metadata: str = dataset.json(exclude_none=True, exclude_defaults=True)
230  
231      boundary = "----LitestarBoundary123"
232  
233      def generate():
234          yield f"--{boundary}\r\n".encode()
235          yield b"Content-Type: application/json\r\n\r\n"
236          yield metadata.encode() + b"\r\n"
237  
238          yield f"--{boundary}\r\n".encode()
239          yield b"Content-Type: application/octet-stream\r\n\r\n"
240          yield buf.getvalue() + b"\r\n"
241  
242          yield f"--{boundary}--\r\n".encode()
243  
244      if "sdk" in format.split("+"):
245          return ASGIResponse(
246              body=b"".join(generate()),
247              media_type=f"multipart/mixed; boundary={boundary}",
248          )
249      return ASGIResponse(
250          body=buf.getvalue(),
251          media_type="application/octet-stream",
252          headers=[("Content-Disposition", f"attachment; filename={dataset_id}.{file_format}")],
253      )
254  
255  
256  class DatasetMetadataResponse(EvidentlyAPIModel):
257      """Response for dataset metadata."""
258  
259      id: DatasetID
260      project_id: ProjectID
261      name: str
262      description: str
263      size_bytes: int
264      created_at: datetime.datetime
265      author_name: str
266      row_count: int
267      column_count: int
268      origin: DatasetOrigin
269      metadata: Dict[str, MetadataValueType]
270      tags: List[str]
271  
272      @classmethod
273      def from_dataset_metadata(cls, dataset: DatasetMetadataFull):
274          """Create from DatasetMetadataFull."""
275          return cls(**{k: v for k, v in dataset.__dict__.items() if k in cls.__fields__})
276  
277  
278  class ListDatasetResponse(EvidentlyAPIModel):
279      """Response for listing datasets."""
280  
281      datasets: List[DatasetMetadataResponse]
282  
283  
284  @get("/")
285  async def list_datasets(
286      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
287      user_id: UserID,
288      project_id: ProjectID,
289      limit: Annotated[Optional[int], Parameter(title="Page size")] = None,
290      origin: Annotated[Optional[List[DatasetOrigin]], Parameter(schema_extra={"type": "array"})] = None,
291      draft: Annotated[Optional[bool], Parameter(title="Return draft datasets")] = False,
292  ) -> ListDatasetResponse:
293      """List datasets."""
294      datasets = await dataset_manager.list_datasets(user_id, project_id, limit, origin, draft)
295      return ListDatasetResponse(datasets=[DatasetMetadataResponse.from_dataset_metadata(d) for d in datasets])
296  
297  
298  @get("/{dataset_id:uuid}/metadata")
299  async def get_dataset_metadata(
300      dataset_id: Annotated[DatasetID, Parameter(title="dataset id")],
301      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
302      user_id: UserID,
303  ) -> DatasetMetadata:
304      """Get dataset metadata."""
305      dataset = await dataset_manager.get_dataset_metadata(user_id, dataset_id)
306      return dataset
307  
308  
309  class DatasetDataDefinitionResponse(EvidentlyAPIModel):
310      """Response for data definition."""
311  
312      data_definition: DataDefinition
313      all_columns: List[str]
314  
315  
316  @get("/{dataset_id:uuid}/data_definition")
317  async def get_data_definition(
318      dataset_id: Annotated[DatasetID, Parameter()],
319      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
320      user_id: UserID,
321  ) -> DatasetDataDefinitionResponse:
322      """Get data definition for a dataset."""
323      dataset = await dataset_manager.get_dataset_metadata(user_id, dataset_id)
324      return DatasetDataDefinitionResponse(data_definition=dataset.data_definition, all_columns=dataset.all_columns)
325  
326  
327  class MaterializeDatasetRequest(BaseModel):
328      """Request for materializing a dataset from a source."""
329  
330      name: str
331      description: Optional[str] = None
332      source: DataSourceDTO
333      metadata: Dict[str, MetadataValueType] = {}
334      tags: List[str] = []
335  
336  
337  class MaterializeDatasetResponse(EvidentlyAPIModel):
338      """Response for materializing a dataset."""
339  
340      dataset_id: DatasetID
341  
342  
343  class AddTracingDatasetRequest(BaseModel):
344      """Request for creating a tracing dataset."""
345  
346      name: str
347  
348      class Config:
349          extra = "forbid"
350  
351  
352  class AddTracingDatasetResponse(EvidentlyAPIModel):
353      """Response for creating a tracing dataset."""
354  
355      dataset_id: DatasetID
356      external_dataset_id: str = ""
357  
358  
359  @post("/tracing")
360  async def add_tracing_dataset(
361      data: AddTracingDatasetRequest,
362      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
363      user_id: UserID,
364      project_id: ProjectID,
365  ) -> AddTracingDatasetResponse:
366      """Create a tracing dataset."""
367      dataset = await dataset_manager.create_tracing_dataset(
368          user_id=user_id,
369          project_id=project_id,
370          name=data.name,
371      )
372      return AddTracingDatasetResponse(dataset_id=dataset.id, external_dataset_id="")
373  
374  
375  @post("/materialize")
376  async def materialize_from_source(
377      data: MaterializeDatasetRequest,
378      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
379      user_id: UserID,
380      project_id: ProjectID,
381  ) -> MaterializeDatasetResponse:
382      """Materialize a dataset from a data source."""
383      df = await data.source.to_data_source(user_id=user_id, project_id=project_id).materialize(dataset_manager)
384      dataset = await dataset_manager.upload_dataset(
385          user_id=user_id,
386          project_id=project_id,
387          name=data.name,
388          description=data.description,
389          data=df,
390          data_definition=Dataset.from_pandas(df).data_definition,
391          origin=DatasetOrigin.dataset,
392          metadata=data.metadata,
393          tags=data.tags,
394      )
395      return MaterializeDatasetResponse(dataset_id=dataset.id)
396  
397  
398  def datasets_api(guard: Callable) -> Router:
399      """Create datasets API router."""
400      return Router(
401          "/datasets",
402          route_handlers=[
403              # read
404              Router(
405                  "",
406                  route_handlers=[
407                      get_dataset,
408                      list_datasets,
409                      download_dataset,
410                      get_dataset_metadata,
411                      get_data_definition,
412                  ],
413              ),
414              # write
415              Router(
416                  "",
417                  route_handlers=[
418                      upload_dataset,
419                      update_dataset,
420                      delete_dataset,
421                      materialize_from_source,
422                      add_tracing_dataset,
423                  ],
424                  guards=[guard],
425              ),
426          ],
427      )