/ tests / entities / test_dataset_record_source.py
test_dataset_record_source.py
  1  import json
  2  
  3  import pytest
  4  
  5  from mlflow.entities.dataset_record_source import DatasetRecordSource, DatasetRecordSourceType
  6  from mlflow.exceptions import MlflowException
  7  from mlflow.protos.datasets_pb2 import DatasetRecordSource as ProtoDatasetRecordSource
  8  
  9  
 10  def test_dataset_record_source_type_constants():
 11      assert DatasetRecordSourceType.TRACE == "TRACE"
 12      assert DatasetRecordSourceType.HUMAN == "HUMAN"
 13      assert DatasetRecordSourceType.DOCUMENT == "DOCUMENT"
 14      assert DatasetRecordSourceType.CODE == "CODE"
 15      assert DatasetRecordSourceType.UNSPECIFIED == "UNSPECIFIED"
 16  
 17  
 18  def test_dataset_record_source_type_enum_values():
 19      assert DatasetRecordSourceType.TRACE == "TRACE"
 20      assert DatasetRecordSourceType.HUMAN == "HUMAN"
 21      assert DatasetRecordSourceType.DOCUMENT == "DOCUMENT"
 22      assert DatasetRecordSourceType.CODE == "CODE"
 23      assert DatasetRecordSourceType.UNSPECIFIED == "UNSPECIFIED"
 24  
 25      assert isinstance(DatasetRecordSourceType.TRACE, str)
 26      assert DatasetRecordSourceType.TRACE.value == "TRACE"
 27  
 28  
 29  def test_dataset_record_source_string_normalization():
 30      source1 = DatasetRecordSource(source_type="trace", source_data={})
 31      assert source1.source_type == DatasetRecordSourceType.TRACE
 32  
 33      source2 = DatasetRecordSource(source_type="HUMAN", source_data={})
 34      assert source2.source_type == DatasetRecordSourceType.HUMAN
 35  
 36      source3 = DatasetRecordSource(source_type="Document", source_data={})
 37      assert source3.source_type == DatasetRecordSourceType.DOCUMENT
 38  
 39      source4 = DatasetRecordSource(source_type=DatasetRecordSourceType.CODE, source_data={})
 40      assert source4.source_type == DatasetRecordSourceType.CODE
 41  
 42  
 43  def test_dataset_record_source_invalid_type():
 44      with pytest.raises(MlflowException, match="Invalid dataset record source type"):
 45          DatasetRecordSource(source_type="INVALID", source_data={})
 46  
 47  
 48  def test_dataset_record_source_creation():
 49      source1 = DatasetRecordSource(
 50          source_type="TRACE", source_data={"trace_id": "trace123", "span_id": "span456"}
 51      )
 52  
 53      assert source1.source_type == DatasetRecordSourceType.TRACE
 54      assert source1.source_data == {"trace_id": "trace123", "span_id": "span456"}
 55  
 56      source2 = DatasetRecordSource(
 57          source_type=DatasetRecordSourceType.HUMAN, source_data={"user_id": "user123"}
 58      )
 59  
 60      assert source2.source_type == DatasetRecordSourceType.HUMAN
 61      assert source2.source_data == {"user_id": "user123"}
 62  
 63  
 64  def test_dataset_record_source_auto_normalization():
 65      source = DatasetRecordSource(source_type="trace", source_data={"trace_id": "trace123"})
 66  
 67      assert source.source_type == DatasetRecordSourceType.TRACE
 68  
 69  
 70  def test_dataset_record_source_empty_data():
 71      source = DatasetRecordSource(source_type="HUMAN", source_data=None)
 72      assert source.source_data == {}
 73  
 74  
 75  def test_trace_source():
 76      source1 = DatasetRecordSource(
 77          source_type="TRACE", source_data={"trace_id": "trace123", "span_id": "span456"}
 78      )
 79      assert source1.source_type == DatasetRecordSourceType.TRACE
 80      assert source1.source_data["trace_id"] == "trace123"
 81      assert source1.source_data.get("span_id") == "span456"
 82  
 83      source2 = DatasetRecordSource(
 84          source_type=DatasetRecordSourceType.TRACE, source_data={"trace_id": "trace789"}
 85      )
 86      assert source2.source_data["trace_id"] == "trace789"
 87      assert source2.source_data.get("span_id") is None
 88  
 89  
 90  def test_human_source():
 91      source1 = DatasetRecordSource(source_type="HUMAN", source_data={"user_id": "user123"})
 92      assert source1.source_type == DatasetRecordSourceType.HUMAN
 93      assert source1.source_data["user_id"] == "user123"
 94  
 95      source2 = DatasetRecordSource(
 96          source_type=DatasetRecordSourceType.HUMAN,
 97          source_data={"user_id": "user456", "timestamp": "2024-01-01"},
 98      )
 99      assert source2.source_data["user_id"] == "user456"
