/ tests / utils / test_databricks_tracing_utils.py
test_databricks_tracing_utils.py
  1  import json
  2  
  3  import pytest
  4  from google.protobuf.timestamp_pb2 import Timestamp
  5  
  6  import mlflow
  7  from mlflow.entities import (
  8      AssessmentSource,
  9      Expectation,
 10      Feedback,
 11      Trace,
 12      TraceData,
 13      TraceInfo,
 14      TraceState,
 15  )
 16  from mlflow.entities.trace_location import (
 17      InferenceTableLocation,
 18      MlflowExperimentLocation,
 19      TraceLocation,
 20      TraceLocationType,
 21      UCSchemaLocation,
 22      UnityCatalog,
 23  )
 24  from mlflow.protos import assessments_pb2
 25  from mlflow.protos import databricks_tracing_pb2 as pb
 26  from mlflow.protos.assessments_pb2 import AssessmentSource as ProtoAssessmentSource
 27  from mlflow.tracing.constant import (
 28      TRACE_ID_V4_PREFIX,
 29      TRACE_SCHEMA_VERSION,
 30      TRACE_SCHEMA_VERSION_KEY,
 31      SpanAttributeKey,
 32  )
 33  from mlflow.tracing.utils import TraceMetadataKey, add_size_stats_to_trace_metadata
 34  from mlflow.utils.databricks_tracing_utils import (
 35      assessment_to_proto,
 36      get_trace_id_from_assessment_proto,
 37      inference_table_location_to_proto,
 38      mlflow_experiment_location_to_proto,
 39      parse_uc_location,
 40      trace_from_proto,
 41      trace_location_from_proto,
 42      trace_location_to_proto,
 43      trace_to_proto,
 44      uc_location_to_str,
 45      uc_schema_location_from_proto,
 46      uc_schema_location_to_proto,
 47  )
 48  
 49  
 50  def test_trace_location_to_proto_uc_schema():
 51      trace_location = TraceLocation.from_databricks_uc_schema(
 52          catalog_name="test_catalog", schema_name="test_schema"
 53      )
 54      proto = trace_location_to_proto(trace_location)
 55      assert proto.type == pb.TraceLocation.TraceLocationType.UC_SCHEMA
 56      assert proto.uc_schema.catalog_name == "test_catalog"
 57      assert proto.uc_schema.schema_name == "test_schema"
 58  
 59  
 60  def test_parse_uc_location():
 61      assert parse_uc_location("catalog.schema") == ("catalog", "schema", None)
 62      assert parse_uc_location("catalog.schema.prefix") == ("catalog", "schema", "prefix")
 63  
 64      with pytest.raises(ValueError, match="Invalid UC location"):
 65          parse_uc_location("a.b.c.d")
 66  
 67  
 68  def test_uc_location_to_str():
 69      assert uc_location_to_str("catalog", "schema") == "catalog.schema"
 70      assert uc_location_to_str("catalog", "schema", "prefix") == "catalog.schema.prefix"
 71  
 72  
 73  def test_trace_location_to_proto_mlflow_experiment():
 74      trace_location = TraceLocation.from_experiment_id(experiment_id="1234")
 75      proto = trace_location_to_proto(trace_location)
 76      assert proto.type == pb.TraceLocation.TraceLocationType.MLFLOW_EXPERIMENT
 77      assert proto.mlflow_experiment.experiment_id == "1234"
 78  
 79  
 80  def test_trace_location_to_proto_inference_table():
 81      trace_location = TraceLocation(
 82          type=TraceLocationType.INFERENCE_TABLE,
 83          inference_table=InferenceTableLocation(
 84              full_table_name="test_catalog.test_schema.test_table"
 85          ),
 86      )
 87      proto = trace_location_to_proto(trace_location)
 88      assert proto.type == pb.TraceLocation.TraceLocationType.INFERENCE_TABLE
 89      assert proto.inference_table.full_table_name == "test_catalog.test_schema.test_table"
 90  
 91  
 92  def test_uc_schema_location_to_proto():
 93      schema_location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema")
 94      proto = uc_schema_location_to_proto(schema_location)
 95      assert proto.catalog_name == "test_catalog"
 96      assert proto.schema_name == "test_schema"
 97  
 98  
 99  def test_uc_schema_location_from_proto():
