/ tests / entities / test_evaluation_dataset.py
test_evaluation_dataset.py
  1  import json
  2  from unittest.mock import Mock, patch
  3  
  4  import pandas as pd
  5  import pytest
  6  from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan
  7  
  8  from mlflow.entities.dataset_record import DatasetRecord
  9  from mlflow.entities.dataset_record_source import DatasetRecordSourceType
 10  from mlflow.entities.evaluation_dataset import EvaluationDataset
 11  from mlflow.entities.span import Span, SpanType
 12  from mlflow.entities.trace import Trace
 13  from mlflow.entities.trace_data import TraceData
 14  from mlflow.entities.trace_info import TraceInfo
 15  from mlflow.entities.trace_location import TraceLocation
 16  from mlflow.entities.trace_state import TraceState
 17  from mlflow.exceptions import MlflowException
 18  from mlflow.tracing.utils import build_otel_context
 19  
 20  
 21  def test_evaluation_dataset_creation():
 22      dataset = EvaluationDataset(
 23          dataset_id="dataset123",
 24          name="test_dataset",
 25          digest="abc123",
 26          created_time=123456789,
 27          last_update_time=987654321,
 28          tags={"source": "manual", "type": "HUMAN"},
 29          schema='{"fields": ["input", "output"]}',
 30          profile='{"count": 100}',
 31          created_by="user1",
 32          last_updated_by="user2",
 33      )
 34  
 35      assert dataset.dataset_id == "dataset123"
 36      assert dataset.name == "test_dataset"
 37      assert dataset.tags == {"source": "manual", "type": "HUMAN"}
 38      assert dataset.schema == '{"fields": ["input", "output"]}'
 39      assert dataset.profile == '{"count": 100}'
 40      assert dataset.digest == "abc123"
 41      assert dataset.created_by == "user1"
 42      assert dataset.last_updated_by == "user2"
 43      assert dataset.created_time == 123456789
 44      assert dataset.last_update_time == 987654321
 45  
 46      dataset.experiment_ids = ["exp1", "exp2"]
 47      assert dataset.experiment_ids == ["exp1", "exp2"]
 48  
 49  
 50  def test_evaluation_dataset_timestamps_required():
 51      dataset = EvaluationDataset(
 52          dataset_id="dataset123",
 53          name="test_dataset",
 54          digest="digest123",
 55          created_time=123456789,
 56          last_update_time=987654321,
 57      )
 58  
 59      assert dataset.created_time == 123456789
 60      assert dataset.last_update_time == 987654321
 61  
 62  
 63  def test_evaluation_dataset_experiment_ids_setter():
 64      dataset = EvaluationDataset(
 65          dataset_id="dataset123",
 66          name="test_dataset",
 67          digest="digest123",
 68          created_time=123456789,
 69          last_update_time=123456789,
 70      )
 71  
 72      new_experiment_ids = ["exp1", "exp2"]
 73      dataset.experiment_ids = new_experiment_ids
 74      assert dataset._experiment_ids == new_experiment_ids
 75      assert dataset.experiment_ids == new_experiment_ids
 76  
 77      dataset.experiment_ids = []
 78      assert dataset._experiment_ids == []
 79      assert dataset.experiment_ids == []
 80  
 81      dataset.experiment_ids = None
 82      assert dataset._experiment_ids == []
 83      assert dataset.experiment_ids == []
 84  
 85  
 86  def test_evaluation_dataset_to_from_proto():
 87      dataset = EvaluationDataset(
 88          dataset_id="dataset123",
 89          name="test_dataset",
 90          tags={"source": "manual", "type": "HUMAN"},
 91          schema='{"fields": ["input", "output"]}',
 92          profile='{"count": 100}',
 93          digest="abc123",
 94          created_time=123456789,
 95          last_update_time=987654321,
 96          created_by="user1",
 97          last_updated_by="user2",
 98      )
 99      dataset.experiment_ids = ["exp1", "exp2"]
