test_databricks_tracing_utils.py
1 import json 2 3 import pytest 4 from google.protobuf.timestamp_pb2 import Timestamp 5 6 import mlflow 7 from mlflow.entities import ( 8 AssessmentSource, 9 Expectation, 10 Feedback, 11 Trace, 12 TraceData, 13 TraceInfo, 14 TraceState, 15 ) 16 from mlflow.entities.trace_location import ( 17 InferenceTableLocation, 18 MlflowExperimentLocation, 19 TraceLocation, 20 TraceLocationType, 21 UCSchemaLocation, 22 UnityCatalog, 23 ) 24 from mlflow.protos import assessments_pb2 25 from mlflow.protos import databricks_tracing_pb2 as pb 26 from mlflow.protos.assessments_pb2 import AssessmentSource as ProtoAssessmentSource 27 from mlflow.tracing.constant import ( 28 TRACE_ID_V4_PREFIX, 29 TRACE_SCHEMA_VERSION, 30 TRACE_SCHEMA_VERSION_KEY, 31 SpanAttributeKey, 32 ) 33 from mlflow.tracing.utils import TraceMetadataKey, add_size_stats_to_trace_metadata 34 from mlflow.utils.databricks_tracing_utils import ( 35 assessment_to_proto, 36 get_trace_id_from_assessment_proto, 37 inference_table_location_to_proto, 38 mlflow_experiment_location_to_proto, 39 parse_uc_location, 40 trace_from_proto, 41 trace_location_from_proto, 42 trace_location_to_proto, 43 trace_to_proto, 44 uc_location_to_str, 45 uc_schema_location_from_proto, 46 uc_schema_location_to_proto, 47 ) 48 49 50 def test_trace_location_to_proto_uc_schema(): 51 trace_location = TraceLocation.from_databricks_uc_schema( 52 catalog_name="test_catalog", schema_name="test_schema" 53 ) 54 proto = trace_location_to_proto(trace_location) 55 assert proto.type == pb.TraceLocation.TraceLocationType.UC_SCHEMA 56 assert proto.uc_schema.catalog_name == "test_catalog" 57 assert proto.uc_schema.schema_name == "test_schema" 58 59 60 def test_parse_uc_location(): 61 assert parse_uc_location("catalog.schema") == ("catalog", "schema", None) 62 assert parse_uc_location("catalog.schema.prefix") == ("catalog", "schema", "prefix") 63 64 with pytest.raises(ValueError, match="Invalid UC location"): 65 parse_uc_location("a.b.c.d") 66 67 68 def test_uc_location_to_str(): 69 assert uc_location_to_str("catalog", "schema") == "catalog.schema" 70 assert uc_location_to_str("catalog", "schema", "prefix") == "catalog.schema.prefix" 71 72 73 def test_trace_location_to_proto_mlflow_experiment(): 74 trace_location = TraceLocation.from_experiment_id(experiment_id="1234") 75 proto = trace_location_to_proto(trace_location) 76 assert proto.type == pb.TraceLocation.TraceLocationType.MLFLOW_EXPERIMENT 77 assert proto.mlflow_experiment.experiment_id == "1234" 78 79 80 def test_trace_location_to_proto_inference_table(): 81 trace_location = TraceLocation( 82 type=TraceLocationType.INFERENCE_TABLE, 83 inference_table=InferenceTableLocation( 84 full_table_name="test_catalog.test_schema.test_table" 85 ), 86 ) 87 proto = trace_location_to_proto(trace_location) 88 assert proto.type == pb.TraceLocation.TraceLocationType.INFERENCE_TABLE 89 assert proto.inference_table.full_table_name == "test_catalog.test_schema.test_table" 90 91 92 def test_uc_schema_location_to_proto(): 93 schema_location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema") 94 proto = uc_schema_location_to_proto(schema_location) 95 assert proto.catalog_name == "test_catalog" 96 assert proto.schema_name == "test_schema" 97 98 99 def test_uc_schema_location_from_proto(): 100 proto = pb.UCSchemaLocation( 101 catalog_name="test_catalog", 102 schema_name="test_schema", 103 otel_spans_table_name="test_spans", 104 otel_logs_table_name="test_logs", 105 ) 106 schema_location = uc_schema_location_from_proto(proto) 107 assert schema_location.catalog_name == "test_catalog" 108 assert schema_location.schema_name == "test_schema" 109 assert schema_location.full_otel_spans_table_name == "test_catalog.test_schema.test_spans" 110 assert schema_location.full_otel_logs_table_name == "test_catalog.test_schema.test_logs" 111 112 113 def test_inference_table_location_to_proto(): 114 table_location = InferenceTableLocation(full_table_name="test_catalog.test_schema.test_table") 115 proto = inference_table_location_to_proto(table_location) 116 assert proto.full_table_name == "test_catalog.test_schema.test_table" 117 118 119 def test_mlflow_experiment_location_to_proto(): 120 experiment_location = MlflowExperimentLocation(experiment_id="1234") 121 proto = mlflow_experiment_location_to_proto(experiment_location) 122 assert proto.experiment_id == "1234" 123 124 125 def test_schema_location_to_proto(): 126 schema_location = UCSchemaLocation( 127 catalog_name="test_catalog", 128 schema_name="test_schema", 129 ) 130 schema_location._otel_spans_table_name = "test_spans" 131 schema_location._otel_logs_table_name = "test_logs" 132 proto = uc_schema_location_to_proto(schema_location) 133 assert proto.catalog_name == "test_catalog" 134 assert proto.schema_name == "test_schema" 135 assert proto.otel_spans_table_name == "test_spans" 136 assert proto.otel_logs_table_name == "test_logs" 137 138 139 def test_trace_location_from_proto_uc_schema(): 140 proto = pb.TraceLocation( 141 type=pb.TraceLocation.TraceLocationType.UC_SCHEMA, 142 uc_schema=pb.UCSchemaLocation( 143 catalog_name="catalog", 144 schema_name="schema", 145 otel_spans_table_name="spans", 146 otel_logs_table_name="logs", 147 ), 148 ) 149 trace_location = trace_location_from_proto(proto) 150 assert trace_location.uc_schema.catalog_name == "catalog" 151 assert trace_location.uc_schema.schema_name == "schema" 152 assert trace_location.uc_schema.full_otel_spans_table_name == "catalog.schema.spans" 153 assert trace_location.uc_schema.full_otel_logs_table_name == "catalog.schema.logs" 154 155 156 def test_trace_location_from_proto_mlflow_experiment(): 157 proto = pb.TraceLocation( 158 type=pb.TraceLocation.TraceLocationType.MLFLOW_EXPERIMENT, 159 mlflow_experiment=mlflow_experiment_location_to_proto( 160 MlflowExperimentLocation(experiment_id="1234") 161 ), 162 ) 163 trace_location = trace_location_from_proto(proto) 164 assert trace_location.type == TraceLocationType.MLFLOW_EXPERIMENT 165 assert trace_location.mlflow_experiment.experiment_id == "1234" 166 167 168 def test_trace_location_from_proto_inference_table(): 169 proto = pb.TraceLocation( 170 type=pb.TraceLocation.TraceLocationType.INFERENCE_TABLE, 171 inference_table=inference_table_location_to_proto( 172 InferenceTableLocation(full_table_name="test_catalog.test_schema.test_table") 173 ), 174 ) 175 trace_location = trace_location_from_proto(proto) 176 assert trace_location.type == TraceLocationType.INFERENCE_TABLE 177 assert trace_location.inference_table.full_table_name == "test_catalog.test_schema.test_table" 178 179 180 def test_trace_info_to_v4_proto(): 181 otel_trace_id = "2efb31387ff19263f92b2c0a61b0a8bc" 182 trace_id = f"trace:/catalog.schema/{otel_trace_id}" 183 trace_info = TraceInfo( 184 trace_id=trace_id, 185 trace_location=TraceLocation.from_databricks_uc_schema( 186 catalog_name="catalog", schema_name="schema" 187 ), 188 request_time=0, 189 state=TraceState.OK, 190 request_preview="request", 191 response_preview="response", 192 client_request_id="client_request_id", 193 tags={"key": "value"}, 194 ) 195 proto_trace_info = trace_info.to_proto() 196 assert proto_trace_info.trace_id == otel_trace_id 197 assert proto_trace_info.trace_location.uc_schema.catalog_name == "catalog" 198 assert proto_trace_info.trace_location.uc_schema.schema_name == "schema" 199 assert proto_trace_info.state == 1 200 assert proto_trace_info.request_preview == "request" 201 assert proto_trace_info.response_preview == "response" 202 assert proto_trace_info.client_request_id == "client_request_id" 203 assert proto_trace_info.tags == {"key": "value"} 204 assert len(proto_trace_info.assessments) == 0 205 206 trace_info_from_proto = TraceInfo.from_proto(proto_trace_info) 207 assert trace_info_from_proto == trace_info 208 209 210 def test_trace_to_proto_and_from_proto(): 211 with mlflow.start_span() as span: 212 otel_trace_id = span.trace_id.removeprefix("tr-") 213 uc_schema = "catalog.schema" 214 trace_id = f"trace:/{uc_schema}/{otel_trace_id}" 215 span.set_attribute(SpanAttributeKey.REQUEST_ID, trace_id) 216 mlflow_span = span.to_immutable_span() 217 218 assert mlflow_span.trace_id == trace_id 219 trace = Trace( 220 info=TraceInfo( 221 trace_id=trace_id, 222 trace_location=TraceLocation.from_databricks_uc_schema( 223 catalog_name="catalog", schema_name="schema" 224 ), 225 request_time=0, 226 state=TraceState.OK, 227 request_preview="request", 228 response_preview="response", 229 client_request_id="client_request_id", 230 tags={"key": "value"}, 231 ), 232 data=TraceData(spans=[mlflow_span]), 233 ) 234 235 proto_trace_v4 = trace_to_proto(trace) 236 237 assert proto_trace_v4.trace_info.trace_id == otel_trace_id 238 assert proto_trace_v4.trace_info.trace_location.uc_schema.catalog_name == "catalog" 239 assert proto_trace_v4.trace_info.trace_location.uc_schema.schema_name == "schema" 240 assert len(proto_trace_v4.spans) == len(trace.data.spans) 241 242 reconstructed_trace = trace_from_proto(proto_trace_v4, location_id="catalog.schema") 243 244 assert reconstructed_trace.info.trace_id == trace_id 245 assert reconstructed_trace.info.trace_location.uc_schema.catalog_name == "catalog" 246 assert reconstructed_trace.info.trace_location.uc_schema.schema_name == "schema" 247 assert len(reconstructed_trace.data.spans) == len(trace.data.spans) 248 249 original_span = trace.data.spans[0] 250 reconstructed_span = reconstructed_trace.data.spans[0] 251 252 assert reconstructed_span.name == original_span.name 253 assert reconstructed_span.span_id == original_span.span_id 254 assert reconstructed_span.trace_id == original_span.trace_id 255 assert reconstructed_span.inputs == original_span.inputs 256 assert reconstructed_span.outputs == original_span.outputs 257 assert reconstructed_span.get_attribute("custom") == original_span.get_attribute("custom") 258 259 260 def test_trace_from_proto_with_location_preserves_v4_trace_id(): 261 with mlflow.start_span() as span: 262 otel_trace_id = span.trace_id.removeprefix("tr-") 263 uc_schema = "catalog.schema" 264 trace_id_v4 = f"{TRACE_ID_V4_PREFIX}{uc_schema}/{otel_trace_id}" 265 span.set_attribute(SpanAttributeKey.REQUEST_ID, trace_id_v4) 266 mlflow_span = span.to_immutable_span() 267 268 # Create trace with v4 trace ID 269 trace = Trace( 270 info=TraceInfo( 271 trace_id=trace_id_v4, 272 trace_location=TraceLocation.from_databricks_uc_schema( 273 catalog_name="catalog", schema_name="schema" 274 ), 275 request_time=0, 276 state=TraceState.OK, 277 ), 278 data=TraceData(spans=[mlflow_span]), 279 ) 280 281 # Convert to proto 282 proto_trace = trace_to_proto(trace) 283 284 # Reconstruct with location parameter 285 reconstructed_trace = trace_from_proto(proto_trace, location_id=uc_schema) 286 287 # Verify that all spans have the correct v4 trace_id format 288 for reconstructed_span in reconstructed_trace.data.spans: 289 assert reconstructed_span.trace_id == trace_id_v4 290 assert reconstructed_span.trace_id.startswith(TRACE_ID_V4_PREFIX) 291 # Verify the REQUEST_ID attribute is also in v4 format 292 request_id = reconstructed_span.get_attribute("mlflow.traceRequestId") 293 assert request_id == trace_id_v4 294 295 296 def test_trace_info_from_proto_handles_uc_schema_location(): 297 request_time = Timestamp() 298 request_time.FromMilliseconds(1234567890) 299 proto = pb.TraceInfo( 300 trace_id="test_trace_id", 301 trace_location=trace_location_to_proto( 302 TraceLocation.from_databricks_uc_schema(catalog_name="catalog", schema_name="schema") 303 ), 304 request_preview="test request", 305 response_preview="test response", 306 request_time=request_time, 307 state=TraceState.OK.to_proto(), 308 trace_metadata={ 309 TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION), 310 "other_key": "other_value", 311 }, 312 tags={"test_tag": "test_value"}, 313 ) 314 trace_info = TraceInfo.from_proto(proto) 315 assert trace_info.trace_location.uc_schema.catalog_name == "catalog" 316 assert trace_info.trace_location.uc_schema.schema_name == "schema" 317 assert trace_info.trace_metadata[TRACE_SCHEMA_VERSION_KEY] == str(TRACE_SCHEMA_VERSION) 318 assert trace_info.trace_metadata["other_key"] == "other_value" 319 assert trace_info.tags == {"test_tag": "test_value"} 320 321 322 def test_add_size_stats_to_trace_metadata_for_v4_trace(): 323 with mlflow.start_span() as span: 324 otel_trace_id = span.trace_id.removeprefix("tr-") 325 uc_schema = "catalog.schema" 326 trace_id = f"trace:/{uc_schema}/{otel_trace_id}" 327 span.set_attribute(SpanAttributeKey.REQUEST_ID, trace_id) 328 mlflow_span = span.to_immutable_span() 329 330 trace = Trace( 331 info=TraceInfo( 332 trace_id="test_trace_id", 333 trace_location=TraceLocation.from_databricks_uc_schema( 334 catalog_name="catalog", schema_name="schema" 335 ), 336 request_time=0, 337 state=TraceState.OK, 338 request_preview="request", 339 response_preview="response", 340 client_request_id="client_request_id", 341 tags={"key": "value"}, 342 ), 343 data=TraceData(spans=[mlflow_span]), 344 ) 345 add_size_stats_to_trace_metadata(trace) 346 assert TraceMetadataKey.SIZE_STATS in trace.info.trace_metadata 347 348 349 def test_assessment_to_proto(): 350 # Test with Feedback assessment 351 feedback = Feedback( 352 name="correctness", 353 value=0.95, 354 source=AssessmentSource(source_type="LLM_JUDGE", source_id="gpt-4"), 355 trace_id="trace:/catalog.schema/trace123", 356 metadata={"model": "gpt-4", "temperature": "0.7"}, 357 span_id="span456", 358 rationale="The response is accurate and complete", 359 overrides="old_assessment_id", 360 valid=False, 361 ) 362 feedback.assessment_id = "assessment789" 363 364 proto_v4 = assessment_to_proto(feedback) 365 366 # Validate proto structure 367 assert isinstance(proto_v4, pb.Assessment) 368 assert proto_v4.assessment_name == "correctness" 369 assert proto_v4.assessment_id == "assessment789" 370 assert proto_v4.span_id == "span456" 371 assert proto_v4.rationale == "The response is accurate and complete" 372 assert proto_v4.overrides == "old_assessment_id" 373 assert proto_v4.valid is False 374 375 # Check TraceIdentifier 376 assert proto_v4.trace_id == "trace123" 377 assert proto_v4.trace_location.uc_schema.catalog_name == "catalog" 378 assert proto_v4.trace_location.uc_schema.schema_name == "schema" 379 380 # Check source 381 assert proto_v4.source.source_type == ProtoAssessmentSource.SourceType.Value("LLM_JUDGE") 382 assert proto_v4.source.source_id == "gpt-4" 383 384 # Check metadata 385 assert proto_v4.metadata["model"] == "gpt-4" 386 assert proto_v4.metadata["temperature"] == "0.7" 387 388 # Check feedback value 389 assert proto_v4.HasField("feedback") 390 assert proto_v4.feedback.value.number_value == 0.95 391 392 # Test with Expectation assessment 393 expectation = Expectation( 394 name="expected_answer", 395 value={"answer": "Paris", "confidence": 0.99}, 396 source=AssessmentSource(source_type="HUMAN", source_id="user@example.com"), 397 trace_id="trace:/main.default/trace789", 398 metadata={"question": "What is the capital of France?"}, 399 span_id="span111", 400 ) 401 expectation.assessment_id = "exp_assessment123" 402 403 proto_v4_exp = assessment_to_proto(expectation) 404 405 assert isinstance(proto_v4_exp, pb.Assessment) 406 assert proto_v4_exp.assessment_name == "expected_answer" 407 assert proto_v4_exp.assessment_id == "exp_assessment123" 408 assert proto_v4_exp.span_id == "span111" 409 410 # Check TraceIdentifier for expectation 411 assert proto_v4_exp.trace_id == "trace789" 412 assert proto_v4_exp.trace_location.uc_schema.catalog_name == "main" 413 assert proto_v4_exp.trace_location.uc_schema.schema_name == "default" 414 415 # Check expectation value 416 assert proto_v4_exp.HasField("expectation") 417 assert proto_v4_exp.expectation.HasField("serialized_value") 418 assert json.loads(proto_v4_exp.expectation.serialized_value.value) == { 419 "answer": "Paris", 420 "confidence": 0.99, 421 } 422 423 424 def test_get_trace_id_from_assessment_proto(): 425 proto = pb.Assessment( 426 trace_id="1234", 427 trace_location=trace_location_to_proto( 428 TraceLocation.from_databricks_uc_schema(catalog_name="catalog", schema_name="schema") 429 ), 430 ) 431 assert get_trace_id_from_assessment_proto(proto) == "trace:/catalog.schema/1234" 432 433 proto = assessments_pb2.Assessment( 434 trace_id="tr-123", 435 ) 436 assert get_trace_id_from_assessment_proto(proto) == "tr-123" 437 438 439 def test_trace_location_uc_table_prefix_proto_round_trip(): 440 location = UnityCatalog( 441 catalog_name="catalog", 442 schema_name="schema", 443 table_prefix="prefix", 444 ) 445 location._otel_spans_table_name = "catalog.schema.prefix_otel_spans" 446 location._otel_logs_table_name = "catalog.schema.prefix_otel_logs" 447 location._annotations_table_name = "catalog.schema.prefix_otel_annotations" 448 449 trace_location = TraceLocation(type=TraceLocationType.UC_TABLE_PREFIX, uc_table_prefix=location) 450 proto = trace_location_to_proto(trace_location) 451 assert proto.type == pb.TraceLocation.TraceLocationType.UC_TABLE_PREFIX 452 assert proto.uc_table_prefix.catalog_name == "catalog" 453 assert proto.uc_table_prefix.schema_name == "schema" 454 assert proto.uc_table_prefix.table_prefix == "prefix" 455 assert proto.uc_table_prefix.spans_table_name == "catalog.schema.prefix_otel_spans" 456 assert proto.uc_table_prefix.logs_table_name == "catalog.schema.prefix_otel_logs" 457 assert proto.uc_table_prefix.annotations_table_name == "catalog.schema.prefix_otel_annotations" 458 459 reconstructed = trace_location_from_proto(proto) 460 assert reconstructed.type == TraceLocationType.UC_TABLE_PREFIX 461 uc = reconstructed.uc_table_prefix 462 assert uc.catalog_name == "catalog" 463 assert uc.schema_name == "schema" 464 assert uc.table_prefix == "prefix" 465 assert uc.full_otel_spans_table_name == "catalog.schema.prefix_otel_spans" 466 assert uc.full_otel_logs_table_name == "catalog.schema.prefix_otel_logs" 467 assert uc.full_annotations_table_name == "catalog.schema.prefix_otel_annotations" 468 469 470 def test_trace_info_from_proto_handles_uc_table_prefix_location(): 471 request_time = Timestamp() 472 request_time.FromMilliseconds(1234567890) 473 proto = pb.TraceInfo( 474 trace_id="test_trace_id", 475 trace_location=trace_location_to_proto( 476 TraceLocation.from_databricks_uc_table_prefix( 477 catalog_name="catalog", schema_name="schema", table_prefix="prefix" 478 ) 479 ), 480 request_preview="test request", 481 response_preview="test response", 482 request_time=request_time, 483 state=TraceState.OK.to_proto(), 484 trace_metadata={TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION)}, 485 ) 486 trace_info = TraceInfo.from_proto(proto) 487 assert trace_info.trace_id == "trace:/catalog.schema.prefix/test_trace_id" 488 assert trace_info.trace_location.type == TraceLocationType.UC_TABLE_PREFIX 489 assert trace_info.trace_location.uc_table_prefix.catalog_name == "catalog" 490 assert trace_info.trace_location.uc_table_prefix.schema_name == "schema" 491 assert trace_info.trace_location.uc_table_prefix.table_prefix == "prefix" 492 493 494 def test_assessment_to_proto_uc_table_prefix(): 495 feedback = Feedback( 496 name="correctness", 497 value=0.95, 498 source=AssessmentSource(source_type="LLM_JUDGE", source_id="gpt-4"), 499 trace_id="trace:/catalog.schema.prefix/trace123", 500 ) 501 proto = assessment_to_proto(feedback) 502 assert proto.trace_id == "trace123" 503 assert proto.trace_location.type == pb.TraceLocation.TraceLocationType.UC_TABLE_PREFIX 504 assert proto.trace_location.uc_table_prefix.catalog_name == "catalog" 505 assert proto.trace_location.uc_table_prefix.schema_name == "schema" 506 assert proto.trace_location.uc_table_prefix.table_prefix == "prefix" 507 508 509 def test_get_trace_id_from_assessment_proto_uc_table_prefix(): 510 proto = pb.Assessment( 511 trace_id="1234", 512 trace_location=trace_location_to_proto( 513 TraceLocation.from_databricks_uc_table_prefix( 514 catalog_name="catalog", schema_name="schema", table_prefix="prefix" 515 ) 516 ), 517 ) 518 assert get_trace_id_from_assessment_proto(proto) == "trace:/catalog.schema.prefix/1234"