/ mlflow / entities / evaluation_dataset.py
evaluation_dataset.py
  1  from __future__ import annotations
  2  
  3  import json
  4  from enum import Enum
  5  from typing import TYPE_CHECKING, Any
  6  
  7  from mlflow.data import Dataset
  8  from mlflow.data.evaluation_dataset_source import EvaluationDatasetSource
  9  from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin
 10  from mlflow.entities._mlflow_object import _MlflowObject
 11  from mlflow.entities.dataset_record import DatasetRecord
 12  from mlflow.entities.dataset_record_source import DatasetRecordSourceType
 13  from mlflow.exceptions import MlflowException
 14  from mlflow.protos.datasets_pb2 import Dataset as ProtoDataset
 15  from mlflow.telemetry.events import DatasetToDataFrameEvent, MergeRecordsEvent
 16  from mlflow.telemetry.track import record_usage_event
 17  from mlflow.tracing.constant import TraceMetadataKey
 18  from mlflow.tracking.context import registry as context_registry
 19  from mlflow.utils.mlflow_tags import MLFLOW_USER
 20  
 21  if TYPE_CHECKING:
 22      import pandas as pd
 23  
 24      from mlflow.entities.trace import Trace
 25  
 26  
 27  SESSION_IDENTIFIER_FIELDS = frozenset({"goal"})
 28  SESSION_INPUT_FIELDS = frozenset({"persona", "goal", "context", "simulation_guidelines"})
 29  SESSION_ALLOWED_COLUMNS = SESSION_INPUT_FIELDS | {"expectations", "tags", "source"}
 30  
 31  
 32  class DatasetGranularity(Enum):
 33      TRACE = "trace"
 34      SESSION = "session"
 35      UNKNOWN = "unknown"
 36  
 37  
 38  class EvaluationDataset(_MlflowObject, Dataset, PyFuncConvertibleDatasetMixin):
 39      """
 40      Evaluation dataset for storing inputs and expectations for GenAI evaluation.
 41  
 42      This class supports lazy loading of records - when retrieved via get_evaluation_dataset(),
 43      only metadata is loaded. Records are fetched when to_df() or merge_records() is called.
 44      """
 45  
 46      def __init__(
 47          self,
 48          dataset_id: str,
 49          name: str,
 50          digest: str,
 51          created_time: int,
 52          last_update_time: int,
 53          tags: dict[str, Any] | None = None,
 54          schema: str | None = None,
 55          profile: str | None = None,
 56          created_by: str | None = None,
 57          last_updated_by: str | None = None,
 58      ):
 59          """Initialize the EvaluationDataset."""
 60          self.dataset_id = dataset_id
 61          self.created_time = created_time
 62          self.last_update_time = last_update_time
 63          self.tags = tags
 64          self._schema = schema
 65          self._profile = profile
 66          self.created_by = created_by
 67          self.last_updated_by = last_updated_by
 68          self._experiment_ids = None
 69          self._records = None
 70  
 71          source = EvaluationDatasetSource(dataset_id=self.dataset_id)
 72          Dataset.__init__(self, source=source, name=name, digest=digest)
 73  
 74      def _compute_digest(self) -> str:
 75          """
 76          Compute digest for the dataset. This is called by Dataset.__init__ if no digest is provided.
 77          Since we always have a digest from the dataclass initialization, this should not be called.
 78          """
 79          return self.digest
 80  
 81      @property
 82      def source(self) -> EvaluationDatasetSource:
 83          """Override source property to return the correct type."""
 84          return self._source
 85  
 86      @property
 87      def schema(self) -> str | None:
 88          """
 89          Dataset schema information.
 90          """
 91          return self._schema
 92  
 93      @property
 94      def profile(self) -> str | None:
 95          """
 96          Dataset profile information.
 97          """
 98          return self._profile
 99  
