datasets.py
1 import datetime 2 import json 3 from io import BytesIO 4 from json import JSONDecodeError 5 from typing import Callable 6 from typing import Dict 7 from typing import List 8 from typing import Optional 9 10 import pandas as pd 11 from litestar import Router 12 from litestar import delete 13 from litestar import get 14 from litestar import patch 15 from litestar import post 16 from litestar.datastructures import UploadFile 17 from litestar.enums import RequestEncodingType 18 from litestar.exceptions import HTTPException 19 from litestar.exceptions import ValidationException 20 from litestar.params import Body 21 from litestar.params import Dependency 22 from litestar.params import Parameter 23 from litestar.response.base import ASGIResponse 24 from typing_extensions import Annotated 25 26 from evidently._pydantic_compat import BaseModel 27 from evidently._pydantic_compat import Extra 28 from evidently._pydantic_compat import ValidationError 29 from evidently._pydantic_compat import parse_obj_as 30 from evidently.core.datasets import DataDefinition 31 from evidently.core.datasets import Dataset 32 from evidently.legacy.suite.base_suite import MetadataValueType 33 from evidently.legacy.ui.api.models import EvidentlyAPIModel 34 from evidently.legacy.ui.type_aliases import ProjectID 35 from evidently.legacy.ui.type_aliases import UserID 36 from evidently.ui.service.datasets.data_source import DataSourceDTO 37 from evidently.ui.service.datasets.data_source import SortBy 38 from evidently.ui.service.datasets.filters import FilterBy 39 from evidently.ui.service.datasets.metadata import DatasetMetadata 40 from evidently.ui.service.datasets.metadata import DatasetMetadataFull 41 from evidently.ui.service.datasets.metadata import DatasetOrigin 42 from evidently.ui.service.datasets.models import DatasetPagination 43 from evidently.ui.service.datasets.snapshot_links import SnapshotDatasetLinksManager 44 from evidently.ui.service.managers.datasets import DatasetManager 45 from evidently.ui.service.type_aliases import DatasetID 46 from evidently.ui.service.type_aliases import SnapshotID 47 48 49 class UploadDatasetRequest(BaseModel): 50 """Request for uploading a dataset.""" 51 52 class Config: 53 extra = Extra.forbid 54 arbitrary_types_allowed = True 55 56 name: str 57 file: UploadFile 58 description: Optional[str] = None 59 data_definition_str: Optional[str] = None 60 metadata_str: str = "" 61 tags_str: str = "" 62 63 @property 64 def metadata(self) -> Dict[str, MetadataValueType]: 65 """Parse metadata from string.""" 66 if not self.metadata_str: 67 return {} 68 return parse_obj_as(Dict[str, MetadataValueType], json.loads(self.metadata_str)) 69 70 @property 71 def tags(self) -> List[str]: 72 """Parse tags from string.""" 73 if not self.tags_str: 74 return [] 75 return parse_obj_as(List[str], json.loads(self.tags_str)) 76 77 @property 78 def data_definition(self) -> Optional[DataDefinition]: 79 """Parse data definition from string.""" 80 if self.data_definition_str: 81 return parse_obj_as(DataDefinition, json.loads(self.data_definition_str)) 82 return None 83 84 85 class UploadDatasetResponse(EvidentlyAPIModel): 86 """Response for dataset upload.""" 87 88 dataset: DatasetMetadata 89 90 91 class PatchDatasetRequest(EvidentlyAPIModel): 92 """Request for updating a dataset.""" 93 94 name: Optional[str] = None 95 description: Optional[str] = None 96 data_definition: Optional[DataDefinition] = None 97 metadata: Optional[Dict[str, MetadataValueType]] = None 98 tags: Optional[List[str]] = None 99 human_feedback_custom_shortcut_labels: Optional[List[str]] = None 100 101 102 @post("/upload") 103 async def upload_dataset( 104 data: Annotated[UploadDatasetRequest, Body(media_type=RequestEncodingType.MULTI_PART)], 105 dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)], 106 snapshot_dataset_links: Annotated["SnapshotDatasetLinksManager", Dependency(skip_validation=True)], 107 user_id: UserID, 108 project_id: ProjectID, 109 snapshot_id: Annotated[Optional[SnapshotID], Parameter(title="snapshot id")] = None, 110 dataset_type: Annotated[Optional[str], Parameter(title="dataset type (input/output)")] = None, 111 dataset_subtype: Annotated[Optional[str], Parameter(title="dataset subtype (current/reference)")] = None, 112 ) -> UploadDatasetResponse: 113 """Upload a dataset.""" 114 if snapshot_id is not None: 115 if dataset_type is None or dataset_subtype is None: 116 raise HTTPException( 117 status_code=400, detail="snapshot_id, dataset_type and dataset_subtype must be specified together" 118 ) 119 120 default_dd = ( 121 Dataset.from_pandas(pd.DataFrame()).data_definition if data.data_definition is None else data.data_definition 122 ) 123 dataset = await dataset_manager.upload_dataset( 124 user_id, 125 project_id, 126 data.name, 127 data.description, 128 data.file, 129 default_dd, 130 origin=DatasetOrigin.file, 131 metadata=data.metadata, 132 tags=data.tags, 133 ) 134 135 if ( 136 snapshot_id is not None 137 and dataset_type is not None 138 and dataset_subtype is not None 139 and snapshot_dataset_links is not None 140 ): 141 await snapshot_dataset_links.link_dataset_snapshot( 142 project_id=project_id, 143 snapshot_id=snapshot_id, 144 dataset_id=dataset.id, 145 dataset_type=dataset_type, 146 dataset_subtype=dataset_subtype, 147 ) 148 149 return UploadDatasetResponse(dataset=dataset) 150 151 152 @patch("/{dataset_id:uuid}") 153 async def update_dataset( 154 dataset_id: Annotated[DatasetID, Parameter(title="dataset id")], 155 data: Annotated[PatchDatasetRequest, Body()], 156 dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)], 157 user_id: UserID, 158 ) -> None: 159 """Update a dataset.""" 160 await dataset_manager.update_dataset( 161 user_id, 162 dataset_id, 163 data.name, 164 data.description, 165 data.data_definition, 166 data.metadata, 167 data.tags, 168 data.human_feedback_custom_shortcut_labels, 169 ) 170 171 172 @delete("/{dataset_id:uuid}") 173 async def delete_dataset( 174 dataset_id: Annotated[DatasetID, Parameter(title="dataset id")], 175 dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)], 176 user_id: UserID, 177 ) -> None: 178 """Delete a dataset.""" 179 await dataset_manager.delete_dataset(user_id, dataset_id) 180 181 182 @get("/{dataset_id:uuid}") 183 async def get_dataset( 184 dataset_id: Annotated[DatasetID, Parameter(title="dataset id")], 185 dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)], 186 user_id: UserID, 187 page_size: Annotated[int, Parameter(gt=0)] = 10, 188 current_page: Annotated[int, Parameter(gt=0)] = 1, 189 sort_by_column: Optional[str] = None, 190 sort_ascending: Optional[bool] = True, 191 filters: Optional[List[str]] = None, 192 ) -> DatasetPagination: 193 """Get a dataset with pagination.""" 194 195 try: 196 filter_queries = [parse_obj_as(FilterBy, json.loads(_filter)) for _filter in filters] if filters else None 197 except (ValidationError, JSONDecodeError) as e: 198 raise ValidationException(str(e)) from e 199 200 sort_by = SortBy(column=sort_by_column, ascending=sort_ascending) if sort_by_column else None 201 result = await dataset_manager.get_dataset_pagination( 202 user_id, dataset_id, page_size, current_page, sort_by, filter_queries 203 ) 204 return result 205 206 207 @get("/{dataset_id:uuid}/download") 208 async def download_dataset( 209 dataset_id: Annotated[DatasetID, Parameter(title="dataset id")], 210 dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)], 211 user_id: UserID, 212 format: str = "parquet+sdk", 213 ) -> ASGIResponse: 214 """Download a dataset.""" 215 dataset = await dataset_manager.get_dataset_metadata(user_id, dataset_id) 216 df, _ = await dataset_manager.get_dataset(user_id, dataset_id) 217 218 buf = BytesIO() 219 file_format: str 220 if "parquet" in format.split("+"): 221 df.to_parquet(buf) 222 file_format = "parquet" 223 elif "csv" in format.split("+"): 224 df.to_csv(buf, index=False) 225 file_format = "csv" 226 else: 227 raise HTTPException(status_code=400, detail="unsupported file format") 228 229 metadata: str = dataset.json(exclude_none=True, exclude_defaults=True) 230 231 boundary = "----LitestarBoundary123" 232 233 def generate(): 234 yield f"--{boundary}\r\n".encode() 235 yield b"Content-Type: application/json\r\n\r\n" 236 yield metadata.encode() + b"\r\n" 237 238 yield f"--{boundary}\r\n".encode() 239 yield b"Content-Type: application/octet-stream\r\n\r\n" 240 yield buf.getvalue() + b"\r\n" 241 242 yield f"--{boundary}--\r\n".encode() 243 244 if "sdk" in format.split("+"): 245 return ASGIResponse( 246 body=b"".join(generate()), 247 media_type=f"multipart/mixed; boundary={boundary}", 248 ) 249 return ASGIResponse( 250 body=buf.getvalue(), 251 media_type="application/octet-stream", 252 headers=[("Content-Disposition", f"attachment; filename={dataset_id}.{file_format}")], 253 ) 254 255 256 class DatasetMetadataResponse(EvidentlyAPIModel): 257 """Response for dataset metadata.""" 258 259 id: DatasetID 260 project_id: ProjectID 261 name: str 262 description: str 263 size_bytes: int 264 created_at: datetime.datetime 265 author_name: str 266 row_count: int 267 column_count: int 268 origin: DatasetOrigin 269 metadata: Dict[str, MetadataValueType] 270 tags: List[str] 271 272 @classmethod 273 def from_dataset_metadata(cls, dataset: DatasetMetadataFull): 274 """Create from DatasetMetadataFull.""" 275 return cls(**{k: v for k, v in dataset.__dict__.items() if k in cls.__fields__}) 276 277 278 class ListDatasetResponse(EvidentlyAPIModel): 279 """Response for listing datasets.""" 280 281 datasets: List[DatasetMetadataResponse] 282 283 284 @get("/") 285 async def list_datasets( 286 dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)], 287 user_id: UserID, 288 project_id: ProjectID, 289 limit: Annotated[Optional[int], Parameter(title="Page size")] = None, 290 origin: Annotated[Optional[List[DatasetOrigin]], Parameter(schema_extra={"type": "array"})] = None, 291 draft: Annotated[Optional[bool], Parameter(title="Return draft datasets")] = False, 292 ) -> ListDatasetResponse: 293 """List datasets.""" 294 datasets = await dataset_manager.list_datasets(user_id, project_id, limit, origin, draft) 295 return ListDatasetResponse(datasets=[DatasetMetadataResponse.from_dataset_metadata(d) for d in datasets]) 296 297 298 @get("/{dataset_id:uuid}/metadata") 299 async def get_dataset_metadata( 300 dataset_id: Annotated[DatasetID, Parameter(title="dataset id")], 301 dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)], 302 user_id: UserID, 303 ) -> DatasetMetadata: 304 """Get dataset metadata.""" 305 dataset = await dataset_manager.get_dataset_metadata(user_id, dataset_id) 306 return dataset 307 308 309 class DatasetDataDefinitionResponse(EvidentlyAPIModel): 310 """Response for data definition.""" 311 312 data_definition: DataDefinition 313 all_columns: List[str] 314 315 316 @get("/{dataset_id:uuid}/data_definition") 317 async def get_data_definition( 318 dataset_id: Annotated[DatasetID, Parameter()], 319 dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)], 320 user_id: UserID, 321 ) -> DatasetDataDefinitionResponse: 322 """Get data definition for a dataset.""" 323 dataset = await dataset_manager.get_dataset_metadata(user_id, dataset_id) 324 return DatasetDataDefinitionResponse(data_definition=dataset.data_definition, all_columns=dataset.all_columns) 325 326 327 class MaterializeDatasetRequest(BaseModel): 328 """Request for materializing a dataset from a source.""" 329 330 name: str 331 description: Optional[str] = None 332 source: DataSourceDTO 333 metadata: Dict[str, MetadataValueType] = {} 334 tags: List[str] = [] 335 336 337 class MaterializeDatasetResponse(EvidentlyAPIModel): 338 """Response for materializing a dataset.""" 339 340 dataset_id: DatasetID 341 342 343 class AddTracingDatasetRequest(BaseModel): 344 """Request for creating a tracing dataset.""" 345 346 name: str 347 348 class Config: 349 extra = "forbid" 350 351 352 class AddTracingDatasetResponse(EvidentlyAPIModel): 353 """Response for creating a tracing dataset.""" 354 355 dataset_id: DatasetID 356 external_dataset_id: str = "" 357 358 359 @post("/tracing") 360 async def add_tracing_dataset( 361 data: AddTracingDatasetRequest, 362 dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)], 363 user_id: UserID, 364 project_id: ProjectID, 365 ) -> AddTracingDatasetResponse: 366 """Create a tracing dataset.""" 367 dataset = await dataset_manager.create_tracing_dataset( 368 user_id=user_id, 369 project_id=project_id, 370 name=data.name, 371 ) 372 return AddTracingDatasetResponse(dataset_id=dataset.id, external_dataset_id="") 373 374 375 @post("/materialize") 376 async def materialize_from_source( 377 data: MaterializeDatasetRequest, 378 dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)], 379 user_id: UserID, 380 project_id: ProjectID, 381 ) -> MaterializeDatasetResponse: 382 """Materialize a dataset from a data source.""" 383 df = await data.source.to_data_source(user_id=user_id, project_id=project_id).materialize(dataset_manager) 384 dataset = await dataset_manager.upload_dataset( 385 user_id=user_id, 386 project_id=project_id, 387 name=data.name, 388 description=data.description, 389 data=df, 390 data_definition=Dataset.from_pandas(df).data_definition, 391 origin=DatasetOrigin.dataset, 392 metadata=data.metadata, 393 tags=data.tags, 394 ) 395 return MaterializeDatasetResponse(dataset_id=dataset.id) 396 397 398 def datasets_api(guard: Callable) -> Router: 399 """Create datasets API router.""" 400 return Router( 401 "/datasets", 402 route_handlers=[ 403 # read 404 Router( 405 "", 406 route_handlers=[ 407 get_dataset, 408 list_datasets, 409 download_dataset, 410 get_dataset_metadata, 411 get_data_definition, 412 ], 413 ), 414 # write 415 Router( 416 "", 417 route_handlers=[ 418 upload_dataset, 419 update_dataset, 420 delete_dataset, 421 materialize_from_source, 422 add_tracing_dataset, 423 ], 424 guards=[guard], 425 ), 426 ], 427 )