100      assert source2.source_data["timestamp"] == "2024-01-01"
101  
102  
103  def test_document_source():
104      source1 = DatasetRecordSource(
105          source_type="DOCUMENT",
106          source_data={"doc_uri": "s3://bucket/doc.txt", "content": "Document content"},
107      )
108      assert source1.source_type == DatasetRecordSourceType.DOCUMENT
109      assert source1.source_data["doc_uri"] == "s3://bucket/doc.txt"
110      assert source1.source_data["content"] == "Document content"
111  
112      source2 = DatasetRecordSource(
113          source_type=DatasetRecordSourceType.DOCUMENT,
114          source_data={"doc_uri": "https://example.com/doc.pdf"},
115      )
116      assert source2.source_data["doc_uri"] == "https://example.com/doc.pdf"
117      assert source2.source_data.get("content") is None
118  
119  
120  def test_dataset_record_source_to_from_proto():
121      source = DatasetRecordSource(source_type="CODE", source_data={"file": "example.py", "line": 42})
122  
123      proto = source.to_proto()
124      assert isinstance(proto, ProtoDatasetRecordSource)
125      assert proto.source_type == ProtoDatasetRecordSource.SourceType.Value("CODE")
126      assert json.loads(proto.source_data) == {"file": "example.py", "line": 42}
127  
128      source2 = DatasetRecordSource.from_proto(proto)
129      assert isinstance(source2, DatasetRecordSource)
130      assert source2.source_type == DatasetRecordSourceType.CODE
131      assert source2.source_data == {"file": "example.py", "line": 42}
132  
133  
134  def test_trace_source_proto_conversion():
135      source = DatasetRecordSource(
136          source_type="TRACE", source_data={"trace_id": "trace123", "span_id": "span456"}
137      )
138  
139      proto = source.to_proto()
140      assert proto.source_type == ProtoDatasetRecordSource.SourceType.Value("TRACE")
141  
142      source2 = DatasetRecordSource.from_proto(proto)
143      assert isinstance(source2, DatasetRecordSource)
144      assert source2.source_data["trace_id"] == "trace123"
145      assert source2.source_data["span_id"] == "span456"
146  
147  
148  def test_human_source_proto_conversion():
149      source = DatasetRecordSource(source_type="HUMAN", source_data={"user_id": "user123"})
150  
151      proto = source.to_proto()
152      assert proto.source_type == ProtoDatasetRecordSource.SourceType.Value("HUMAN")
153  
154      source2 = DatasetRecordSource.from_proto(proto)
155      assert isinstance(source2, DatasetRecordSource)
156      assert source2.source_data["user_id"] == "user123"
157  
158  
159  def test_document_source_proto_conversion():
160      source = DatasetRecordSource(
161          source_type="DOCUMENT",
162          source_data={"doc_uri": "s3://bucket/doc.txt", "content": "Test content"},
163      )
164  
165      proto = source.to_proto()
166      assert proto.source_type == ProtoDatasetRecordSource.SourceType.Value("DOCUMENT")
167  
168      source2 = DatasetRecordSource.from_proto(proto)
169      assert isinstance(source2, DatasetRecordSource)
170      assert source2.source_data["doc_uri"] == "s3://bucket/doc.txt"
171      assert source2.source_data["content"] == "Test content"
172  
173  
174  def test_dataset_record_source_to_from_dict():
175      source = DatasetRecordSource(source_type="CODE", source_data={"file": "example.py", "line": 42})
176  
177      data = source.to_dict()
178      assert data == {"source_type": "CODE", "source_data": {"file": "example.py", "line": 42}}
179  
180      source2 = DatasetRecordSource.from_dict(data)
181      assert source2.source_type == DatasetRecordSourceType.CODE
182      assert source2.source_data == {"file": "example.py", "line": 42}
183  
184  
185  def test_specific_source_dict_conversion():
186      trace_data = {"source_type": "TRACE", "source_data": {"trace_id": "trace123"}}
187      trace_source = DatasetRecordSource.from_dict(trace_data)
188      assert isinstance(trace_source, DatasetRecordSource)
189      assert trace_source.source_data["trace_id"] == "trace123"
190  
191      human_data = {"source_type": "HUMAN", "source_data": {"user_id": "user123"}}
192      human_source = DatasetRecordSource.from_dict(human_data)
193      assert isinstance(human_source, DatasetRecordSource)
194      assert human_source.source_data["user_id"] == "user123"
195  
196      doc_data = {"source_type": "DOCUMENT", "source_data": {"doc_uri": "file.txt"}}
197      doc_source = DatasetRecordSource.from_dict(doc_data)
198      assert isinstance(doc_source, DatasetRecordSource)
199      assert doc_source.source_data["doc_uri"] == "file.txt"
200  
201  
202  def test_dataset_record_source_equality():
203      source1 = DatasetRecordSource(source_type="TRACE", source_data={"trace_id": "trace123"})
204  
205      source2 = DatasetRecordSource(source_type="TRACE", source_data={"trace_id": "trace123"})
206  
207      source3 = DatasetRecordSource(source_type="TRACE", source_data={"trace_id": "trace456"})
208  
209      source4 = DatasetRecordSource(source_type="HUMAN", source_data={"trace_id": "trace123"})
210  
211      assert source1 == source2
212      assert source1 != source3
213      assert source1 != source4
214      assert source1 != "not a source"
215  
216  
217  def test_dataset_record_source_with_extra_fields():
218      source = DatasetRecordSource(
219          source_type="HUMAN",
220          source_data={
221              "user_id": "user123",
222              "timestamp": "2024-01-01T00:00:00Z",
223              "annotation_tool": "labelstudio",
224              "confidence": 0.95,
225          },
226      )
227  
228      assert source.source_data["user_id"] == "user123"
229      assert source.source_data["timestamp"] == "2024-01-01T00:00:00Z"
230      assert source.source_data["annotation_tool"] == "labelstudio"
231      assert source.source_data["confidence"] == 0.95
232  
233      proto = source.to_proto()
234      source2 = DatasetRecordSource.from_proto(proto)
235      assert source2.source_data == source.source_data