100      @property
101      def experiment_ids(self) -> list[str]:
102          """
103          Get associated experiment IDs, loading them if necessary.
104  
105          This property implements lazy loading - experiment IDs are only fetched from the backend
106          when accessed for the first time.
107          """
108          if self._experiment_ids is None:
109              self._load_experiment_ids()
110          return self._experiment_ids or []
111  
112      @experiment_ids.setter
113      def experiment_ids(self, value: list[str]):
114          """Set experiment IDs directly."""
115          self._experiment_ids = value or []
116  
117      def _load_experiment_ids(self):
118          """Load experiment IDs from the backend."""
119          from mlflow.tracking._tracking_service.utils import _get_store
120  
121          tracking_store = _get_store()
122          self._experiment_ids = tracking_store.get_dataset_experiment_ids(self.dataset_id)
123  
124      @property
125      def records(self) -> list[DatasetRecord]:
126          """
127          Get dataset records, loading them if necessary.
128  
129          This property implements lazy loading - records are only fetched from the backend
130          when accessed for the first time.
131          """
132          if self._records is None:
133              from mlflow.tracking._tracking_service.utils import _get_store
134  
135              tracking_store = _get_store()
136              # For lazy loading, we want all records (no pagination)
137              self._records, _ = tracking_store._load_dataset_records(
138                  self.dataset_id, max_results=None
139              )
140          return self._records or []
141  
142      def has_records(self) -> bool:
143          """Check if dataset records are loaded without triggering a load."""
144          return self._records is not None
145  
146      def _process_trace_records(self, traces: list["Trace"]) -> list[dict[str, Any]]:
147          """Convert a list of Trace objects to dataset record dictionaries.
148  
149          Args:
150              traces: List of Trace objects to convert
151  
152          Returns:
153              List of dictionaries with 'inputs', 'expectations', and 'source' fields
154          """
155          from mlflow.entities.trace import Trace
156  
157          record_dicts = []
158          for i, trace in enumerate(traces):
159              if not isinstance(trace, Trace):
160                  raise MlflowException.invalid_parameter_value(
161                      f"Mixed types in trace list. Expected all elements to be Trace objects, "
162                      f"but element at index {i} is {type(trace).__name__}"
163                  )
164  
165              root_span = trace.data._get_root_span()
166              inputs = root_span.inputs if root_span and root_span.inputs is not None else {}
167              outputs = root_span.outputs if root_span and root_span.outputs is not None else None
168  
169              expectations = {}
170              expectation_assessments = trace.search_assessments(type="expectation")
171              for expectation in expectation_assessments:
172                  expectations[expectation.name] = expectation.value
173  
174              # Preserve session metadata from the original trace
175              source_data = {"trace_id": trace.info.trace_id}
176              if session_id := trace.info.trace_metadata.get(TraceMetadataKey.TRACE_SESSION):
177                  source_data["session_id"] = session_id
178  
179              record_dict = {
180                  "inputs": inputs,
181                  "outputs": outputs,
182                  "expectations": expectations,
183                  "source": {
184                      "source_type": DatasetRecordSourceType.TRACE.value,
185                      "source_data": source_data,
186                  },
187              }
188              record_dicts.append(record_dict)
189  
190          return record_dicts
191  
192      def _process_dataframe_records(self, df: "pd.DataFrame") -> list[dict[str, Any]]:
193          """Process a DataFrame into dataset record dictionaries.
194  
195          Args:
196              df: DataFrame to process. Can be either:
197                  - DataFrame from search_traces with 'trace' column containing Trace objects/JSON
198                  - Standard DataFrame with 'inputs', 'expectations' columns
199  
200          Returns:
201              List of dictionaries with 'inputs', 'expectations', and optionally 'source' fields
202          """
203          if "trace" in df.columns:
204              from mlflow.entities.trace import Trace
205  
206              traces = [
207                  Trace.from_json(trace_item) if isinstance(trace_item, str) else trace_item
208                  for trace_item in df["trace"]
209              ]
210  
211              return self._process_trace_records(traces)
212          else:
213              return df.to_dict("records")
214  
215      @record_usage_event(MergeRecordsEvent)
216      def merge_records(
217          self, records: list[dict[str, Any]] | "pd.DataFrame" | list["Trace"]
218      ) -> "EvaluationDataset":
219          """
220          Merge new records with existing ones.
221  
222          Args:
223              records: Records to merge. Can be:
224                  - List of dictionaries with 'inputs' and optionally 'expectations' and 'tags'
225                  - Session format with 'persona', 'goal', 'context' nested inside 'inputs'
226                  - DataFrame from mlflow.search_traces() - automatically parsed and converted
227                  - DataFrame with 'inputs' column and optionally 'expectations' and 'tags' columns
228                  - List of Trace objects
229  
230          Returns:
231              Self for method chaining
232  
233          Example:
234              .. code-block:: python
235  
236                  # Direct usage with search_traces DataFrame output
237                  traces_df = mlflow.search_traces()  # Returns DataFrame by default
238                  dataset.merge_records(traces_df)  # No extraction needed
239  
240                  # Or with standard DataFrame
241                  df = pd.DataFrame([{"inputs": {"q": "What?"}, "expectations": {"a": "Answer"}}])
242                  dataset.merge_records(df)
243  
244                  # Session format in inputs
245                  test_cases = [
246                      {
247                          "inputs": {
248                              "persona": "Student",
249                              "goal": "Find articles",
250                              "context": {"student_id": "U1"},
251                          }
252                      },
253                  ]
254                  dataset.merge_records(test_cases)
255          """
256          import pandas as pd
257  
258          from mlflow.entities.trace import Trace
259          from mlflow.tracking._tracking_service.utils import _get_store, get_tracking_uri
260  
261          if isinstance(records, pd.DataFrame):
262              record_dicts = self._process_dataframe_records(records)
263          elif isinstance(records, list) and records and isinstance(records[0], Trace):
264              record_dicts = self._process_trace_records(records)
265          else:
266              record_dicts = records
267  
268          self._validate_record_dicts(record_dicts)
269  
270          self._infer_source_types(record_dicts)
271  
272          tracking_store = _get_store()
273  
274          try:
275              existing_dataset = tracking_store.get_dataset(self.dataset_id)
276              self._schema = existing_dataset.schema
277          except Exception as e:
278              raise MlflowException.invalid_parameter_value(
279                  f"Cannot add records to dataset {self.dataset_id}: Dataset not found. "
280                  f"Please verify the dataset exists and check your tracking URI is set correctly "
281                  f"(currently set to: {get_tracking_uri()})."
282              ) from e
283  
284          self._validate_schema(record_dicts)
285  
286          context_tags = context_registry.resolve_tags()
287          if user_tag := context_tags.get(MLFLOW_USER):
288              for record in record_dicts:
289                  if "tags" not in record:
290                      record["tags"] = {}
291                  if MLFLOW_USER not in record["tags"]:
292                      record["tags"][MLFLOW_USER] = user_tag
293  
294          tracking_store.upsert_dataset_records(dataset_id=self.dataset_id, records=record_dicts)
295          self._records = None
296  
297          return self
298  
299      def _validate_record_dicts(self, record_dicts: list[dict[str, Any]]) -> None:
300          """Validate that record dictionaries have the required structure.
301  
302          Args:
303              record_dicts: List of record dictionaries to validate
304  
305          Raises:
306              MlflowException: If records don't have the required structure
307          """
308          for record in record_dicts:
309              if not isinstance(record, dict):
310                  raise MlflowException.invalid_parameter_value("Each record must be a dictionary")
311              if "inputs" not in record:
312                  raise MlflowException.invalid_parameter_value(
313                      "Each record must have an 'inputs' field"
314                  )
315  
316      def _infer_source_types(self, record_dicts: list[dict[str, Any]]) -> None:
317          """Infer source types for records without explicit source information.
318  
319          Simple inference rules:
320          - Records with expectations -> HUMAN (manual test cases/ground truth)
321          - Records with inputs but no expectations -> CODE (programmatically generated)
322  
323          Inference can be overridden by providing explicit source information.
324  
325          Note that trace inputs (from List[Trace] or pd.DataFrame of Trace data) will
326          always be inferred as a trace source type when processing trace records.
327  
328          Args:
329              record_dicts: List of record dictionaries to process (modified in place)
330          """
331          for record in record_dicts:
332              if "source" in record:
333                  continue
334  
335              if "expectations" in record and record["expectations"]:
336                  record["source"] = {
337                      "source_type": DatasetRecordSourceType.HUMAN.value,
338                      "source_data": {},
339                  }
340              elif "inputs" in record and "expectations" not in record:
341                  record["source"] = {
342                      "source_type": DatasetRecordSourceType.CODE.value,
343                      "source_data": {},
344                  }
345  
346      def _validate_schema(self, record_dicts: list[dict[str, Any]]) -> None:
347          """
348          Validate schema consistency of new records and compatibility with existing dataset.
349  
350          Args:
351              record_dicts: List of normalized record dictionaries
352  
353          Raises:
354              MlflowException: If records have invalid schema, inconsistent schemas within batch,
355                  or are incompatible with existing dataset schema
356          """
357          granularity_counts: dict[DatasetGranularity, int] = {}
358          has_empty_inputs = False
359  
360          for record in record_dicts:
361              input_keys = set(record.get("inputs", {}).keys())
362              if not input_keys:
363                  has_empty_inputs = True
364                  continue
365  
366              record_type = self._classify_input_fields(input_keys)
367  
368              if record_type == DatasetGranularity.UNKNOWN:
369                  session_fields = input_keys & SESSION_IDENTIFIER_FIELDS
370                  other_fields = input_keys - SESSION_INPUT_FIELDS
371                  raise MlflowException.invalid_parameter_value(
372                      f"Invalid input schema: cannot mix session fields {list(session_fields)} "
373                      f"with other fields {list(other_fields)}. "
374                      f"Consider placing {list(other_fields)} fields inside 'context'."
375                  )
376  
377              granularity_counts[record_type] = granularity_counts.get(record_type, 0) + 1
378  
379          if len(granularity_counts) > 1:
380              counts_str = ", ".join(
381                  f"{count} records with {granularity.value} granularity"
382                  for granularity, count in granularity_counts.items()
383              )
384              raise MlflowException.invalid_parameter_value(
385                  f"All records must use the same granularity. Found {counts_str}."
386              )
387  
388          batch_granularity = next(iter(granularity_counts), DatasetGranularity.UNKNOWN)
389          existing_granularity = self._get_existing_granularity()
390  
391          if has_empty_inputs and DatasetGranularity.SESSION in {
392              batch_granularity,
393              existing_granularity,
394          }:
395              raise MlflowException.invalid_parameter_value(
396                  "Empty inputs are not allowed for session records. The 'goal' field is required."
397              )
398  
399          if DatasetGranularity.UNKNOWN in {batch_granularity, existing_granularity}:
400              return
401  
402          if batch_granularity != existing_granularity:
403              raise MlflowException.invalid_parameter_value(
404                  f"New records use {batch_granularity.value} granularity, but existing "
405                  f"dataset uses {existing_granularity.value}. Cannot mix granularities."
406              )
407  
408      def _get_existing_granularity(self) -> DatasetGranularity:
409          """
410          Get granularity from the dataset's stored schema.
411  
412          Returns:
413              DatasetGranularity based on existing records, or UNKNOWN if empty/unparseable
414          """
415          if self._schema is None:
416              if self.has_records():
417                  return self._classify_input_fields(set(self.records[0].inputs.keys()))
418              return DatasetGranularity.UNKNOWN
419          try:
420              schema = json.loads(self._schema)
421              input_keys = set(schema.get("inputs", {}).keys())
422              return self._classify_input_fields(input_keys)
423          except (json.JSONDecodeError, TypeError):
424              return DatasetGranularity.UNKNOWN
425  
426      @staticmethod
427      def _classify_input_fields(input_keys: set[str]) -> DatasetGranularity:
428          """
429          Classify a set of input field names into a granularity type:
430          - SESSION: Has 'goal' field, and only session fields (persona, goal, context)
431          - TRACE: No 'goal' field present
432          - UNKNOWN: Empty or has 'goal' mixed with non-session fields
433  
434          Args:
435              input_keys: Set of field names from a record's inputs
436  
437          Returns:
438              DatasetGranularity classification for the input fields
439          """
440          if not input_keys:
441              return DatasetGranularity.UNKNOWN
442  
443          has_session_identifier = bool(input_keys & SESSION_IDENTIFIER_FIELDS)
444  
445          if not has_session_identifier:
446              return DatasetGranularity.TRACE
447  
448          if input_keys <= SESSION_INPUT_FIELDS:
449              return DatasetGranularity.SESSION
450  
451          return DatasetGranularity.UNKNOWN
452  
453      def delete_records(self, record_ids: list[str]) -> int:
454          """
455          Delete specific records from the dataset.
456  
457          Args:
458              record_ids: List of record IDs to delete.
459  
460          Returns:
461              The number of records deleted.
462  
463          Example:
464              .. code-block:: python
465  
466                  # Get record IDs to delete
467                  df = dataset.to_df()
468                  record_ids_to_delete = df["dataset_record_id"].tolist()[:2]
469  
470                  # Delete the records
471                  deleted_count = dataset.delete_records(record_ids_to_delete)
472                  print(f"Deleted {deleted_count} records")
473          """
474          from mlflow.tracking._tracking_service.utils import _get_store
475  
476          tracking_store = _get_store()
477          deleted_count = tracking_store.delete_dataset_records(
478              dataset_id=self.dataset_id,
479              dataset_record_ids=record_ids,
480          )
481          self._records = None  # Clear cached records
482          return deleted_count
483  
484      @record_usage_event(DatasetToDataFrameEvent)
485      def to_df(self) -> "pd.DataFrame":
486          """
487          Convert dataset records to a pandas DataFrame.
488  
489          This method triggers lazy loading of records if they haven't been loaded yet.
490  
491          Returns:
492              DataFrame with columns for inputs, outputs, expectations, tags, and metadata
493          """
494          import pandas as pd
495  
496          records = self.records
497  
498          if not records:
499              return pd.DataFrame(
500                  columns=[
501                      "inputs",
502                      "outputs",
503                      "expectations",
504                      "tags",
505                      "source_type",
506                      "source_id",
507                      "source",
508                      "created_time",
509                      "dataset_record_id",
510                  ]
511              )
512  
513          data = [
514              {
515                  "inputs": record.inputs,
516                  "outputs": record.outputs,
517                  "expectations": record.expectations,
518                  "tags": record.tags,
519                  "source_type": record.source_type,
520                  "source_id": record.source_id,
521                  "source": record.source,
522                  "created_time": record.created_time,
523                  "dataset_record_id": record.dataset_record_id,
524              }
525              for record in records
526          ]
527  
528          return pd.DataFrame(data)
529  
530      def to_proto(self) -> ProtoDataset:
531          """Convert to protobuf representation."""
532          proto = ProtoDataset()
533  
534          proto.dataset_id = self.dataset_id
535          proto.name = self.name
536          if self.tags is not None:
537              proto.tags = json.dumps(self.tags)
538          if self.schema is not None:
539              proto.schema = self.schema
540          if self.profile is not None:
541              proto.profile = self.profile
542          proto.digest = self.digest
543          proto.created_time = self.created_time
544          proto.last_update_time = self.last_update_time
545          if self.created_by is not None:
546              proto.created_by = self.created_by
547          if self.last_updated_by is not None:
548              proto.last_updated_by = self.last_updated_by
549          if self._experiment_ids is not None:
550              proto.experiment_ids.extend(self._experiment_ids)
551  
552          return proto
553  
554      @classmethod
555      def from_proto(cls, proto: ProtoDataset) -> "EvaluationDataset":
556          """Create instance from protobuf representation."""
557          tags = None
558          if proto.HasField("tags"):
559              tags = json.loads(proto.tags)
560  
561          dataset = cls(
562              dataset_id=proto.dataset_id,
563              name=proto.name,
564              digest=proto.digest,
565              created_time=proto.created_time,
566              last_update_time=proto.last_update_time,
567              tags=tags,
568              schema=proto.schema if proto.HasField("schema") else None,
569              profile=proto.profile if proto.HasField("profile") else None,
570              created_by=proto.created_by if proto.HasField("created_by") else None,
571              last_updated_by=proto.last_updated_by if proto.HasField("last_updated_by") else None,
572          )
573          if proto.experiment_ids:
574              dataset._experiment_ids = list(proto.experiment_ids)
575          return dataset
576  
577      def to_dict(self) -> dict[str, Any]:
578          """Convert to dictionary representation."""
579          result = super().to_dict()
580  
581          result.update({
582              "dataset_id": self.dataset_id,
583              "tags": self.tags,
584              "schema": self.schema,
585              "profile": self.profile,
586              "created_time": self.created_time,
587              "last_update_time": self.last_update_time,
588              "created_by": self.created_by,
589              "last_updated_by": self.last_updated_by,
590              "experiment_ids": self.experiment_ids,
591          })
592  
593          result["records"] = [record.to_dict() for record in self.records]
594  
595          return result
596  
597      @classmethod
598      def from_dict(cls, data: dict[str, Any]) -> "EvaluationDataset":
599          """Create instance from dictionary representation."""
600          if "dataset_id" not in data:
601              raise ValueError("dataset_id is required")
602          if "name" not in data:
603              raise ValueError("name is required")
604          if "digest" not in data:
605              raise ValueError("digest is required")
606          if "created_time" not in data:
607              raise ValueError("created_time is required")
608          if "last_update_time" not in data:
609              raise ValueError("last_update_time is required")
610  
611          dataset = cls(
612              dataset_id=data["dataset_id"],
613              name=data["name"],
614              digest=data["digest"],
615              created_time=data["created_time"],
616              last_update_time=data["last_update_time"],
617              tags=data.get("tags"),
618              schema=data.get("schema"),
619              profile=data.get("profile"),
620              created_by=data.get("created_by"),
621              last_updated_by=data.get("last_updated_by"),
622          )
623          if "experiment_ids" in data:
624              dataset._experiment_ids = data["experiment_ids"]
625  
626          if "records" in data:
627              dataset._records = [
628                  DatasetRecord.from_dict(record_data) for record_data in data["records"]
629              ]
630  
631          return dataset