evaluation_dataset.py
1 from __future__ import annotations 2 3 import json 4 from enum import Enum 5 from typing import TYPE_CHECKING, Any 6 7 from mlflow.data import Dataset 8 from mlflow.data.evaluation_dataset_source import EvaluationDatasetSource 9 from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin 10 from mlflow.entities._mlflow_object import _MlflowObject 11 from mlflow.entities.dataset_record import DatasetRecord 12 from mlflow.entities.dataset_record_source import DatasetRecordSourceType 13 from mlflow.exceptions import MlflowException 14 from mlflow.protos.datasets_pb2 import Dataset as ProtoDataset 15 from mlflow.telemetry.events import DatasetToDataFrameEvent, MergeRecordsEvent 16 from mlflow.telemetry.track import record_usage_event 17 from mlflow.tracing.constant import TraceMetadataKey 18 from mlflow.tracking.context import registry as context_registry 19 from mlflow.utils.mlflow_tags import MLFLOW_USER 20 21 if TYPE_CHECKING: 22 import pandas as pd 23 24 from mlflow.entities.trace import Trace 25 26 27 SESSION_IDENTIFIER_FIELDS = frozenset({"goal"}) 28 SESSION_INPUT_FIELDS = frozenset({"persona", "goal", "context", "simulation_guidelines"}) 29 SESSION_ALLOWED_COLUMNS = SESSION_INPUT_FIELDS | {"expectations", "tags", "source"} 30 31 32 class DatasetGranularity(Enum): 33 TRACE = "trace" 34 SESSION = "session" 35 UNKNOWN = "unknown" 36 37 38 class EvaluationDataset(_MlflowObject, Dataset, PyFuncConvertibleDatasetMixin): 39 """ 40 Evaluation dataset for storing inputs and expectations for GenAI evaluation. 41 42 This class supports lazy loading of records - when retrieved via get_evaluation_dataset(), 43 only metadata is loaded. Records are fetched when to_df() or merge_records() is called. 44 """ 45 46 def __init__( 47 self, 48 dataset_id: str, 49 name: str, 50 digest: str, 51 created_time: int, 52 last_update_time: int, 53 tags: dict[str, Any] | None = None, 54 schema: str | None = None, 55 profile: str | None = None, 56 created_by: str | None = None, 57 last_updated_by: str | None = None, 58 ): 59 """Initialize the EvaluationDataset.""" 60 self.dataset_id = dataset_id 61 self.created_time = created_time 62 self.last_update_time = last_update_time 63 self.tags = tags 64 self._schema = schema 65 self._profile = profile 66 self.created_by = created_by 67 self.last_updated_by = last_updated_by 68 self._experiment_ids = None 69 self._records = None 70 71 source = EvaluationDatasetSource(dataset_id=self.dataset_id) 72 Dataset.__init__(self, source=source, name=name, digest=digest) 73 74 def _compute_digest(self) -> str: 75 """ 76 Compute digest for the dataset. This is called by Dataset.__init__ if no digest is provided. 77 Since we always have a digest from the dataclass initialization, this should not be called. 78 """ 79 return self.digest 80 81 @property 82 def source(self) -> EvaluationDatasetSource: 83 """Override source property to return the correct type.""" 84 return self._source 85 86 @property 87 def schema(self) -> str | None: 88 """ 89 Dataset schema information. 90 """ 91 return self._schema 92 93 @property 94 def profile(self) -> str | None: 95 """ 96 Dataset profile information. 97 """ 98 return self._profile 99 100 @property 101 def experiment_ids(self) -> list[str]: 102 """ 103 Get associated experiment IDs, loading them if necessary. 104 105 This property implements lazy loading - experiment IDs are only fetched from the backend 106 when accessed for the first time. 107 """ 108 if self._experiment_ids is None: 109 self._load_experiment_ids() 110 return self._experiment_ids or [] 111 112 @experiment_ids.setter 113 def experiment_ids(self, value: list[str]): 114 """Set experiment IDs directly.""" 115 self._experiment_ids = value or [] 116 117 def _load_experiment_ids(self): 118 """Load experiment IDs from the backend.""" 119 from mlflow.tracking._tracking_service.utils import _get_store 120 121 tracking_store = _get_store() 122 self._experiment_ids = tracking_store.get_dataset_experiment_ids(self.dataset_id) 123 124 @property 125 def records(self) -> list[DatasetRecord]: 126 """ 127 Get dataset records, loading them if necessary. 128 129 This property implements lazy loading - records are only fetched from the backend 130 when accessed for the first time. 131 """ 132 if self._records is None: 133 from mlflow.tracking._tracking_service.utils import _get_store 134 135 tracking_store = _get_store() 136 # For lazy loading, we want all records (no pagination) 137 self._records, _ = tracking_store._load_dataset_records( 138 self.dataset_id, max_results=None 139 ) 140 return self._records or [] 141 142 def has_records(self) -> bool: 143 """Check if dataset records are loaded without triggering a load.""" 144 return self._records is not None 145 146 def _process_trace_records(self, traces: list["Trace"]) -> list[dict[str, Any]]: 147 """Convert a list of Trace objects to dataset record dictionaries. 148 149 Args: 150 traces: List of Trace objects to convert 151 152 Returns: 153 List of dictionaries with 'inputs', 'expectations', and 'source' fields 154 """ 155 from mlflow.entities.trace import Trace 156 157 record_dicts = [] 158 for i, trace in enumerate(traces): 159 if not isinstance(trace, Trace): 160 raise MlflowException.invalid_parameter_value( 161 f"Mixed types in trace list. Expected all elements to be Trace objects, " 162 f"but element at index {i} is {type(trace).__name__}" 163 ) 164 165 root_span = trace.data._get_root_span() 166 inputs = root_span.inputs if root_span and root_span.inputs is not None else {} 167 outputs = root_span.outputs if root_span and root_span.outputs is not None else None 168 169 expectations = {} 170 expectation_assessments = trace.search_assessments(type="expectation") 171 for expectation in expectation_assessments: 172 expectations[expectation.name] = expectation.value 173 174 # Preserve session metadata from the original trace 175 source_data = {"trace_id": trace.info.trace_id} 176 if session_id := trace.info.trace_metadata.get(TraceMetadataKey.TRACE_SESSION): 177 source_data["session_id"] = session_id 178 179 record_dict = { 180 "inputs": inputs, 181 "outputs": outputs, 182 "expectations": expectations, 183 "source": { 184 "source_type": DatasetRecordSourceType.TRACE.value, 185 "source_data": source_data, 186 }, 187 } 188 record_dicts.append(record_dict) 189 190 return record_dicts 191 192 def _process_dataframe_records(self, df: "pd.DataFrame") -> list[dict[str, Any]]: 193 """Process a DataFrame into dataset record dictionaries. 194 195 Args: 196 df: DataFrame to process. Can be either: 197 - DataFrame from search_traces with 'trace' column containing Trace objects/JSON 198 - Standard DataFrame with 'inputs', 'expectations' columns 199 200 Returns: 201 List of dictionaries with 'inputs', 'expectations', and optionally 'source' fields 202 """ 203 if "trace" in df.columns: 204 from mlflow.entities.trace import Trace 205 206 traces = [ 207 Trace.from_json(trace_item) if isinstance(trace_item, str) else trace_item 208 for trace_item in df["trace"] 209 ] 210 211 return self._process_trace_records(traces) 212 else: 213 return df.to_dict("records") 214 215 @record_usage_event(MergeRecordsEvent) 216 def merge_records( 217 self, records: list[dict[str, Any]] | "pd.DataFrame" | list["Trace"] 218 ) -> "EvaluationDataset": 219 """ 220 Merge new records with existing ones. 221 222 Args: 223 records: Records to merge. Can be: 224 - List of dictionaries with 'inputs' and optionally 'expectations' and 'tags' 225 - Session format with 'persona', 'goal', 'context' nested inside 'inputs' 226 - DataFrame from mlflow.search_traces() - automatically parsed and converted 227 - DataFrame with 'inputs' column and optionally 'expectations' and 'tags' columns 228 - List of Trace objects 229 230 Returns: 231 Self for method chaining 232 233 Example: 234 .. code-block:: python 235 236 # Direct usage with search_traces DataFrame output 237 traces_df = mlflow.search_traces() # Returns DataFrame by default 238 dataset.merge_records(traces_df) # No extraction needed 239 240 # Or with standard DataFrame 241 df = pd.DataFrame([{"inputs": {"q": "What?"}, "expectations": {"a": "Answer"}}]) 242 dataset.merge_records(df) 243 244 # Session format in inputs 245 test_cases = [ 246 { 247 "inputs": { 248 "persona": "Student", 249 "goal": "Find articles", 250 "context": {"student_id": "U1"}, 251 } 252 }, 253 ] 254 dataset.merge_records(test_cases) 255 """ 256 import pandas as pd 257 258 from mlflow.entities.trace import Trace 259 from mlflow.tracking._tracking_service.utils import _get_store, get_tracking_uri 260 261 if isinstance(records, pd.DataFrame): 262 record_dicts = self._process_dataframe_records(records) 263 elif isinstance(records, list) and records and isinstance(records[0], Trace): 264 record_dicts = self._process_trace_records(records) 265 else: 266 record_dicts = records 267 268 self._validate_record_dicts(record_dicts) 269 270 self._infer_source_types(record_dicts) 271 272 tracking_store = _get_store() 273 274 try: 275 existing_dataset = tracking_store.get_dataset(self.dataset_id) 276 self._schema = existing_dataset.schema 277 except Exception as e: 278 raise MlflowException.invalid_parameter_value( 279 f"Cannot add records to dataset {self.dataset_id}: Dataset not found. " 280 f"Please verify the dataset exists and check your tracking URI is set correctly " 281 f"(currently set to: {get_tracking_uri()})." 282 ) from e 283 284 self._validate_schema(record_dicts) 285 286 context_tags = context_registry.resolve_tags() 287 if user_tag := context_tags.get(MLFLOW_USER): 288 for record in record_dicts: 289 if "tags" not in record: 290 record["tags"] = {} 291 if MLFLOW_USER not in record["tags"]: 292 record["tags"][MLFLOW_USER] = user_tag 293 294 tracking_store.upsert_dataset_records(dataset_id=self.dataset_id, records=record_dicts) 295 self._records = None 296 297 return self 298 299 def _validate_record_dicts(self, record_dicts: list[dict[str, Any]]) -> None: 300 """Validate that record dictionaries have the required structure. 301 302 Args: 303 record_dicts: List of record dictionaries to validate 304 305 Raises: 306 MlflowException: If records don't have the required structure 307 """ 308 for record in record_dicts: 309 if not isinstance(record, dict): 310 raise MlflowException.invalid_parameter_value("Each record must be a dictionary") 311 if "inputs" not in record: 312 raise MlflowException.invalid_parameter_value( 313 "Each record must have an 'inputs' field" 314 ) 315 316 def _infer_source_types(self, record_dicts: list[dict[str, Any]]) -> None: 317 """Infer source types for records without explicit source information. 318 319 Simple inference rules: 320 - Records with expectations -> HUMAN (manual test cases/ground truth) 321 - Records with inputs but no expectations -> CODE (programmatically generated) 322 323 Inference can be overridden by providing explicit source information. 324 325 Note that trace inputs (from List[Trace] or pd.DataFrame of Trace data) will 326 always be inferred as a trace source type when processing trace records. 327 328 Args: 329 record_dicts: List of record dictionaries to process (modified in place) 330 """ 331 for record in record_dicts: 332 if "source" in record: 333 continue 334 335 if "expectations" in record and record["expectations"]: 336 record["source"] = { 337 "source_type": DatasetRecordSourceType.HUMAN.value, 338 "source_data": {}, 339 } 340 elif "inputs" in record and "expectations" not in record: 341 record["source"] = { 342 "source_type": DatasetRecordSourceType.CODE.value, 343 "source_data": {}, 344 } 345 346 def _validate_schema(self, record_dicts: list[dict[str, Any]]) -> None: 347 """ 348 Validate schema consistency of new records and compatibility with existing dataset. 349 350 Args: 351 record_dicts: List of normalized record dictionaries 352 353 Raises: 354 MlflowException: If records have invalid schema, inconsistent schemas within batch, 355 or are incompatible with existing dataset schema 356 """ 357 granularity_counts: dict[DatasetGranularity, int] = {} 358 has_empty_inputs = False 359 360 for record in record_dicts: 361 input_keys = set(record.get("inputs", {}).keys()) 362 if not input_keys: 363 has_empty_inputs = True 364 continue 365 366 record_type = self._classify_input_fields(input_keys) 367 368 if record_type == DatasetGranularity.UNKNOWN: 369 session_fields = input_keys & SESSION_IDENTIFIER_FIELDS 370 other_fields = input_keys - SESSION_INPUT_FIELDS 371 raise MlflowException.invalid_parameter_value( 372 f"Invalid input schema: cannot mix session fields {list(session_fields)} " 373 f"with other fields {list(other_fields)}. " 374 f"Consider placing {list(other_fields)} fields inside 'context'." 375 ) 376 377 granularity_counts[record_type] = granularity_counts.get(record_type, 0) + 1 378 379 if len(granularity_counts) > 1: 380 counts_str = ", ".join( 381 f"{count} records with {granularity.value} granularity" 382 for granularity, count in granularity_counts.items() 383 ) 384 raise MlflowException.invalid_parameter_value( 385 f"All records must use the same granularity. Found {counts_str}." 386 ) 387 388 batch_granularity = next(iter(granularity_counts), DatasetGranularity.UNKNOWN) 389 existing_granularity = self._get_existing_granularity() 390 391 if has_empty_inputs and DatasetGranularity.SESSION in { 392 batch_granularity, 393 existing_granularity, 394 }: 395 raise MlflowException.invalid_parameter_value( 396 "Empty inputs are not allowed for session records. The 'goal' field is required." 397 ) 398 399 if DatasetGranularity.UNKNOWN in {batch_granularity, existing_granularity}: 400 return 401 402 if batch_granularity != existing_granularity: 403 raise MlflowException.invalid_parameter_value( 404 f"New records use {batch_granularity.value} granularity, but existing " 405 f"dataset uses {existing_granularity.value}. Cannot mix granularities." 406 ) 407 408 def _get_existing_granularity(self) -> DatasetGranularity: 409 """ 410 Get granularity from the dataset's stored schema. 411 412 Returns: 413 DatasetGranularity based on existing records, or UNKNOWN if empty/unparseable 414 """ 415 if self._schema is None: 416 if self.has_records(): 417 return self._classify_input_fields(set(self.records[0].inputs.keys())) 418 return DatasetGranularity.UNKNOWN 419 try: 420 schema = json.loads(self._schema) 421 input_keys = set(schema.get("inputs", {}).keys()) 422 return self._classify_input_fields(input_keys) 423 except (json.JSONDecodeError, TypeError): 424 return DatasetGranularity.UNKNOWN 425 426 @staticmethod 427 def _classify_input_fields(input_keys: set[str]) -> DatasetGranularity: 428 """ 429 Classify a set of input field names into a granularity type: 430 - SESSION: Has 'goal' field, and only session fields (persona, goal, context) 431 - TRACE: No 'goal' field present 432 - UNKNOWN: Empty or has 'goal' mixed with non-session fields 433 434 Args: 435 input_keys: Set of field names from a record's inputs 436 437 Returns: 438 DatasetGranularity classification for the input fields 439 """ 440 if not input_keys: 441 return DatasetGranularity.UNKNOWN 442 443 has_session_identifier = bool(input_keys & SESSION_IDENTIFIER_FIELDS) 444 445 if not has_session_identifier: 446 return DatasetGranularity.TRACE 447 448 if input_keys <= SESSION_INPUT_FIELDS: 449 return DatasetGranularity.SESSION 450 451 return DatasetGranularity.UNKNOWN 452 453 def delete_records(self, record_ids: list[str]) -> int: 454 """ 455 Delete specific records from the dataset. 456 457 Args: 458 record_ids: List of record IDs to delete. 459 460 Returns: 461 The number of records deleted. 462 463 Example: 464 .. code-block:: python 465 466 # Get record IDs to delete 467 df = dataset.to_df() 468 record_ids_to_delete = df["dataset_record_id"].tolist()[:2] 469 470 # Delete the records 471 deleted_count = dataset.delete_records(record_ids_to_delete) 472 print(f"Deleted {deleted_count} records") 473 """ 474 from mlflow.tracking._tracking_service.utils import _get_store 475 476 tracking_store = _get_store() 477 deleted_count = tracking_store.delete_dataset_records( 478 dataset_id=self.dataset_id, 479 dataset_record_ids=record_ids, 480 ) 481 self._records = None # Clear cached records 482 return deleted_count 483 484 @record_usage_event(DatasetToDataFrameEvent) 485 def to_df(self) -> "pd.DataFrame": 486 """ 487 Convert dataset records to a pandas DataFrame. 488 489 This method triggers lazy loading of records if they haven't been loaded yet. 490 491 Returns: 492 DataFrame with columns for inputs, outputs, expectations, tags, and metadata 493 """ 494 import pandas as pd 495 496 records = self.records 497 498 if not records: 499 return pd.DataFrame( 500 columns=[ 501 "inputs", 502 "outputs", 503 "expectations", 504 "tags", 505 "source_type", 506 "source_id", 507 "source", 508 "created_time", 509 "dataset_record_id", 510 ] 511 ) 512 513 data = [ 514 { 515 "inputs": record.inputs, 516 "outputs": record.outputs, 517 "expectations": record.expectations, 518 "tags": record.tags, 519 "source_type": record.source_type, 520 "source_id": record.source_id, 521 "source": record.source, 522 "created_time": record.created_time, 523 "dataset_record_id": record.dataset_record_id, 524 } 525 for record in records 526 ] 527 528 return pd.DataFrame(data) 529 530 def to_proto(self) -> ProtoDataset: 531 """Convert to protobuf representation.""" 532 proto = ProtoDataset() 533 534 proto.dataset_id = self.dataset_id 535 proto.name = self.name 536 if self.tags is not None: 537 proto.tags = json.dumps(self.tags) 538 if self.schema is not None: 539 proto.schema = self.schema 540 if self.profile is not None: 541 proto.profile = self.profile 542 proto.digest = self.digest 543 proto.created_time = self.created_time 544 proto.last_update_time = self.last_update_time 545 if self.created_by is not None: 546 proto.created_by = self.created_by 547 if self.last_updated_by is not None: 548 proto.last_updated_by = self.last_updated_by 549 if self._experiment_ids is not None: 550 proto.experiment_ids.extend(self._experiment_ids) 551 552 return proto 553 554 @classmethod 555 def from_proto(cls, proto: ProtoDataset) -> "EvaluationDataset": 556 """Create instance from protobuf representation.""" 557 tags = None 558 if proto.HasField("tags"): 559 tags = json.loads(proto.tags) 560 561 dataset = cls( 562 dataset_id=proto.dataset_id, 563 name=proto.name, 564 digest=proto.digest, 565 created_time=proto.created_time, 566 last_update_time=proto.last_update_time, 567 tags=tags, 568 schema=proto.schema if proto.HasField("schema") else None, 569 profile=proto.profile if proto.HasField("profile") else None, 570 created_by=proto.created_by if proto.HasField("created_by") else None, 571 last_updated_by=proto.last_updated_by if proto.HasField("last_updated_by") else None, 572 ) 573 if proto.experiment_ids: 574 dataset._experiment_ids = list(proto.experiment_ids) 575 return dataset 576 577 def to_dict(self) -> dict[str, Any]: 578 """Convert to dictionary representation.""" 579 result = super().to_dict() 580 581 result.update({ 582 "dataset_id": self.dataset_id, 583 "tags": self.tags, 584 "schema": self.schema, 585 "profile": self.profile, 586 "created_time": self.created_time, 587 "last_update_time": self.last_update_time, 588 "created_by": self.created_by, 589 "last_updated_by": self.last_updated_by, 590 "experiment_ids": self.experiment_ids, 591 }) 592 593 result["records"] = [record.to_dict() for record in self.records] 594 595 return result 596 597 @classmethod 598 def from_dict(cls, data: dict[str, Any]) -> "EvaluationDataset": 599 """Create instance from dictionary representation.""" 600 if "dataset_id" not in data: 601 raise ValueError("dataset_id is required") 602 if "name" not in data: 603 raise ValueError("name is required") 604 if "digest" not in data: 605 raise ValueError("digest is required") 606 if "created_time" not in data: 607 raise ValueError("created_time is required") 608 if "last_update_time" not in data: 609 raise ValueError("last_update_time is required") 610 611 dataset = cls( 612 dataset_id=data["dataset_id"], 613 name=data["name"], 614 digest=data["digest"], 615 created_time=data["created_time"], 616 last_update_time=data["last_update_time"], 617 tags=data.get("tags"), 618 schema=data.get("schema"), 619 profile=data.get("profile"), 620 created_by=data.get("created_by"), 621 last_updated_by=data.get("last_updated_by"), 622 ) 623 if "experiment_ids" in data: 624 dataset._experiment_ids = data["experiment_ids"] 625 626 if "records" in data: 627 dataset._records = [ 628 DatasetRecord.from_dict(record_data) for record_data in data["records"] 629 ] 630 631 return dataset