100      proto = pb.UCSchemaLocation(
101          catalog_name="test_catalog",
102          schema_name="test_schema",
103          otel_spans_table_name="test_spans",
104          otel_logs_table_name="test_logs",
105      )
106      schema_location = uc_schema_location_from_proto(proto)
107      assert schema_location.catalog_name == "test_catalog"
108      assert schema_location.schema_name == "test_schema"
109      assert schema_location.full_otel_spans_table_name == "test_catalog.test_schema.test_spans"
110      assert schema_location.full_otel_logs_table_name == "test_catalog.test_schema.test_logs"
111  
112  
113  def test_inference_table_location_to_proto():
114      table_location = InferenceTableLocation(full_table_name="test_catalog.test_schema.test_table")
115      proto = inference_table_location_to_proto(table_location)
116      assert proto.full_table_name == "test_catalog.test_schema.test_table"
117  
118  
119  def test_mlflow_experiment_location_to_proto():
120      experiment_location = MlflowExperimentLocation(experiment_id="1234")
121      proto = mlflow_experiment_location_to_proto(experiment_location)
122      assert proto.experiment_id == "1234"
123  
124  
125  def test_schema_location_to_proto():
126      schema_location = UCSchemaLocation(
127          catalog_name="test_catalog",
128          schema_name="test_schema",
129      )
130      schema_location._otel_spans_table_name = "test_spans"
131      schema_location._otel_logs_table_name = "test_logs"
132      proto = uc_schema_location_to_proto(schema_location)
133      assert proto.catalog_name == "test_catalog"
134      assert proto.schema_name == "test_schema"
135      assert proto.otel_spans_table_name == "test_spans"
136      assert proto.otel_logs_table_name == "test_logs"
137  
138  
139  def test_trace_location_from_proto_uc_schema():
140      proto = pb.TraceLocation(
141          type=pb.TraceLocation.TraceLocationType.UC_SCHEMA,
142          uc_schema=pb.UCSchemaLocation(
143              catalog_name="catalog",
144              schema_name="schema",
145              otel_spans_table_name="spans",
146              otel_logs_table_name="logs",
147          ),
148      )
149      trace_location = trace_location_from_proto(proto)
150      assert trace_location.uc_schema.catalog_name == "catalog"
151      assert trace_location.uc_schema.schema_name == "schema"
152      assert trace_location.uc_schema.full_otel_spans_table_name == "catalog.schema.spans"
153      assert trace_location.uc_schema.full_otel_logs_table_name == "catalog.schema.logs"
154  
155  
156  def test_trace_location_from_proto_mlflow_experiment():
157      proto = pb.TraceLocation(
158          type=pb.TraceLocation.TraceLocationType.MLFLOW_EXPERIMENT,
159          mlflow_experiment=mlflow_experiment_location_to_proto(
160              MlflowExperimentLocation(experiment_id="1234")
161          ),
162      )
163      trace_location = trace_location_from_proto(proto)
164      assert trace_location.type == TraceLocationType.MLFLOW_EXPERIMENT
165      assert trace_location.mlflow_experiment.experiment_id == "1234"
166  
167  
168  def test_trace_location_from_proto_inference_table():
169      proto = pb.TraceLocation(
170          type=pb.TraceLocation.TraceLocationType.INFERENCE_TABLE,
171          inference_table=inference_table_location_to_proto(
172              InferenceTableLocation(full_table_name="test_catalog.test_schema.test_table")
173          ),
174      )
175      trace_location = trace_location_from_proto(proto)
176      assert trace_location.type == TraceLocationType.INFERENCE_TABLE
177      assert trace_location.inference_table.full_table_name == "test_catalog.test_schema.test_table"
178  
179  
180  def test_trace_info_to_v4_proto():
181      otel_trace_id = "2efb31387ff19263f92b2c0a61b0a8bc"
182      trace_id = f"trace:/catalog.schema/{otel_trace_id}"
183      trace_info = TraceInfo(
184          trace_id=trace_id,
185          trace_location=TraceLocation.from_databricks_uc_schema(
186              catalog_name="catalog", schema_name="schema"
187          ),
188          request_time=0,
189          state=TraceState.OK,
190          request_preview="request",
191          response_preview="response",
192          client_request_id="client_request_id",
193          tags={"key": "value"},
194      )
195      proto_trace_info = trace_info.to_proto()
196      assert proto_trace_info.trace_id == otel_trace_id
197      assert proto_trace_info.trace_location.uc_schema.catalog_name == "catalog"
198      assert proto_trace_info.trace_location.uc_schema.schema_name == "schema"
199      assert proto_trace_info.state == 1
200      assert proto_trace_info.request_preview == "request"
201      assert proto_trace_info.response_preview == "response"
202      assert proto_trace_info.client_request_id == "client_request_id"
203      assert proto_trace_info.tags == {"key": "value"}
204      assert len(proto_trace_info.assessments) == 0
205  
206      trace_info_from_proto = TraceInfo.from_proto(proto_trace_info)
207      assert trace_info_from_proto == trace_info
208  
209  
210  def test_trace_to_proto_and_from_proto():
211      with mlflow.start_span() as span:
212          otel_trace_id = span.trace_id.removeprefix("tr-")
213          uc_schema = "catalog.schema"
214          trace_id = f"trace:/{uc_schema}/{otel_trace_id}"
215          span.set_attribute(SpanAttributeKey.REQUEST_ID, trace_id)
216          mlflow_span = span.to_immutable_span()
217  
218      assert mlflow_span.trace_id == trace_id
219      trace = Trace(
220          info=TraceInfo(
221              trace_id=trace_id,
222              trace_location=TraceLocation.from_databricks_uc_schema(
223                  catalog_name="catalog", schema_name="schema"
224              ),
225              request_time=0,
226              state=TraceState.OK,
227              request_preview="request",
228              response_preview="response",
229              client_request_id="client_request_id",
230              tags={"key": "value"},
231          ),
232          data=TraceData(spans=[mlflow_span]),
233      )
234  
235      proto_trace_v4 = trace_to_proto(trace)
236  
237      assert proto_trace_v4.trace_info.trace_id == otel_trace_id
238      assert proto_trace_v4.trace_info.trace_location.uc_schema.catalog_name == "catalog"
239      assert proto_trace_v4.trace_info.trace_location.uc_schema.schema_name == "schema"
240      assert len(proto_trace_v4.spans) == len(trace.data.spans)
241  
242      reconstructed_trace = trace_from_proto(proto_trace_v4, location_id="catalog.schema")
243  
244      assert reconstructed_trace.info.trace_id == trace_id
245      assert reconstructed_trace.info.trace_location.uc_schema.catalog_name == "catalog"
246      assert reconstructed_trace.info.trace_location.uc_schema.schema_name == "schema"
247      assert len(reconstructed_trace.data.spans) == len(trace.data.spans)
248  
249      original_span = trace.data.spans[0]
250      reconstructed_span = reconstructed_trace.data.spans[0]
251  
252      assert reconstructed_span.name == original_span.name
253      assert reconstructed_span.span_id == original_span.span_id
254      assert reconstructed_span.trace_id == original_span.trace_id
255      assert reconstructed_span.inputs == original_span.inputs
256      assert reconstructed_span.outputs == original_span.outputs
257      assert reconstructed_span.get_attribute("custom") == original_span.get_attribute("custom")
258  
259  
260  def test_trace_from_proto_with_location_preserves_v4_trace_id():
261      with mlflow.start_span() as span:
262          otel_trace_id = span.trace_id.removeprefix("tr-")
263          uc_schema = "catalog.schema"
264          trace_id_v4 = f"{TRACE_ID_V4_PREFIX}{uc_schema}/{otel_trace_id}"
265          span.set_attribute(SpanAttributeKey.REQUEST_ID, trace_id_v4)
266          mlflow_span = span.to_immutable_span()
267  
268      # Create trace with v4 trace ID
269      trace = Trace(
270          info=TraceInfo(
271              trace_id=trace_id_v4,
272              trace_location=TraceLocation.from_databricks_uc_schema(
273                  catalog_name="catalog", schema_name="schema"
274              ),
275              request_time=0,
276              state=TraceState.OK,
277          ),
278          data=TraceData(spans=[mlflow_span]),
279      )
280  
281      # Convert to proto
282      proto_trace = trace_to_proto(trace)
283  
284      # Reconstruct with location parameter
285      reconstructed_trace = trace_from_proto(proto_trace, location_id=uc_schema)
286  
287      # Verify that all spans have the correct v4 trace_id format
288      for reconstructed_span in reconstructed_trace.data.spans:
289          assert reconstructed_span.trace_id == trace_id_v4
290          assert reconstructed_span.trace_id.startswith(TRACE_ID_V4_PREFIX)
291          # Verify the REQUEST_ID attribute is also in v4 format
292          request_id = reconstructed_span.get_attribute("mlflow.traceRequestId")
293          assert request_id == trace_id_v4
294  
295  
296  def test_trace_info_from_proto_handles_uc_schema_location():
297      request_time = Timestamp()
298      request_time.FromMilliseconds(1234567890)
299      proto = pb.TraceInfo(
300          trace_id="test_trace_id",
301          trace_location=trace_location_to_proto(
302              TraceLocation.from_databricks_uc_schema(catalog_name="catalog", schema_name="schema")
303          ),
304          request_preview="test request",
305          response_preview="test response",
306          request_time=request_time,
307          state=TraceState.OK.to_proto(),
308          trace_metadata={
309              TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION),
310              "other_key": "other_value",
311          },
312          tags={"test_tag": "test_value"},
313      )
314      trace_info = TraceInfo.from_proto(proto)
315      assert trace_info.trace_location.uc_schema.catalog_name == "catalog"
316      assert trace_info.trace_location.uc_schema.schema_name == "schema"
317      assert trace_info.trace_metadata[TRACE_SCHEMA_VERSION_KEY] == str(TRACE_SCHEMA_VERSION)
318      assert trace_info.trace_metadata["other_key"] == "other_value"
319      assert trace_info.tags == {"test_tag": "test_value"}
320  
321  
322  def test_add_size_stats_to_trace_metadata_for_v4_trace():
323      with mlflow.start_span() as span:
324          otel_trace_id = span.trace_id.removeprefix("tr-")
325          uc_schema = "catalog.schema"
326          trace_id = f"trace:/{uc_schema}/{otel_trace_id}"
327          span.set_attribute(SpanAttributeKey.REQUEST_ID, trace_id)
328          mlflow_span = span.to_immutable_span()
329  
330      trace = Trace(
331          info=TraceInfo(
332              trace_id="test_trace_id",
333              trace_location=TraceLocation.from_databricks_uc_schema(
334                  catalog_name="catalog", schema_name="schema"
335              ),
336              request_time=0,
337              state=TraceState.OK,
338              request_preview="request",
339              response_preview="response",
340              client_request_id="client_request_id",
341              tags={"key": "value"},
342          ),
343          data=TraceData(spans=[mlflow_span]),
344      )
345      add_size_stats_to_trace_metadata(trace)
346      assert TraceMetadataKey.SIZE_STATS in trace.info.trace_metadata
347  
348  
349  def test_assessment_to_proto():
350      # Test with Feedback assessment
351      feedback = Feedback(
352          name="correctness",
353          value=0.95,
354          source=AssessmentSource(source_type="LLM_JUDGE", source_id="gpt-4"),
355          trace_id="trace:/catalog.schema/trace123",
356          metadata={"model": "gpt-4", "temperature": "0.7"},
357          span_id="span456",
358          rationale="The response is accurate and complete",
359          overrides="old_assessment_id",
360          valid=False,
361      )
362      feedback.assessment_id = "assessment789"
363  
364      proto_v4 = assessment_to_proto(feedback)
365  
366      # Validate proto structure
367      assert isinstance(proto_v4, pb.Assessment)
368      assert proto_v4.assessment_name == "correctness"
369      assert proto_v4.assessment_id == "assessment789"
370      assert proto_v4.span_id == "span456"
371      assert proto_v4.rationale == "The response is accurate and complete"
372      assert proto_v4.overrides == "old_assessment_id"
373      assert proto_v4.valid is False
374  
375      # Check TraceIdentifier
376      assert proto_v4.trace_id == "trace123"
377      assert proto_v4.trace_location.uc_schema.catalog_name == "catalog"
378      assert proto_v4.trace_location.uc_schema.schema_name == "schema"
379  
380      # Check source
381      assert proto_v4.source.source_type == ProtoAssessmentSource.SourceType.Value("LLM_JUDGE")
382      assert proto_v4.source.source_id == "gpt-4"
383  
384      # Check metadata
385      assert proto_v4.metadata["model"] == "gpt-4"
386      assert proto_v4.metadata["temperature"] == "0.7"
387  
388      # Check feedback value
389      assert proto_v4.HasField("feedback")
390      assert proto_v4.feedback.value.number_value == 0.95
391  
392      # Test with Expectation assessment
393      expectation = Expectation(
394          name="expected_answer",
395          value={"answer": "Paris", "confidence": 0.99},
396          source=AssessmentSource(source_type="HUMAN", source_id="user@example.com"),
397          trace_id="trace:/main.default/trace789",
398          metadata={"question": "What is the capital of France?"},
399          span_id="span111",
400      )
401      expectation.assessment_id = "exp_assessment123"
402  
403      proto_v4_exp = assessment_to_proto(expectation)
404  
405      assert isinstance(proto_v4_exp, pb.Assessment)
406      assert proto_v4_exp.assessment_name == "expected_answer"
407      assert proto_v4_exp.assessment_id == "exp_assessment123"
408      assert proto_v4_exp.span_id == "span111"
409  
410      # Check TraceIdentifier for expectation
411      assert proto_v4_exp.trace_id == "trace789"
412      assert proto_v4_exp.trace_location.uc_schema.catalog_name == "main"
413      assert proto_v4_exp.trace_location.uc_schema.schema_name == "default"
414  
415      # Check expectation value
416      assert proto_v4_exp.HasField("expectation")
417      assert proto_v4_exp.expectation.HasField("serialized_value")
418      assert json.loads(proto_v4_exp.expectation.serialized_value.value) == {
419          "answer": "Paris",
420          "confidence": 0.99,
421      }
422  
423  
424  def test_get_trace_id_from_assessment_proto():
425      proto = pb.Assessment(
426          trace_id="1234",
427          trace_location=trace_location_to_proto(
428              TraceLocation.from_databricks_uc_schema(catalog_name="catalog", schema_name="schema")
429          ),
430      )
431      assert get_trace_id_from_assessment_proto(proto) == "trace:/catalog.schema/1234"
432  
433      proto = assessments_pb2.Assessment(
434          trace_id="tr-123",
435      )
436      assert get_trace_id_from_assessment_proto(proto) == "tr-123"
437  
438  
439  def test_trace_location_uc_table_prefix_proto_round_trip():
440      location = UnityCatalog(
441          catalog_name="catalog",
442          schema_name="schema",
443          table_prefix="prefix",
444      )
445      location._otel_spans_table_name = "catalog.schema.prefix_otel_spans"
446      location._otel_logs_table_name = "catalog.schema.prefix_otel_logs"
447      location._annotations_table_name = "catalog.schema.prefix_otel_annotations"
448  
449      trace_location = TraceLocation(type=TraceLocationType.UC_TABLE_PREFIX, uc_table_prefix=location)
450      proto = trace_location_to_proto(trace_location)
451      assert proto.type == pb.TraceLocation.TraceLocationType.UC_TABLE_PREFIX
452      assert proto.uc_table_prefix.catalog_name == "catalog"
453      assert proto.uc_table_prefix.schema_name == "schema"
454      assert proto.uc_table_prefix.table_prefix == "prefix"
455      assert proto.uc_table_prefix.spans_table_name == "catalog.schema.prefix_otel_spans"
456      assert proto.uc_table_prefix.logs_table_name == "catalog.schema.prefix_otel_logs"
457      assert proto.uc_table_prefix.annotations_table_name == "catalog.schema.prefix_otel_annotations"
458  
459      reconstructed = trace_location_from_proto(proto)
460      assert reconstructed.type == TraceLocationType.UC_TABLE_PREFIX
461      uc = reconstructed.uc_table_prefix
462      assert uc.catalog_name == "catalog"
463      assert uc.schema_name == "schema"
464      assert uc.table_prefix == "prefix"
465      assert uc.full_otel_spans_table_name == "catalog.schema.prefix_otel_spans"
466      assert uc.full_otel_logs_table_name == "catalog.schema.prefix_otel_logs"
467      assert uc.full_annotations_table_name == "catalog.schema.prefix_otel_annotations"
468  
469  
470  def test_trace_info_from_proto_handles_uc_table_prefix_location():
471      request_time = Timestamp()
472      request_time.FromMilliseconds(1234567890)
473      proto = pb.TraceInfo(
474          trace_id="test_trace_id",
475          trace_location=trace_location_to_proto(
476              TraceLocation.from_databricks_uc_table_prefix(
477                  catalog_name="catalog", schema_name="schema", table_prefix="prefix"
478              )
479          ),
480          request_preview="test request",
481          response_preview="test response",
482          request_time=request_time,
483          state=TraceState.OK.to_proto(),
484          trace_metadata={TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION)},
485      )
486      trace_info = TraceInfo.from_proto(proto)
487      assert trace_info.trace_id == "trace:/catalog.schema.prefix/test_trace_id"
488      assert trace_info.trace_location.type == TraceLocationType.UC_TABLE_PREFIX
489      assert trace_info.trace_location.uc_table_prefix.catalog_name == "catalog"
490      assert trace_info.trace_location.uc_table_prefix.schema_name == "schema"
491      assert trace_info.trace_location.uc_table_prefix.table_prefix == "prefix"
492  
493  
494  def test_assessment_to_proto_uc_table_prefix():
495      feedback = Feedback(
496          name="correctness",
497          value=0.95,
498          source=AssessmentSource(source_type="LLM_JUDGE", source_id="gpt-4"),
499          trace_id="trace:/catalog.schema.prefix/trace123",
500      )
501      proto = assessment_to_proto(feedback)
502      assert proto.trace_id == "trace123"
503      assert proto.trace_location.type == pb.TraceLocation.TraceLocationType.UC_TABLE_PREFIX
504      assert proto.trace_location.uc_table_prefix.catalog_name == "catalog"
505      assert proto.trace_location.uc_table_prefix.schema_name == "schema"
506      assert proto.trace_location.uc_table_prefix.table_prefix == "prefix"
507  
508  
509  def test_get_trace_id_from_assessment_proto_uc_table_prefix():
510      proto = pb.Assessment(
511          trace_id="1234",
512          trace_location=trace_location_to_proto(
513              TraceLocation.from_databricks_uc_table_prefix(
514                  catalog_name="catalog", schema_name="schema", table_prefix="prefix"
515              )
516          ),
517      )
518      assert get_trace_id_from_assessment_proto(proto) == "trace:/catalog.schema.prefix/1234"