100  
101      proto = dataset.to_proto()
102      assert proto.name == "test_dataset"
103      assert proto.tags == '{"source": "manual", "type": "HUMAN"}'
104      assert proto.schema == '{"fields": ["input", "output"]}'
105      assert proto.profile == '{"count": 100}'
106      assert proto.digest == "abc123"
107      assert proto.created_time == 123456789
108      assert proto.last_update_time == 987654321
109      assert proto.created_by == "user1"
110      assert proto.last_updated_by == "user2"
111      assert list(proto.experiment_ids) == ["exp1", "exp2"]
112  
113      dataset2 = EvaluationDataset.from_proto(proto)
114      assert dataset2.dataset_id == dataset.dataset_id
115      assert dataset2.name == dataset.name
116      assert dataset2.tags == dataset.tags
117      assert dataset2.schema == dataset.schema
118      assert dataset2.profile == dataset.profile
119      assert dataset2.digest == dataset.digest
120      assert dataset2.created_time == dataset.created_time
121      assert dataset2.last_update_time == dataset.last_update_time
122      assert dataset2.created_by == dataset.created_by
123      assert dataset2.last_updated_by == dataset.last_updated_by
124      assert dataset2._experiment_ids == ["exp1", "exp2"]
125      assert dataset2.experiment_ids == ["exp1", "exp2"]
126  
127  
128  def test_evaluation_dataset_to_from_proto_minimal():
129      dataset = EvaluationDataset(
130          dataset_id="dataset123",
131          name="test_dataset",
132          digest="digest123",
133          created_time=123456789,
134          last_update_time=123456789,
135      )
136  
137      proto = dataset.to_proto()
138      dataset2 = EvaluationDataset.from_proto(proto)
139  
140      assert dataset2.dataset_id == "dataset123"
141      assert dataset2.name == "test_dataset"
142      assert dataset2.tags is None
143      assert dataset2.schema is None
144      assert dataset2.profile is None
145      assert dataset2.digest == "digest123"
146      assert dataset2.created_by is None
147      assert dataset2.last_updated_by is None
148      assert dataset2._experiment_ids is None
149  
150  
151  def test_evaluation_dataset_to_from_dict():
152      dataset = EvaluationDataset(
153          dataset_id="dataset123",
154          name="test_dataset",
155          tags={"source": "manual", "type": "HUMAN"},
156          schema='{"fields": ["input", "output"]}',
157          profile='{"count": 100}',
158          digest="abc123",
159          created_time=123456789,
160          last_update_time=987654321,
161          created_by="user1",
162          last_updated_by="user2",
163      )
164      dataset.experiment_ids = ["exp1", "exp2"]
165  
166      dataset._records = [
167          DatasetRecord(
168              dataset_record_id="rec789",
169              dataset_id="dataset123",
170              inputs={"question": "What is MLflow?"},
171              created_time=123456789,
172              last_update_time=123456789,
173          )
174      ]
175  
176      data = dataset.to_dict()
177      assert data["dataset_id"] == "dataset123"
178      assert data["name"] == "test_dataset"
179      assert data["tags"] == {"source": "manual", "type": "HUMAN"}
180      assert data["schema"] == '{"fields": ["input", "output"]}'
181      assert data["profile"] == '{"count": 100}'
182      assert data["digest"] == "abc123"
183      assert data["created_time"] == 123456789
184      assert data["last_update_time"] == 987654321
185      assert data["created_by"] == "user1"
186      assert data["last_updated_by"] == "user2"
187      assert data["experiment_ids"] == ["exp1", "exp2"]
188      assert len(data["records"]) == 1
189      assert data["records"][0]["inputs"]["question"] == "What is MLflow?"
190  
191      dataset2 = EvaluationDataset.from_dict(data)
192      assert dataset2.dataset_id == dataset.dataset_id
193      assert dataset2.name == dataset.name
194      assert dataset2.tags == dataset.tags
195      assert dataset2.schema == dataset.schema
196      assert dataset2.profile == dataset.profile
197      assert dataset2.digest == dataset.digest
198      assert dataset2.created_time == dataset.created_time
199      assert dataset2.last_update_time == dataset.last_update_time
200      assert dataset2.created_by == dataset.created_by
201      assert dataset2.last_updated_by == dataset.last_updated_by
202      assert dataset2._experiment_ids == ["exp1", "exp2"]
203      assert dataset2.experiment_ids == ["exp1", "exp2"]
204      assert len(dataset2._records) == 1
205      assert dataset2._records[0].inputs["question"] == "What is MLflow?"
206  
207  
208  def test_evaluation_dataset_to_from_dict_minimal():
209      dataset = EvaluationDataset(
210          dataset_id="dataset123",
211          name="test_dataset",
212          digest="digest123",
213          created_time=123456789,
214          last_update_time=123456789,
215      )
216      dataset._experiment_ids = []
217      dataset._records = []
218  
219      data = dataset.to_dict()
220      dataset2 = EvaluationDataset.from_dict(data)
221  
222      assert dataset2.dataset_id == "dataset123"
223      assert dataset2.name == "test_dataset"
224      assert dataset2.tags is None
225      assert dataset2.schema is None
226      assert dataset2.profile is None
227      assert dataset2.digest == "digest123"
228      assert dataset2.created_by is None
229      assert dataset2.last_updated_by is None
230      assert dataset2._experiment_ids == []
231      assert dataset2._records == []
232  
233  
234  def test_evaluation_dataset_has_records():
235      dataset = EvaluationDataset(
236          dataset_id="dataset123",
237          name="test_dataset",
238          digest="digest123",
239          created_time=123456789,
240          last_update_time=123456789,
241      )
242  
243      assert dataset.has_records() is False
244  
245      dataset._records = [
246          DatasetRecord(
247              dataset_record_id="rec123",
248              dataset_id="dataset123",
249              inputs={"test": "data"},
250              created_time=123456789,
251              last_update_time=123456789,
252          )
253      ]
254      assert dataset.has_records() is True
255  
256      dataset._records = []
257      assert dataset.has_records() is True
258  
259  
260  def test_evaluation_dataset_proto_with_unloaded_experiment_ids():
261      dataset = EvaluationDataset(
262          dataset_id="dataset123",
263          name="test_dataset",
264          digest="digest123",
265          created_time=123456789,
266          last_update_time=123456789,
267      )
268  
269      assert dataset._experiment_ids is None
270  
271      proto = dataset.to_proto()
272      assert len(proto.experiment_ids) == 0
273      assert dataset._experiment_ids is None
274  
275  
276  def test_evaluation_dataset_complex_tags():
277      complex_tags = {
278          "source": "automated",
279          "metadata": {"version": "1.0", "config": {"temperature": 0.7, "max_tokens": 100}},
280          "labels": ["production", "evaluated"],
281      }
282  
283      dataset = EvaluationDataset(
284          dataset_id="dataset123",
285          name="test_dataset",
286          digest="digest123",
287          created_time=123456789,
288          last_update_time=123456789,
289          tags=complex_tags,
290      )
291  
292      proto = dataset.to_proto()
293      dataset2 = EvaluationDataset.from_proto(proto)
294      assert dataset2.tags == complex_tags
295  
296      dataset._experiment_ids = []
297      dataset._records = []
298  
299      data = dataset.to_dict()
300      dataset3 = EvaluationDataset.from_dict(data)
301      assert dataset3.tags == complex_tags
302  
303  
304  def test_evaluation_dataset_to_df():
305      dataset = EvaluationDataset(
306          dataset_id="dataset123",
307          name="test_dataset",
308          digest="digest123",
309          created_time=123456789,
310          last_update_time=123456789,
311      )
312  
313      # Test empty dataset
314      df_empty = dataset.to_df()
315      assert isinstance(df_empty, pd.DataFrame)
316      expected_columns = [
317          "inputs",
318          "outputs",
319          "expectations",
320          "tags",
321          "source_type",
322          "source_id",
323          "source",
324          "created_time",
325          "dataset_record_id",
326      ]
327      assert list(df_empty.columns) == expected_columns
328      assert len(df_empty) == 0
329  
330      # Test dataset with records
331      dataset._records = [
332          DatasetRecord(
333              dataset_record_id="rec123",
334              dataset_id="dataset123",
335              inputs={"question": "What is MLflow?"},
336              outputs={
337                  "answer": "MLflow is an ML platform for managing machine learning lifecycle",
338                  "key1": "value1",
339              },
340              expectations={"answer": "MLflow is an ML platform"},
341              tags={"source": "manual"},
342              source_type="HUMAN",
343              source_id="user123",
344              created_time=123456789,
345              last_update_time=123456789,
346          ),
347          DatasetRecord(
348              dataset_record_id="rec456",
349              dataset_id="dataset123",
350              inputs={"question": "What is Spark?"},
351              outputs={"answer": "Apache Spark is a unified analytics engine for data processing"},
352              expectations={"answer": "Spark is a data engine"},
353              tags={"source": "automated"},
354              source_type="CODE",
355              source_id="script456",
356              created_time=123456790,
357              last_update_time=123456790,
358          ),
359      ]
360  
361      df = dataset.to_df()
362      assert isinstance(df, pd.DataFrame)
363      assert list(df.columns) == expected_columns
364      assert len(df) == 2
365  
366      # Check that outputs column exists and contains actual values
367      assert "outputs" in df.columns
368      assert df["outputs"].iloc[0] == {
369          "answer": "MLflow is an ML platform for managing machine learning lifecycle",
370          "key1": "value1",
371      }
372      assert df["outputs"].iloc[1] == {
373          "answer": "Apache Spark is a unified analytics engine for data processing"
374      }
375  
376      # Check other columns have expected values
377      assert df["inputs"].iloc[0] == {"question": "What is MLflow?"}
378      assert df["inputs"].iloc[1] == {"question": "What is Spark?"}
379      assert df["expectations"].iloc[0] == {"answer": "MLflow is an ML platform"}
380      assert df["expectations"].iloc[1] == {"answer": "Spark is a data engine"}
381      assert df["tags"].iloc[0] == {"source": "manual"}
382      assert df["tags"].iloc[1] == {"source": "automated"}
383      assert df["source_type"].iloc[0] == "HUMAN"
384      assert df["source_type"].iloc[1] == "CODE"
385      assert df["source_id"].iloc[0] == "user123"
386      assert df["source_id"].iloc[1] == "script456"
387      assert df["dataset_record_id"].iloc[0] == "rec123"
388      assert df["dataset_record_id"].iloc[1] == "rec456"
389  
390  
391  def create_test_span(
392      span_id=1,
393      parent_id=None,
394      name="test_span",
395      inputs=None,
396      outputs=None,
397      span_type=SpanType.UNKNOWN,
398  ):
399      attributes = {
400          "mlflow.spanType": json.dumps(span_type),
401      }
402  
403      if inputs is not None:
404          attributes["mlflow.spanInputs"] = json.dumps(inputs)
405  
406      if outputs is not None:
407          attributes["mlflow.spanOutputs"] = json.dumps(outputs)
408  
409      otel_span = OTelReadableSpan(
410          name=name,
411          context=build_otel_context(trace_id=123456789, span_id=span_id),
412          parent=build_otel_context(trace_id=123456789, span_id=parent_id) if parent_id else None,
413          start_time=100000000,
414          end_time=200000000,
415          attributes=attributes,
416      )
417      return Span(otel_span)
418  
419  
420  def create_test_trace(
421      trace_id="test-trace-123",
422      inputs=None,
423      outputs=None,
424      expectations=None,
425      trace_metadata=None,
426      _no_defaults=False,
427  ):
428      assessments = []
429      if expectations:
430          from mlflow.entities.assessment import AssessmentSource, AssessmentSourceType, Expectation
431  
432          for name, value in expectations.items():
433              expectation = Expectation(
434                  name=name,
435                  value=value,
436                  source=AssessmentSource(
437                      source_type=AssessmentSourceType.HUMAN, source_id="test_user"
438                  ),
439              )
440              assessments.append(expectation)
441  
442      trace_info = TraceInfo(
443          trace_id=trace_id,
444          trace_location=TraceLocation.from_experiment_id("0"),
445          request_time=1234567890,
446          execution_duration=1000,
447          state=TraceState.OK,
448          assessments=assessments,
449          trace_metadata=trace_metadata or {},
450      )
451  
452      default_inputs = {"question": "What is MLflow?"}
453      default_outputs = {"answer": "MLflow is a platform"}
454  
455      if _no_defaults:
456          span_inputs = inputs
457          span_outputs = outputs
458      else:
459          span_inputs = inputs if inputs is not None else default_inputs
460          span_outputs = outputs if outputs is not None else default_outputs
461  
462      spans = [
463          create_test_span(
464              span_id=1,
465              parent_id=None,
466              name="root_span",
467              inputs=span_inputs,
468              outputs=span_outputs,
469              span_type=SpanType.CHAIN,
470          )
471      ]
472  
473      trace_data = TraceData(spans=spans)
474      return Trace(info=trace_info, data=trace_data)
475  
476  
477  def test_process_trace_records_with_dict_outputs():
478      dataset = EvaluationDataset(
479          dataset_id="dataset123",
480          name="test_dataset",
481          digest="digest123",
482          created_time=123456789,
483          last_update_time=123456789,
484      )
485  
486      trace = create_test_trace(
487          trace_id="trace1",
488          inputs={"question": "What is MLflow?"},
489          outputs={"answer": "MLflow is a platform", "confidence": 0.95},
490      )
491  
492      record_dicts = dataset._process_trace_records([trace])
493  
494      assert len(record_dicts) == 1
495      assert record_dicts[0]["inputs"] == {"question": "What is MLflow?"}
496      assert record_dicts[0]["outputs"] == {"answer": "MLflow is a platform", "confidence": 0.95}
497      assert record_dicts[0]["expectations"] == {}
498      assert record_dicts[0]["source"]["source_type"] == DatasetRecordSourceType.TRACE.value
499      assert record_dicts[0]["source"]["source_data"]["trace_id"] == "trace1"
500  
501  
502  def test_process_trace_records_with_string_outputs():
503      dataset = EvaluationDataset(
504          dataset_id="dataset123",
505          name="test_dataset",
506          digest="digest123",
507          created_time=123456789,
508          last_update_time=123456789,
509      )
510  
511      trace = create_test_trace(
512          trace_id="trace2",
513          inputs={"query": "Tell me about Python"},
514          outputs="Python is a programming language",
515      )
516  
517      record_dicts = dataset._process_trace_records([trace])
518  
519      assert len(record_dicts) == 1
520      assert record_dicts[0]["inputs"] == {"query": "Tell me about Python"}
521      assert record_dicts[0]["outputs"] == "Python is a programming language"
522      assert record_dicts[0]["expectations"] == {}
523      assert record_dicts[0]["source"]["source_type"] == DatasetRecordSourceType.TRACE.value
524  
525  
526  def test_process_trace_records_with_non_dict_non_string_outputs():
527      dataset = EvaluationDataset(
528          dataset_id="dataset123",
529          name="test_dataset",
530          digest="digest123",
531          created_time=123456789,
532          last_update_time=123456789,
533      )
534  
535      trace = create_test_trace(
536          trace_id="trace3", inputs={"x": 1, "y": 2}, outputs=["result1", "result2", "result3"]
537      )
538  
539      record_dicts = dataset._process_trace_records([trace])
540  
541      assert len(record_dicts) == 1
542      assert record_dicts[0]["inputs"] == {"x": 1, "y": 2}
543      assert record_dicts[0]["outputs"] == ["result1", "result2", "result3"]
544      assert record_dicts[0]["source"]["source_type"] == DatasetRecordSourceType.TRACE.value
545  
546  
547  def test_process_trace_records_with_numeric_outputs():
548      dataset = EvaluationDataset(
549          dataset_id="dataset123",
550          name="test_dataset",
551          digest="digest123",
552          created_time=123456789,
553          last_update_time=123456789,
554      )
555  
556      trace = create_test_trace(trace_id="trace4", inputs={"number": 42}, outputs=42)
557  
558      record_dicts = dataset._process_trace_records([trace])
559  
560      assert len(record_dicts) == 1
561      assert record_dicts[0]["outputs"] == 42
562  
563  
564  def test_process_trace_records_with_none_outputs():
565      dataset = EvaluationDataset(
566          dataset_id="dataset123",
567          name="test_dataset",
568          digest="digest123",
569          created_time=123456789,
570          last_update_time=123456789,
571      )
572  
573      trace = create_test_trace(
574          trace_id="trace5", inputs={"input": "test"}, outputs=None, _no_defaults=True
575      )
576  
577      record_dicts = dataset._process_trace_records([trace])
578  
579      assert len(record_dicts) == 1
580      assert record_dicts[0]["outputs"] is None
581  
582  
583  def test_process_trace_records_with_expectations():
584      dataset = EvaluationDataset(
585          dataset_id="dataset123",
586          name="test_dataset",
587          digest="digest123",
588          created_time=123456789,
589          last_update_time=123456789,
590      )
591  
592      trace = create_test_trace(
593          trace_id="trace6",
594          inputs={"question": "What is 2+2?"},
595          outputs={"answer": "4"},
596          expectations={"correctness": True, "tone": "neutral"},
597      )
598  
599      record_dicts = dataset._process_trace_records([trace])
600  
601      assert len(record_dicts) == 1
602      assert record_dicts[0]["expectations"] == {"correctness": True, "tone": "neutral"}
603  
604  
605  def test_process_trace_records_multiple_traces():
606      dataset = EvaluationDataset(
607          dataset_id="dataset123",
608          name="test_dataset",
609          digest="digest123",
610          created_time=123456789,
611          last_update_time=123456789,
612      )
613  
614      traces = [
615          create_test_trace(trace_id="trace1", outputs={"result": "answer1"}),
616          create_test_trace(trace_id="trace2", outputs="string answer"),
617          create_test_trace(trace_id="trace3", outputs=[1, 2, 3]),
618      ]
619  
620      record_dicts = dataset._process_trace_records(traces)
621  
622      assert len(record_dicts) == 3
623      assert record_dicts[0]["outputs"] == {"result": "answer1"}
624      assert record_dicts[1]["outputs"] == "string answer"
625      assert record_dicts[2]["outputs"] == [1, 2, 3]
626  
627  
628  def test_process_trace_records_mixed_types_error():
629      dataset = EvaluationDataset(
630          dataset_id="dataset123",
631          name="test_dataset",
632          digest="digest123",
633          created_time=123456789,
634          last_update_time=123456789,
635      )
636  
637      trace = create_test_trace(trace_id="trace1")
638      not_a_trace = {"not": "a trace"}
639  
640      with pytest.raises(
641          MlflowException,
642          match=(
643              "Mixed types in trace list.*Expected all elements to be Trace objects.*"
644              "element at index 1 is dict"
645          ),
646      ):
647          dataset._process_trace_records([trace, not_a_trace])
648  
649  
650  def test_process_trace_records_preserves_session_metadata():
651      dataset = EvaluationDataset(
652          dataset_id="dataset123",
653          name="test_dataset",
654          digest="digest123",
655          created_time=123456789,
656          last_update_time=123456789,
657      )
658  
659      # Create trace with session metadata
660      trace_with_session = create_test_trace(
661          trace_id="tr-123",
662          trace_metadata={"mlflow.trace.session": "session_1"},
663      )
664  
665      # Create trace without session metadata
666      trace_without_session = create_test_trace(
667          trace_id="tr-456",
668          trace_metadata={},
669      )
670  
671      records = dataset._process_trace_records([trace_with_session, trace_without_session])
672  
673      # Trace with session should have session_id in source_data
674      assert records[0]["source"]["source_data"]["trace_id"] == "tr-123"
675      assert records[0]["source"]["source_data"]["session_id"] == "session_1"
676  
677      # Trace without session should only have trace_id
678      assert records[1]["source"]["source_data"]["trace_id"] == "tr-456"
679      assert "session_id" not in records[1]["source"]["source_data"]
680  
681  
682  def test_to_df_includes_source_column():
683      from mlflow.entities.dataset_record import DatasetRecord
684      from mlflow.entities.dataset_record_source import DatasetRecordSource
685  
686      dataset = EvaluationDataset(
687          dataset_id="dataset123",
688          name="test_dataset",
689          digest="digest123",
690          created_time=123456789,
691          last_update_time=123456789,
692      )
693  
694      # Manually add a record with source to the dataset
695      source = DatasetRecordSource(
696          source_type=DatasetRecordSourceType.TRACE,
697          source_data={"trace_id": "tr-123"},
698      )
699      record = DatasetRecord(
700          dataset_id="dataset123",
701          dataset_record_id="record123",
702          inputs={"question": "test"},
703          outputs={"answer": "test answer"},
704          expectations={},
705          tags={},
706          created_time=123456789,
707          last_update_time=123456789,
708          source=source,
709      )
710      dataset._records = [record]
711  
712      df = dataset.to_df()
713  
714      assert "source" in df.columns
715      assert df["source"].notna().all()
716      assert df["source"].iloc[0] == source
717  
718  
719  def test_delete_records():
720      dataset = EvaluationDataset(
721          dataset_id="dataset123",
722          name="test_dataset",
723          digest="digest123",
724          created_time=123456789,
725          last_update_time=123456789,
726      )
727  
728      # Add some records to cache
729      dataset._records = [Mock(), Mock()]
730  
731      mock_store = Mock()
732      mock_store.delete_dataset_records.return_value = 2
733  
734      with patch("mlflow.tracking._tracking_service.utils._get_store", return_value=mock_store):
735          deleted_count = dataset.delete_records(["record1", "record2"])
736  
737      assert deleted_count == 2
738      mock_store.delete_dataset_records.assert_called_once_with(
739          dataset_id="dataset123",
740          dataset_record_ids=["record1", "record2"],
741      )
742      # Verify cache was cleared
743      assert dataset._records is None