test_otel_loading.py
1 import uuid 2 from pathlib import Path 3 4 import pytest 5 from opentelemetry import trace as otel_trace 6 from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter 7 from opentelemetry.sdk.resources import Resource as OTelSDKResource 8 from opentelemetry.sdk.trace import TracerProvider 9 from opentelemetry.sdk.trace.export import SimpleSpanProcessor 10 from opentelemetry.trace import Status, StatusCode 11 from opentelemetry.util._once import Once 12 13 import mlflow 14 from mlflow.entities import SpanStatusCode 15 from mlflow.entities.assessment import AssessmentSource, Expectation, Feedback 16 from mlflow.entities.assessment_source import AssessmentSourceType 17 from mlflow.server import handlers 18 from mlflow.server.fastapi_app import app 19 from mlflow.server.handlers import initialize_backend_stores 20 from mlflow.tracing.constant import SpanAttributeKey 21 from mlflow.tracing.otel.translation.base import OtelSchemaTranslator 22 from mlflow.tracing.otel.translation.genai_semconv import GenAiTranslator 23 from mlflow.tracing.otel.translation.open_inference import OpenInferenceTranslator 24 from mlflow.tracing.otel.translation.traceloop import TraceloopTranslator 25 from mlflow.tracing.provider import _get_trace_exporter 26 from mlflow.tracing.utils import encode_trace_id 27 from mlflow.tracing.utils.otlp import MLFLOW_EXPERIMENT_ID_HEADER 28 from mlflow.tracking._tracking_service.utils import _use_tracking_uri 29 from mlflow.version import IS_TRACING_SDK_ONLY 30 31 from tests.helper_functions import get_safe_port 32 from tests.tracking.integration_test_utils import ServerThread 33 34 if IS_TRACING_SDK_ONLY: 35 pytest.skip("OTel get_trace tests require full MLflow server", allow_module_level=True) 36 37 38 @pytest.fixture 39 def mlflow_server(tmp_path: Path, db_uri: str): 40 artifact_uri = tmp_path.joinpath("artifacts").as_uri() 41 42 # Force-reset backend stores before each test 43 handlers._tracking_store = None 44 handlers._model_registry_store = None 45 initialize_backend_stores(db_uri, default_artifact_root=artifact_uri) 46 47 with ServerThread(app, get_safe_port()) as url: 48 yield url 49 50 51 @pytest.fixture(autouse=True) 52 def tracking_uri_setup(mlflow_server): 53 with _use_tracking_uri(mlflow_server): 54 yield 55 56 57 @pytest.fixture(params=[True, False]) 58 def is_async(request, monkeypatch): 59 monkeypatch.setenv("MLFLOW_ASYNC_TRACE_LOGGING", "true" if request.param else "false") 60 61 62 def _flush_async_logging(): 63 exporter = _get_trace_exporter() 64 assert hasattr(exporter, "_async_queue"), "Async queue is not initialized" 65 exporter._async_queue.flush(terminate=True) 66 67 68 def create_tracer(mlflow_server: str, experiment_id: str, service_name: str = "test-service"): 69 resource = OTelSDKResource.create({"service.name": service_name, "service.version": "1.0.0"}) 70 tracer_provider = TracerProvider(resource=resource) 71 72 exporter = OTLPSpanExporter( 73 endpoint=f"{mlflow_server}/v1/traces", 74 headers={MLFLOW_EXPERIMENT_ID_HEADER: experiment_id}, 75 timeout=10, 76 ) 77 78 span_processor = SimpleSpanProcessor(exporter) 79 tracer_provider.add_span_processor(span_processor) 80 81 # Reset the global tracer provider 82 otel_trace._TRACER_PROVIDER_SET_ONCE = Once() 83 otel_trace._TRACER_PROVIDER = None 84 otel_trace.set_tracer_provider(tracer_provider) 85 86 return otel_trace.get_tracer(__name__) 87 88 89 def test_get_trace_for_otel_sent_span(mlflow_server: str, is_async): 90 experiment = mlflow.set_experiment("otel-get-trace-test") 91 experiment_id = experiment.experiment_id 92 93 tracer = create_tracer(mlflow_server, experiment_id, "test-service-get-trace") 94 95 # Create a span with various attributes to test conversion 96 with tracer.start_as_current_span("otel-test-span") as span: 97 span.set_attribute("test.string", "string-value") 98 span.set_attribute("test.number", 42) 99 span.set_attribute("test.boolean", True) 100 span.set_attribute("operation.type", "llm_request") 101 102 # Capture the OTel trace ID 103 otel_trace_id = span.get_span_context().trace_id 104 assert span.get_span_context().is_valid 105 assert otel_trace_id != 0 106 107 if is_async: 108 _flush_async_logging() 109 110 traces = mlflow.search_traces( 111 locations=[experiment_id], include_spans=False, return_type="list" 112 ) 113 114 assert len(traces) > 0, "No traces found in the database" 115 116 trace_id = traces[0].info.trace_id 117 retrieved_trace = mlflow.get_trace(trace_id) 118 119 assert retrieved_trace.info.trace_id == trace_id 120 assert retrieved_trace.info.trace_location.mlflow_experiment.experiment_id == experiment_id 121 122 assert len(retrieved_trace.data.spans) == 1 123 span = retrieved_trace.data.spans[0] 124 125 assert span.name == "otel-test-span" 126 assert span.trace_id == trace_id 127 # OTel spans default to UNSET status if not explicitly set 128 assert span.status.status_code == SpanStatusCode.UNSET 129 130 # Verify attributes were converted correctly 131 assert span.attributes["test.string"] == "string-value" 132 assert span.attributes["test.number"] == 42 133 assert span.attributes["test.boolean"] is True 134 assert span.attributes["operation.type"] == "llm_request" 135 136 # Verify the trace ID matches the expected format 137 expected_trace_id = f"tr-{encode_trace_id(otel_trace_id)}" 138 assert trace_id == expected_trace_id 139 140 141 def test_get_trace_for_otel_nested_spans(mlflow_server: str, is_async): 142 experiment = mlflow.set_experiment("otel-nested-spans-test") 143 experiment_id = experiment.experiment_id 144 145 tracer = create_tracer(mlflow_server, experiment_id, "nested-test-service") 146 147 # Create nested spans 148 with tracer.start_as_current_span("parent-span") as parent_span: 149 parent_span.set_attribute("span.level", "parent") 150 151 with tracer.start_as_current_span("child-span") as child_span: 152 child_span.set_attribute("span.level", "child") 153 child_span.set_attribute("child.operation", "process_data") 154 155 if is_async: 156 _flush_async_logging() 157 158 traces = mlflow.search_traces( 159 locations=[experiment_id], include_spans=False, return_type="list" 160 ) 161 162 assert len(traces) > 0, "No traces found in the database" 163 164 trace_id = traces[0].info.trace_id 165 retrieved_trace = mlflow.get_trace(trace_id) 166 167 assert len(retrieved_trace.data.spans) == 2 168 169 spans_by_name = {span.name: span for span in retrieved_trace.data.spans} 170 171 assert "parent-span" in spans_by_name 172 assert "child-span" in spans_by_name 173 174 parent_span = spans_by_name["parent-span"] 175 child_span = spans_by_name["child-span"] 176 177 assert parent_span.attributes["span.level"] == "parent" 178 assert parent_span.parent_id is None # Root span has no parent 179 180 assert child_span.attributes["span.level"] == "child" 181 assert child_span.attributes["child.operation"] == "process_data" 182 assert child_span.parent_id == parent_span.span_id # Child should reference parent 183 184 185 def test_get_trace_with_otel_span_events(mlflow_server: str, is_async): 186 experiment = mlflow.set_experiment("otel-events-test") 187 experiment_id = experiment.experiment_id 188 189 tracer = create_tracer(mlflow_server, experiment_id, "events-test-service") 190 191 # Create span with events using OTel SDK 192 with tracer.start_as_current_span("span-with-events") as span: 193 span.add_event("test_event", attributes={"event.type": "processing"}) 194 195 if is_async: 196 _flush_async_logging() 197 198 traces = mlflow.search_traces( 199 locations=[experiment_id], include_spans=False, return_type="list" 200 ) 201 202 trace_id = traces[0].info.trace_id 203 retrieved_trace = mlflow.get_trace(trace_id) 204 205 assert len(retrieved_trace.data.spans) == 1 206 retrieved_span = retrieved_trace.data.spans[0] 207 208 assert retrieved_span.name == "span-with-events" 209 assert len(retrieved_span.events) == 1 210 event = retrieved_span.events[0] 211 assert event.name == "test_event" 212 assert event.attributes["event.type"] == "processing" 213 214 215 def test_get_trace_nonexistent_otel_trace(mlflow_server: str): 216 # Create a fake trace ID in OTel format 217 fake_otel_trace_id = uuid.uuid4().hex 218 fake_trace_id = f"tr-{fake_otel_trace_id}" 219 220 # MLflow get_trace returns None for non-existent traces 221 trace = mlflow.get_trace(fake_trace_id) 222 assert trace is None 223 224 225 def test_get_trace_with_otel_span_status(mlflow_server: str, is_async): 226 experiment = mlflow.set_experiment("otel-status-test") 227 experiment_id = experiment.experiment_id 228 229 tracer = create_tracer(mlflow_server, experiment_id, "status-test-service") 230 231 # Create span with error status using OTel SDK 232 with tracer.start_as_current_span("error-span") as span: 233 span.set_status(Status(StatusCode.ERROR, "Something went wrong")) 234 235 if is_async: 236 _flush_async_logging() 237 238 traces = mlflow.search_traces( 239 locations=[experiment_id], include_spans=False, return_type="list" 240 ) 241 242 trace_id = traces[0].info.trace_id 243 retrieved_trace = mlflow.get_trace(trace_id) 244 245 assert len(retrieved_trace.data.spans) == 1 246 retrieved_span = retrieved_trace.data.spans[0] 247 248 assert retrieved_span.name == "error-span" 249 assert retrieved_span.status.status_code == SpanStatusCode.ERROR 250 assert "Something went wrong" in retrieved_span.status.description 251 252 253 def test_set_trace_tag_on_otel_trace(mlflow_server: str, is_async): 254 experiment = mlflow.set_experiment("otel-tag-test") 255 experiment_id = experiment.experiment_id 256 257 tracer = create_tracer(mlflow_server, experiment_id, "tag-test-service") 258 259 with tracer.start_as_current_span("tagged-span") as span: 260 span.set_attribute("test.attribute", "value") 261 262 if is_async: 263 _flush_async_logging() 264 265 traces = mlflow.search_traces( 266 locations=[experiment_id], include_spans=False, return_type="list" 267 ) 268 trace_id = traces[0].info.trace_id 269 270 mlflow.set_trace_tag(trace_id, "environment", "test") 271 mlflow.set_trace_tag(trace_id, "version", "1.0.0") 272 273 retrieved_trace = mlflow.get_trace(trace_id) 274 assert retrieved_trace.info.tags["environment"] == "test" 275 assert retrieved_trace.info.tags["version"] == "1.0.0" 276 277 278 def test_log_expectation_on_otel_trace(mlflow_server: str, is_async): 279 experiment = mlflow.set_experiment("otel-expectation-test") 280 experiment_id = experiment.experiment_id 281 282 tracer = create_tracer(mlflow_server, experiment_id, "expectation-test-service") 283 284 # Create a span that represents a question-answer scenario 285 with tracer.start_as_current_span("qa-span") as span: 286 span.set_attribute("question", "What is MLflow?") 287 span.set_attribute("answer", "MLflow is an open-source ML platform") 288 289 if is_async: 290 _flush_async_logging() 291 292 traces = mlflow.search_traces( 293 locations=[experiment_id], include_spans=False, return_type="list" 294 ) 295 trace_id = traces[0].info.trace_id 296 297 expectation_source = AssessmentSource( 298 source_type=AssessmentSourceType.HUMAN, source_id="test_user@example.com" 299 ) 300 301 logged_assessment = mlflow.log_expectation( 302 trace_id=trace_id, 303 name="expected_answer", 304 value="MLflow is an open-source machine learning platform", 305 source=expectation_source, 306 metadata={"confidence": "high", "reviewed_by": "expert"}, 307 ) 308 expectation = mlflow.get_assessment( 309 trace_id=trace_id, assessment_id=logged_assessment.assessment_id 310 ) 311 assert expectation.name == "expected_answer" 312 assert expectation.value == "MLflow is an open-source machine learning platform" 313 assert expectation.source.source_type == AssessmentSourceType.HUMAN 314 assert expectation.metadata["confidence"] == "high" 315 316 317 def test_log_feedback_on_otel_trace(mlflow_server: str, is_async): 318 experiment = mlflow.set_experiment("otel-feedback-test") 319 experiment_id = experiment.experiment_id 320 321 tracer = create_tracer(mlflow_server, experiment_id, "feedback-test-service") 322 323 # Create a span representing a model prediction 324 with tracer.start_as_current_span("prediction-span") as span: 325 span.set_attribute("model", "gpt-4") 326 span.set_attribute("prediction", "The weather is sunny") 327 328 if is_async: 329 _flush_async_logging() 330 331 traces = mlflow.search_traces( 332 locations=[experiment_id], include_spans=False, return_type="list" 333 ) 334 assert len(traces) > 0, "No traces found in the database" 335 trace_id = traces[0].info.trace_id 336 337 llm_source = AssessmentSource( 338 source_type=AssessmentSourceType.LLM_JUDGE, source_id="gpt-4o-mini" 339 ) 340 341 logged_quality = mlflow.log_feedback( 342 trace_id=trace_id, 343 name="quality_score", 344 value=8.5, 345 source=llm_source, 346 metadata={"scale": "1-10", "criterion": "accuracy"}, 347 ) 348 feedback = mlflow.get_assessment(trace_id=trace_id, assessment_id=logged_quality.assessment_id) 349 assert feedback.name == "quality_score" 350 assert feedback.value == 8.5 351 assert feedback.source.source_type == AssessmentSourceType.LLM_JUDGE 352 353 human_source = AssessmentSource( 354 source_type=AssessmentSourceType.HUMAN, source_id="reviewer@example.com" 355 ) 356 357 logged_approval = mlflow.log_feedback( 358 trace_id=trace_id, 359 name="approved", 360 value=True, 361 source=human_source, 362 metadata={"review_date": "2024-01-15"}, 363 ) 364 feedback = mlflow.get_assessment(trace_id=trace_id, assessment_id=logged_approval.assessment_id) 365 assert feedback.name == "approved" 366 assert feedback.value is True 367 assert feedback.source.source_type == AssessmentSourceType.HUMAN 368 369 370 def test_multiple_assessments_on_otel_trace(mlflow_server: str, is_async): 371 experiment = mlflow.set_experiment("otel-multi-assessment-test") 372 experiment_id = experiment.experiment_id 373 374 tracer = create_tracer(mlflow_server, experiment_id, "multi-assessment-test-service") 375 376 # Create a complex trace with nested spans 377 with tracer.start_as_current_span("conversation") as parent_span: 378 parent_span.set_attribute("user_query", "Explain quantum computing") 379 380 with tracer.start_as_current_span("retrieval") as retrieval_span: 381 retrieval_span.set_attribute("documents_found", 5) 382 383 with tracer.start_as_current_span("generation") as generation_span: 384 generation_span.set_attribute("model", "gpt-4") 385 generation_span.set_attribute("response", "Quantum computing uses quantum mechanics...") 386 387 if is_async: 388 _flush_async_logging() 389 390 traces = mlflow.search_traces( 391 locations=[experiment_id], include_spans=False, return_type="list" 392 ) 393 trace_id = traces[0].info.trace_id 394 395 mlflow.set_trace_tag(trace_id, "topic", "quantum_computing") 396 mlflow.set_trace_tag(trace_id, "complexity", "high") 397 398 human_source = AssessmentSource(AssessmentSourceType.HUMAN, "expert@physics.edu") 399 llm_source = AssessmentSource(AssessmentSourceType.LLM_JUDGE, "claude-3") 400 401 expectation = Expectation( 402 name="expected_quality", 403 value="Should explain quantum superposition and entanglement", 404 source=human_source, 405 ) 406 mlflow.log_assessment(trace_id=trace_id, assessment=expectation) 407 feedback_items = [ 408 Feedback(name="accuracy", value=9.0, source=llm_source, metadata={"max_score": "10"}), 409 Feedback(name="clarity", value=8.5, source=llm_source, metadata={"max_score": "10"}), 410 Feedback( 411 name="helpfulness", 412 value=True, 413 source=human_source, 414 metadata={"reviewer_expertise": "quantum_physics"}, 415 ), 416 Feedback( 417 name="contains_errors", 418 value=False, 419 source=human_source, 420 metadata={"fact_checked": "True"}, 421 ), 422 ] 423 424 for feedback in feedback_items: 425 mlflow.log_assessment(trace_id=trace_id, assessment=feedback) 426 427 retrieved_trace = mlflow.get_trace(trace_id) 428 assessments = retrieved_trace.info.assessments 429 assert len(assessments) == 5 430 assert [a.name for a in assessments] == [ 431 "expected_quality", 432 "accuracy", 433 "clarity", 434 "helpfulness", 435 "contains_errors", 436 ] 437 438 assert retrieved_trace.info.tags["topic"] == "quantum_computing" 439 assert retrieved_trace.info.tags["complexity"] == "high" 440 441 assert len(retrieved_trace.data.spans) == 3 442 span_names = {span.name for span in retrieved_trace.data.spans} 443 assert span_names == {"conversation", "retrieval", "generation"} 444 445 tagged_traces = mlflow.search_traces( 446 locations=[experiment_id], 447 filter_string='tags.topic = "quantum_computing"', 448 return_type="list", 449 ) 450 assert len(tagged_traces) == 1 451 assert tagged_traces[0].info.trace_id == trace_id 452 453 454 def test_span_kind_translation(mlflow_server: str, is_async): 455 experiment = mlflow.set_experiment("span-kind-translation-test") 456 experiment_id = experiment.experiment_id 457 458 tracer = create_tracer(mlflow_server, experiment_id, "span-kind-translation-test-service") 459 460 with tracer.start_as_current_span("llm-call") as span: 461 span.set_attribute(OpenInferenceTranslator.SPAN_KIND_ATTRIBUTE_KEY, "LLM") 462 463 with tracer.start_as_current_span("retriever-call") as span: 464 span.set_attribute(OpenInferenceTranslator.SPAN_KIND_ATTRIBUTE_KEY, "RETRIEVER") 465 466 with tracer.start_as_current_span("tool-call") as span: 467 span.set_attribute(TraceloopTranslator.SPAN_KIND_ATTRIBUTE_KEY, "tool") 468 469 if is_async: 470 _flush_async_logging() 471 472 traces = mlflow.search_traces( 473 locations=[experiment_id], include_spans=False, return_type="list" 474 ) 475 476 assert len(traces) == 3 477 for trace_info in traces: 478 retrieved_trace = mlflow.get_trace(trace_info.info.trace_id) 479 for span in retrieved_trace.data.spans: 480 if span.name == "llm-call": 481 assert span.span_type == "LLM" 482 elif span.name == "retriever-call": 483 assert span.span_type == "RETRIEVER" 484 elif span.name == "tool-call": 485 assert span.span_type == "TOOL" 486 487 488 @pytest.mark.parametrize( 489 "translator", [GenAiTranslator, OpenInferenceTranslator, TraceloopTranslator] 490 ) 491 def test_span_inputs_outputs_translation( 492 mlflow_server: str, is_async, translator: OtelSchemaTranslator 493 ): 494 experiment = mlflow.set_experiment("span-inputs-outputs-translation-test") 495 experiment_id = experiment.experiment_id 496 497 tracer = create_tracer( 498 mlflow_server, experiment_id, "span-inputs-outputs-translation-test-service" 499 ) 500 501 with tracer.start_as_current_span("llm-call") as span: 502 span.set_attribute(translator.INPUT_VALUE_KEYS[0], "Hello, world!") 503 span.set_attribute(translator.OUTPUT_VALUE_KEYS[0], "Bye!") 504 505 if is_async: 506 _flush_async_logging() 507 508 traces = mlflow.search_traces( 509 locations=[experiment_id], include_spans=False, return_type="list" 510 ) 511 assert len(traces) == 1 512 retrieved_trace = mlflow.get_trace(traces[0].info.trace_id) 513 assert retrieved_trace.data.spans[0].inputs == "Hello, world!" 514 assert retrieved_trace.data.spans[0].outputs == "Bye!" 515 assert retrieved_trace.info.request_preview == '"Hello, world!"' 516 assert retrieved_trace.info.response_preview == '"Bye!"' 517 518 519 @pytest.mark.parametrize( 520 "translator", [GenAiTranslator, OpenInferenceTranslator, TraceloopTranslator] 521 ) 522 def test_span_token_usage_translation( 523 mlflow_server: str, is_async, translator: OtelSchemaTranslator 524 ): 525 experiment = mlflow.set_experiment("span-token-usage-translation-test") 526 experiment_id = experiment.experiment_id 527 528 tracer = create_tracer( 529 mlflow_server, experiment_id, "span-token-usage-translation-test-service" 530 ) 531 532 with tracer.start_as_current_span("llm-call") as span: 533 span.set_attribute(translator.INPUT_TOKEN_KEY, 100) 534 span.set_attribute(translator.OUTPUT_TOKEN_KEY, 50) 535 536 if is_async: 537 _flush_async_logging() 538 539 traces = mlflow.search_traces( 540 locations=[experiment_id], include_spans=False, return_type="list" 541 ) 542 assert len(traces) > 0 543 for trace_info in traces: 544 assert trace_info.info.token_usage == { 545 "input_tokens": 100, 546 "output_tokens": 50, 547 "total_tokens": 150, 548 } 549 retrieved_trace = mlflow.get_trace(trace_info.info.trace_id) 550 assert ( 551 retrieved_trace.data.spans[0].attributes[SpanAttributeKey.CHAT_USAGE] 552 == trace_info.info.token_usage 553 ) 554 555 556 @pytest.mark.parametrize( 557 "translator", [GenAiTranslator, OpenInferenceTranslator, TraceloopTranslator] 558 ) 559 def test_aggregated_token_usage_from_multiple_spans( 560 mlflow_server: str, is_async, translator: OtelSchemaTranslator 561 ): 562 experiment = mlflow.set_experiment("aggregated-token-usage-test") 563 experiment_id = experiment.experiment_id 564 565 tracer = create_tracer(mlflow_server, experiment_id, "token-aggregation-service") 566 567 with tracer.start_as_current_span("parent-llm-call") as parent: 568 parent.set_attribute(translator.INPUT_TOKEN_KEY, 100) 569 parent.set_attribute(translator.OUTPUT_TOKEN_KEY, 50) 570 571 with tracer.start_as_current_span("child-llm-call-1") as child1: 572 child1.set_attribute(translator.INPUT_TOKEN_KEY, 200) 573 child1.set_attribute(translator.OUTPUT_TOKEN_KEY, 75) 574 575 with tracer.start_as_current_span("child-llm-call-2") as child2: 576 child2.set_attribute(translator.INPUT_TOKEN_KEY, 150) 577 child2.set_attribute(translator.OUTPUT_TOKEN_KEY, 100) 578 579 if is_async: 580 _flush_async_logging() 581 582 traces = mlflow.search_traces( 583 locations=[experiment_id], include_spans=False, return_type="list" 584 ) 585 586 trace_id = traces[0].info.trace_id 587 retrieved_trace = mlflow.get_trace(trace_id) 588 589 assert retrieved_trace.info.token_usage is not None 590 assert retrieved_trace.info.token_usage["input_tokens"] == 450 591 assert retrieved_trace.info.token_usage["output_tokens"] == 225 592 assert retrieved_trace.info.token_usage["total_tokens"] == 675