test_dataset_record_source.py
1 import json 2 3 import pytest 4 5 from mlflow.entities.dataset_record_source import DatasetRecordSource, DatasetRecordSourceType 6 from mlflow.exceptions import MlflowException 7 from mlflow.protos.datasets_pb2 import DatasetRecordSource as ProtoDatasetRecordSource 8 9 10 def test_dataset_record_source_type_constants(): 11 assert DatasetRecordSourceType.TRACE == "TRACE" 12 assert DatasetRecordSourceType.HUMAN == "HUMAN" 13 assert DatasetRecordSourceType.DOCUMENT == "DOCUMENT" 14 assert DatasetRecordSourceType.CODE == "CODE" 15 assert DatasetRecordSourceType.UNSPECIFIED == "UNSPECIFIED" 16 17 18 def test_dataset_record_source_type_enum_values(): 19 assert DatasetRecordSourceType.TRACE == "TRACE" 20 assert DatasetRecordSourceType.HUMAN == "HUMAN" 21 assert DatasetRecordSourceType.DOCUMENT == "DOCUMENT" 22 assert DatasetRecordSourceType.CODE == "CODE" 23 assert DatasetRecordSourceType.UNSPECIFIED == "UNSPECIFIED" 24 25 assert isinstance(DatasetRecordSourceType.TRACE, str) 26 assert DatasetRecordSourceType.TRACE.value == "TRACE" 27 28 29 def test_dataset_record_source_string_normalization(): 30 source1 = DatasetRecordSource(source_type="trace", source_data={}) 31 assert source1.source_type == DatasetRecordSourceType.TRACE 32 33 source2 = DatasetRecordSource(source_type="HUMAN", source_data={}) 34 assert source2.source_type == DatasetRecordSourceType.HUMAN 35 36 source3 = DatasetRecordSource(source_type="Document", source_data={}) 37 assert source3.source_type == DatasetRecordSourceType.DOCUMENT 38 39 source4 = DatasetRecordSource(source_type=DatasetRecordSourceType.CODE, source_data={}) 40 assert source4.source_type == DatasetRecordSourceType.CODE 41 42 43 def test_dataset_record_source_invalid_type(): 44 with pytest.raises(MlflowException, match="Invalid dataset record source type"): 45 DatasetRecordSource(source_type="INVALID", source_data={}) 46 47 48 def test_dataset_record_source_creation(): 49 source1 = DatasetRecordSource( 50 source_type="TRACE", source_data={"trace_id": "trace123", "span_id": "span456"} 51 ) 52 53 assert source1.source_type == DatasetRecordSourceType.TRACE 54 assert source1.source_data == {"trace_id": "trace123", "span_id": "span456"} 55 56 source2 = DatasetRecordSource( 57 source_type=DatasetRecordSourceType.HUMAN, source_data={"user_id": "user123"} 58 ) 59 60 assert source2.source_type == DatasetRecordSourceType.HUMAN 61 assert source2.source_data == {"user_id": "user123"} 62 63 64 def test_dataset_record_source_auto_normalization(): 65 source = DatasetRecordSource(source_type="trace", source_data={"trace_id": "trace123"}) 66 67 assert source.source_type == DatasetRecordSourceType.TRACE 68 69 70 def test_dataset_record_source_empty_data(): 71 source = DatasetRecordSource(source_type="HUMAN", source_data=None) 72 assert source.source_data == {} 73 74 75 def test_trace_source(): 76 source1 = DatasetRecordSource( 77 source_type="TRACE", source_data={"trace_id": "trace123", "span_id": "span456"} 78 ) 79 assert source1.source_type == DatasetRecordSourceType.TRACE 80 assert source1.source_data["trace_id"] == "trace123" 81 assert source1.source_data.get("span_id") == "span456" 82 83 source2 = DatasetRecordSource( 84 source_type=DatasetRecordSourceType.TRACE, source_data={"trace_id": "trace789"} 85 ) 86 assert source2.source_data["trace_id"] == "trace789" 87 assert source2.source_data.get("span_id") is None 88 89 90 def test_human_source(): 91 source1 = DatasetRecordSource(source_type="HUMAN", source_data={"user_id": "user123"}) 92 assert source1.source_type == DatasetRecordSourceType.HUMAN 93 assert source1.source_data["user_id"] == "user123" 94 95 source2 = DatasetRecordSource( 96 source_type=DatasetRecordSourceType.HUMAN, 97 source_data={"user_id": "user456", "timestamp": "2024-01-01"}, 98 ) 99 assert source2.source_data["user_id"] == "user456" 100 assert source2.source_data["timestamp"] == "2024-01-01" 101 102 103 def test_document_source(): 104 source1 = DatasetRecordSource( 105 source_type="DOCUMENT", 106 source_data={"doc_uri": "s3://bucket/doc.txt", "content": "Document content"}, 107 ) 108 assert source1.source_type == DatasetRecordSourceType.DOCUMENT 109 assert source1.source_data["doc_uri"] == "s3://bucket/doc.txt" 110 assert source1.source_data["content"] == "Document content" 111 112 source2 = DatasetRecordSource( 113 source_type=DatasetRecordSourceType.DOCUMENT, 114 source_data={"doc_uri": "https://example.com/doc.pdf"}, 115 ) 116 assert source2.source_data["doc_uri"] == "https://example.com/doc.pdf" 117 assert source2.source_data.get("content") is None 118 119 120 def test_dataset_record_source_to_from_proto(): 121 source = DatasetRecordSource(source_type="CODE", source_data={"file": "example.py", "line": 42}) 122 123 proto = source.to_proto() 124 assert isinstance(proto, ProtoDatasetRecordSource) 125 assert proto.source_type == ProtoDatasetRecordSource.SourceType.Value("CODE") 126 assert json.loads(proto.source_data) == {"file": "example.py", "line": 42} 127 128 source2 = DatasetRecordSource.from_proto(proto) 129 assert isinstance(source2, DatasetRecordSource) 130 assert source2.source_type == DatasetRecordSourceType.CODE 131 assert source2.source_data == {"file": "example.py", "line": 42} 132 133 134 def test_trace_source_proto_conversion(): 135 source = DatasetRecordSource( 136 source_type="TRACE", source_data={"trace_id": "trace123", "span_id": "span456"} 137 ) 138 139 proto = source.to_proto() 140 assert proto.source_type == ProtoDatasetRecordSource.SourceType.Value("TRACE") 141 142 source2 = DatasetRecordSource.from_proto(proto) 143 assert isinstance(source2, DatasetRecordSource) 144 assert source2.source_data["trace_id"] == "trace123" 145 assert source2.source_data["span_id"] == "span456" 146 147 148 def test_human_source_proto_conversion(): 149 source = DatasetRecordSource(source_type="HUMAN", source_data={"user_id": "user123"}) 150 151 proto = source.to_proto() 152 assert proto.source_type == ProtoDatasetRecordSource.SourceType.Value("HUMAN") 153 154 source2 = DatasetRecordSource.from_proto(proto) 155 assert isinstance(source2, DatasetRecordSource) 156 assert source2.source_data["user_id"] == "user123" 157 158 159 def test_document_source_proto_conversion(): 160 source = DatasetRecordSource( 161 source_type="DOCUMENT", 162 source_data={"doc_uri": "s3://bucket/doc.txt", "content": "Test content"}, 163 ) 164 165 proto = source.to_proto() 166 assert proto.source_type == ProtoDatasetRecordSource.SourceType.Value("DOCUMENT") 167 168 source2 = DatasetRecordSource.from_proto(proto) 169 assert isinstance(source2, DatasetRecordSource) 170 assert source2.source_data["doc_uri"] == "s3://bucket/doc.txt" 171 assert source2.source_data["content"] == "Test content" 172 173 174 def test_dataset_record_source_to_from_dict(): 175 source = DatasetRecordSource(source_type="CODE", source_data={"file": "example.py", "line": 42}) 176 177 data = source.to_dict() 178 assert data == {"source_type": "CODE", "source_data": {"file": "example.py", "line": 42}} 179 180 source2 = DatasetRecordSource.from_dict(data) 181 assert source2.source_type == DatasetRecordSourceType.CODE 182 assert source2.source_data == {"file": "example.py", "line": 42} 183 184 185 def test_specific_source_dict_conversion(): 186 trace_data = {"source_type": "TRACE", "source_data": {"trace_id": "trace123"}} 187 trace_source = DatasetRecordSource.from_dict(trace_data) 188 assert isinstance(trace_source, DatasetRecordSource) 189 assert trace_source.source_data["trace_id"] == "trace123" 190 191 human_data = {"source_type": "HUMAN", "source_data": {"user_id": "user123"}} 192 human_source = DatasetRecordSource.from_dict(human_data) 193 assert isinstance(human_source, DatasetRecordSource) 194 assert human_source.source_data["user_id"] == "user123" 195 196 doc_data = {"source_type": "DOCUMENT", "source_data": {"doc_uri": "file.txt"}} 197 doc_source = DatasetRecordSource.from_dict(doc_data) 198 assert isinstance(doc_source, DatasetRecordSource) 199 assert doc_source.source_data["doc_uri"] == "file.txt" 200 201 202 def test_dataset_record_source_equality(): 203 source1 = DatasetRecordSource(source_type="TRACE", source_data={"trace_id": "trace123"}) 204 205 source2 = DatasetRecordSource(source_type="TRACE", source_data={"trace_id": "trace123"}) 206 207 source3 = DatasetRecordSource(source_type="TRACE", source_data={"trace_id": "trace456"}) 208 209 source4 = DatasetRecordSource(source_type="HUMAN", source_data={"trace_id": "trace123"}) 210 211 assert source1 == source2 212 assert source1 != source3 213 assert source1 != source4 214 assert source1 != "not a source" 215 216 217 def test_dataset_record_source_with_extra_fields(): 218 source = DatasetRecordSource( 219 source_type="HUMAN", 220 source_data={ 221 "user_id": "user123", 222 "timestamp": "2024-01-01T00:00:00Z", 223 "annotation_tool": "labelstudio", 224 "confidence": 0.95, 225 }, 226 ) 227 228 assert source.source_data["user_id"] == "user123" 229 assert source.source_data["timestamp"] == "2024-01-01T00:00:00Z" 230 assert source.source_data["annotation_tool"] == "labelstudio" 231 assert source.source_data["confidence"] == 0.95 232 233 proto = source.to_proto() 234 source2 = DatasetRecordSource.from_proto(proto) 235 assert source2.source_data == source.source_data