/ tests / tracing / test_otel_loading.py
test_otel_loading.py
  1  import uuid
  2  from pathlib import Path
  3  
  4  import pytest
  5  from opentelemetry import trace as otel_trace
  6  from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
  7  from opentelemetry.sdk.resources import Resource as OTelSDKResource
  8  from opentelemetry.sdk.trace import TracerProvider
  9  from opentelemetry.sdk.trace.export import SimpleSpanProcessor
 10  from opentelemetry.trace import Status, StatusCode
 11  from opentelemetry.util._once import Once
 12  
 13  import mlflow
 14  from mlflow.entities import SpanStatusCode
 15  from mlflow.entities.assessment import AssessmentSource, Expectation, Feedback
 16  from mlflow.entities.assessment_source import AssessmentSourceType
 17  from mlflow.server import handlers
 18  from mlflow.server.fastapi_app import app
 19  from mlflow.server.handlers import initialize_backend_stores
 20  from mlflow.tracing.constant import SpanAttributeKey
 21  from mlflow.tracing.otel.translation.base import OtelSchemaTranslator
 22  from mlflow.tracing.otel.translation.genai_semconv import GenAiTranslator
 23  from mlflow.tracing.otel.translation.open_inference import OpenInferenceTranslator
 24  from mlflow.tracing.otel.translation.traceloop import TraceloopTranslator
 25  from mlflow.tracing.provider import _get_trace_exporter
 26  from mlflow.tracing.utils import encode_trace_id
 27  from mlflow.tracing.utils.otlp import MLFLOW_EXPERIMENT_ID_HEADER
 28  from mlflow.tracking._tracking_service.utils import _use_tracking_uri
 29  from mlflow.version import IS_TRACING_SDK_ONLY
 30  
 31  from tests.helper_functions import get_safe_port
 32  from tests.tracking.integration_test_utils import ServerThread
 33  
 34  if IS_TRACING_SDK_ONLY:
 35      pytest.skip("OTel get_trace tests require full MLflow server", allow_module_level=True)
 36  
 37  
 38  @pytest.fixture
 39  def mlflow_server(tmp_path: Path, db_uri: str):
 40      artifact_uri = tmp_path.joinpath("artifacts").as_uri()
 41  
 42      # Force-reset backend stores before each test
 43      handlers._tracking_store = None
 44      handlers._model_registry_store = None
 45      initialize_backend_stores(db_uri, default_artifact_root=artifact_uri)
 46  
 47      with ServerThread(app, get_safe_port()) as url:
 48          yield url
 49  
 50  
 51  @pytest.fixture(autouse=True)
 52  def tracking_uri_setup(mlflow_server):
 53      with _use_tracking_uri(mlflow_server):
 54          yield
 55  
 56  
 57  @pytest.fixture(params=[True, False])
 58  def is_async(request, monkeypatch):
 59      monkeypatch.setenv("MLFLOW_ASYNC_TRACE_LOGGING", "true" if request.param else "false")
 60  
 61  
 62  def _flush_async_logging():
 63      exporter = _get_trace_exporter()
 64      assert hasattr(exporter, "_async_queue"), "Async queue is not initialized"
 65      exporter._async_queue.flush(terminate=True)
 66  
 67  
 68  def create_tracer(mlflow_server: str, experiment_id: str, service_name: str = "test-service"):
 69      resource = OTelSDKResource.create({"service.name": service_name, "service.version": "1.0.0"})
 70      tracer_provider = TracerProvider(resource=resource)
 71  
 72      exporter = OTLPSpanExporter(
 73          endpoint=f"{mlflow_server}/v1/traces",
 74          headers={MLFLOW_EXPERIMENT_ID_HEADER: experiment_id},
 75          timeout=10,
 76      )
 77  
 78      span_processor = SimpleSpanProcessor(exporter)
 79      tracer_provider.add_span_processor(span_processor)
 80  
 81      # Reset the global tracer provider
 82      otel_trace._TRACER_PROVIDER_SET_ONCE = Once()
 83      otel_trace._TRACER_PROVIDER = None
 84      otel_trace.set_tracer_provider(tracer_provider)
 85  
 86      return otel_trace.get_tracer(__name__)
 87  
 88  
 89  def test_get_trace_for_otel_sent_span(mlflow_server: str, is_async):
 90      experiment = mlflow.set_experiment("otel-get-trace-test")
 91      experiment_id = experiment.experiment_id
 92  
 93      tracer = create_tracer(mlflow_server, experiment_id, "test-service-get-trace")
 94  
 95      # Create a span with various attributes to test conversion
 96      with tracer.start_as_current_span("otel-test-span") as span:
 97          span.set_attribute("test.string", "string-value")
 98          span.set_attribute("test.number", 42)
 99          span.set_attribute("test.boolean", True)
100          span.set_attribute("operation.type", "llm_request")
101  
102          # Capture the OTel trace ID
103          otel_trace_id = span.get_span_context().trace_id
104          assert span.get_span_context().is_valid
105          assert otel_trace_id != 0
106  
107      if is_async:
108          _flush_async_logging()
109  
110      traces = mlflow.search_traces(
111          locations=[experiment_id], include_spans=False, return_type="list"
112      )
113  
114      assert len(traces) > 0, "No traces found in the database"
115  
116      trace_id = traces[0].info.trace_id
117      retrieved_trace = mlflow.get_trace(trace_id)
118  
119      assert retrieved_trace.info.trace_id == trace_id
120      assert retrieved_trace.info.trace_location.mlflow_experiment.experiment_id == experiment_id
121  
122      assert len(retrieved_trace.data.spans) == 1
123      span = retrieved_trace.data.spans[0]
124  
125      assert span.name == "otel-test-span"
126      assert span.trace_id == trace_id
127      # OTel spans default to UNSET status if not explicitly set
128      assert span.status.status_code == SpanStatusCode.UNSET
129  
130      # Verify attributes were converted correctly
131      assert span.attributes["test.string"] == "string-value"
132      assert span.attributes["test.number"] == 42
133      assert span.attributes["test.boolean"] is True
134      assert span.attributes["operation.type"] == "llm_request"
135  
136      # Verify the trace ID matches the expected format
137      expected_trace_id = f"tr-{encode_trace_id(otel_trace_id)}"
138      assert trace_id == expected_trace_id
139  
140  
141  def test_get_trace_for_otel_nested_spans(mlflow_server: str, is_async):
142      experiment = mlflow.set_experiment("otel-nested-spans-test")
143      experiment_id = experiment.experiment_id
144  
145      tracer = create_tracer(mlflow_server, experiment_id, "nested-test-service")
146  
147      # Create nested spans
148      with tracer.start_as_current_span("parent-span") as parent_span:
149          parent_span.set_attribute("span.level", "parent")
150  
151          with tracer.start_as_current_span("child-span") as child_span:
152              child_span.set_attribute("span.level", "child")
153              child_span.set_attribute("child.operation", "process_data")
154  
155      if is_async:
156          _flush_async_logging()
157  
158      traces = mlflow.search_traces(
159          locations=[experiment_id], include_spans=False, return_type="list"
160      )
161  
162      assert len(traces) > 0, "No traces found in the database"
163  
164      trace_id = traces[0].info.trace_id
165      retrieved_trace = mlflow.get_trace(trace_id)
166  
167      assert len(retrieved_trace.data.spans) == 2
168  
169      spans_by_name = {span.name: span for span in retrieved_trace.data.spans}
170  
171      assert "parent-span" in spans_by_name
172      assert "child-span" in spans_by_name
173  
174      parent_span = spans_by_name["parent-span"]
175      child_span = spans_by_name["child-span"]
176  
177      assert parent_span.attributes["span.level"] == "parent"
178      assert parent_span.parent_id is None  # Root span has no parent
179  
180      assert child_span.attributes["span.level"] == "child"
181      assert child_span.attributes["child.operation"] == "process_data"
182      assert child_span.parent_id == parent_span.span_id  # Child should reference parent
183  
184  
185  def test_get_trace_with_otel_span_events(mlflow_server: str, is_async):
186      experiment = mlflow.set_experiment("otel-events-test")
187      experiment_id = experiment.experiment_id
188  
189      tracer = create_tracer(mlflow_server, experiment_id, "events-test-service")
190  
191      # Create span with events using OTel SDK
192      with tracer.start_as_current_span("span-with-events") as span:
193          span.add_event("test_event", attributes={"event.type": "processing"})
194  
195      if is_async:
196          _flush_async_logging()
197  
198      traces = mlflow.search_traces(
199          locations=[experiment_id], include_spans=False, return_type="list"
200      )
201  
202      trace_id = traces[0].info.trace_id
203      retrieved_trace = mlflow.get_trace(trace_id)
204  
205      assert len(retrieved_trace.data.spans) == 1
206      retrieved_span = retrieved_trace.data.spans[0]
207  
208      assert retrieved_span.name == "span-with-events"
209      assert len(retrieved_span.events) == 1
210      event = retrieved_span.events[0]
211      assert event.name == "test_event"
212      assert event.attributes["event.type"] == "processing"
213  
214  
215  def test_get_trace_nonexistent_otel_trace(mlflow_server: str):
216      # Create a fake trace ID in OTel format
217      fake_otel_trace_id = uuid.uuid4().hex
218      fake_trace_id = f"tr-{fake_otel_trace_id}"
219  
220      # MLflow get_trace returns None for non-existent traces
221      trace = mlflow.get_trace(fake_trace_id)
222      assert trace is None
223  
224  
225  def test_get_trace_with_otel_span_status(mlflow_server: str, is_async):
226      experiment = mlflow.set_experiment("otel-status-test")
227      experiment_id = experiment.experiment_id
228  
229      tracer = create_tracer(mlflow_server, experiment_id, "status-test-service")
230  
231      # Create span with error status using OTel SDK
232      with tracer.start_as_current_span("error-span") as span:
233          span.set_status(Status(StatusCode.ERROR, "Something went wrong"))
234  
235      if is_async:
236          _flush_async_logging()
237  
238      traces = mlflow.search_traces(
239          locations=[experiment_id], include_spans=False, return_type="list"
240      )
241  
242      trace_id = traces[0].info.trace_id
243      retrieved_trace = mlflow.get_trace(trace_id)
244  
245      assert len(retrieved_trace.data.spans) == 1
246      retrieved_span = retrieved_trace.data.spans[0]
247  
248      assert retrieved_span.name == "error-span"
249      assert retrieved_span.status.status_code == SpanStatusCode.ERROR
250      assert "Something went wrong" in retrieved_span.status.description
251  
252  
253  def test_set_trace_tag_on_otel_trace(mlflow_server: str, is_async):
254      experiment = mlflow.set_experiment("otel-tag-test")
255      experiment_id = experiment.experiment_id
256  
257      tracer = create_tracer(mlflow_server, experiment_id, "tag-test-service")
258  
259      with tracer.start_as_current_span("tagged-span") as span:
260          span.set_attribute("test.attribute", "value")
261  
262      if is_async:
263          _flush_async_logging()
264  
265      traces = mlflow.search_traces(
266          locations=[experiment_id], include_spans=False, return_type="list"
267      )
268      trace_id = traces[0].info.trace_id
269  
270      mlflow.set_trace_tag(trace_id, "environment", "test")
271      mlflow.set_trace_tag(trace_id, "version", "1.0.0")
272  
273      retrieved_trace = mlflow.get_trace(trace_id)
274      assert retrieved_trace.info.tags["environment"] == "test"
275      assert retrieved_trace.info.tags["version"] == "1.0.0"
276  
277  
278  def test_log_expectation_on_otel_trace(mlflow_server: str, is_async):
279      experiment = mlflow.set_experiment("otel-expectation-test")
280      experiment_id = experiment.experiment_id
281  
282      tracer = create_tracer(mlflow_server, experiment_id, "expectation-test-service")
283  
284      # Create a span that represents a question-answer scenario
285      with tracer.start_as_current_span("qa-span") as span:
286          span.set_attribute("question", "What is MLflow?")
287          span.set_attribute("answer", "MLflow is an open-source ML platform")
288  
289      if is_async:
290          _flush_async_logging()
291  
292      traces = mlflow.search_traces(
293          locations=[experiment_id], include_spans=False, return_type="list"
294      )
295      trace_id = traces[0].info.trace_id
296  
297      expectation_source = AssessmentSource(
298          source_type=AssessmentSourceType.HUMAN, source_id="test_user@example.com"
299      )
300  
301      logged_assessment = mlflow.log_expectation(
302          trace_id=trace_id,
303          name="expected_answer",
304          value="MLflow is an open-source machine learning platform",
305          source=expectation_source,
306          metadata={"confidence": "high", "reviewed_by": "expert"},
307      )
308      expectation = mlflow.get_assessment(
309          trace_id=trace_id, assessment_id=logged_assessment.assessment_id
310      )
311      assert expectation.name == "expected_answer"
312      assert expectation.value == "MLflow is an open-source machine learning platform"
313      assert expectation.source.source_type == AssessmentSourceType.HUMAN
314      assert expectation.metadata["confidence"] == "high"
315  
316  
317  def test_log_feedback_on_otel_trace(mlflow_server: str, is_async):
318      experiment = mlflow.set_experiment("otel-feedback-test")
319      experiment_id = experiment.experiment_id
320  
321      tracer = create_tracer(mlflow_server, experiment_id, "feedback-test-service")
322  
323      # Create a span representing a model prediction
324      with tracer.start_as_current_span("prediction-span") as span:
325          span.set_attribute("model", "gpt-4")
326          span.set_attribute("prediction", "The weather is sunny")
327  
328      if is_async:
329          _flush_async_logging()
330  
331      traces = mlflow.search_traces(
332          locations=[experiment_id], include_spans=False, return_type="list"
333      )
334      assert len(traces) > 0, "No traces found in the database"
335      trace_id = traces[0].info.trace_id
336  
337      llm_source = AssessmentSource(
338          source_type=AssessmentSourceType.LLM_JUDGE, source_id="gpt-4o-mini"
339      )
340  
341      logged_quality = mlflow.log_feedback(
342          trace_id=trace_id,
343          name="quality_score",
344          value=8.5,
345          source=llm_source,
346          metadata={"scale": "1-10", "criterion": "accuracy"},
347      )
348      feedback = mlflow.get_assessment(trace_id=trace_id, assessment_id=logged_quality.assessment_id)
349      assert feedback.name == "quality_score"
350      assert feedback.value == 8.5
351      assert feedback.source.source_type == AssessmentSourceType.LLM_JUDGE
352  
353      human_source = AssessmentSource(
354          source_type=AssessmentSourceType.HUMAN, source_id="reviewer@example.com"
355      )
356  
357      logged_approval = mlflow.log_feedback(
358          trace_id=trace_id,
359          name="approved",
360          value=True,
361          source=human_source,
362          metadata={"review_date": "2024-01-15"},
363      )
364      feedback = mlflow.get_assessment(trace_id=trace_id, assessment_id=logged_approval.assessment_id)
365      assert feedback.name == "approved"
366      assert feedback.value is True
367      assert feedback.source.source_type == AssessmentSourceType.HUMAN
368  
369  
370  def test_multiple_assessments_on_otel_trace(mlflow_server: str, is_async):
371      experiment = mlflow.set_experiment("otel-multi-assessment-test")
372      experiment_id = experiment.experiment_id
373  
374      tracer = create_tracer(mlflow_server, experiment_id, "multi-assessment-test-service")
375  
376      # Create a complex trace with nested spans
377      with tracer.start_as_current_span("conversation") as parent_span:
378          parent_span.set_attribute("user_query", "Explain quantum computing")
379  
380          with tracer.start_as_current_span("retrieval") as retrieval_span:
381              retrieval_span.set_attribute("documents_found", 5)
382  
383          with tracer.start_as_current_span("generation") as generation_span:
384              generation_span.set_attribute("model", "gpt-4")
385              generation_span.set_attribute("response", "Quantum computing uses quantum mechanics...")
386  
387      if is_async:
388          _flush_async_logging()
389  
390      traces = mlflow.search_traces(
391          locations=[experiment_id], include_spans=False, return_type="list"
392      )
393      trace_id = traces[0].info.trace_id
394  
395      mlflow.set_trace_tag(trace_id, "topic", "quantum_computing")
396      mlflow.set_trace_tag(trace_id, "complexity", "high")
397  
398      human_source = AssessmentSource(AssessmentSourceType.HUMAN, "expert@physics.edu")
399      llm_source = AssessmentSource(AssessmentSourceType.LLM_JUDGE, "claude-3")
400  
401      expectation = Expectation(
402          name="expected_quality",
403          value="Should explain quantum superposition and entanglement",
404          source=human_source,
405      )
406      mlflow.log_assessment(trace_id=trace_id, assessment=expectation)
407      feedback_items = [
408          Feedback(name="accuracy", value=9.0, source=llm_source, metadata={"max_score": "10"}),
409          Feedback(name="clarity", value=8.5, source=llm_source, metadata={"max_score": "10"}),
410          Feedback(
411              name="helpfulness",
412              value=True,
413              source=human_source,
414              metadata={"reviewer_expertise": "quantum_physics"},
415          ),
416          Feedback(
417              name="contains_errors",
418              value=False,
419              source=human_source,
420              metadata={"fact_checked": "True"},
421          ),
422      ]
423  
424      for feedback in feedback_items:
425          mlflow.log_assessment(trace_id=trace_id, assessment=feedback)
426  
427      retrieved_trace = mlflow.get_trace(trace_id)
428      assessments = retrieved_trace.info.assessments
429      assert len(assessments) == 5
430      assert [a.name for a in assessments] == [
431          "expected_quality",
432          "accuracy",
433          "clarity",
434          "helpfulness",
435          "contains_errors",
436      ]
437  
438      assert retrieved_trace.info.tags["topic"] == "quantum_computing"
439      assert retrieved_trace.info.tags["complexity"] == "high"
440  
441      assert len(retrieved_trace.data.spans) == 3
442      span_names = {span.name for span in retrieved_trace.data.spans}
443      assert span_names == {"conversation", "retrieval", "generation"}
444  
445      tagged_traces = mlflow.search_traces(
446          locations=[experiment_id],
447          filter_string='tags.topic = "quantum_computing"',
448          return_type="list",
449      )
450      assert len(tagged_traces) == 1
451      assert tagged_traces[0].info.trace_id == trace_id
452  
453  
454  def test_span_kind_translation(mlflow_server: str, is_async):
455      experiment = mlflow.set_experiment("span-kind-translation-test")
456      experiment_id = experiment.experiment_id
457  
458      tracer = create_tracer(mlflow_server, experiment_id, "span-kind-translation-test-service")
459  
460      with tracer.start_as_current_span("llm-call") as span:
461          span.set_attribute(OpenInferenceTranslator.SPAN_KIND_ATTRIBUTE_KEY, "LLM")
462  
463      with tracer.start_as_current_span("retriever-call") as span:
464          span.set_attribute(OpenInferenceTranslator.SPAN_KIND_ATTRIBUTE_KEY, "RETRIEVER")
465  
466      with tracer.start_as_current_span("tool-call") as span:
467          span.set_attribute(TraceloopTranslator.SPAN_KIND_ATTRIBUTE_KEY, "tool")
468  
469      if is_async:
470          _flush_async_logging()
471  
472      traces = mlflow.search_traces(
473          locations=[experiment_id], include_spans=False, return_type="list"
474      )
475  
476      assert len(traces) == 3
477      for trace_info in traces:
478          retrieved_trace = mlflow.get_trace(trace_info.info.trace_id)
479          for span in retrieved_trace.data.spans:
480              if span.name == "llm-call":
481                  assert span.span_type == "LLM"
482              elif span.name == "retriever-call":
483                  assert span.span_type == "RETRIEVER"
484              elif span.name == "tool-call":
485                  assert span.span_type == "TOOL"
486  
487  
488  @pytest.mark.parametrize(
489      "translator", [GenAiTranslator, OpenInferenceTranslator, TraceloopTranslator]
490  )
491  def test_span_inputs_outputs_translation(
492      mlflow_server: str, is_async, translator: OtelSchemaTranslator
493  ):
494      experiment = mlflow.set_experiment("span-inputs-outputs-translation-test")
495      experiment_id = experiment.experiment_id
496  
497      tracer = create_tracer(
498          mlflow_server, experiment_id, "span-inputs-outputs-translation-test-service"
499      )
500  
501      with tracer.start_as_current_span("llm-call") as span:
502          span.set_attribute(translator.INPUT_VALUE_KEYS[0], "Hello, world!")
503          span.set_attribute(translator.OUTPUT_VALUE_KEYS[0], "Bye!")
504  
505      if is_async:
506          _flush_async_logging()
507  
508      traces = mlflow.search_traces(
509          locations=[experiment_id], include_spans=False, return_type="list"
510      )
511      assert len(traces) == 1
512      retrieved_trace = mlflow.get_trace(traces[0].info.trace_id)
513      assert retrieved_trace.data.spans[0].inputs == "Hello, world!"
514      assert retrieved_trace.data.spans[0].outputs == "Bye!"
515      assert retrieved_trace.info.request_preview == '"Hello, world!"'
516      assert retrieved_trace.info.response_preview == '"Bye!"'
517  
518  
519  @pytest.mark.parametrize(
520      "translator", [GenAiTranslator, OpenInferenceTranslator, TraceloopTranslator]
521  )
522  def test_span_token_usage_translation(
523      mlflow_server: str, is_async, translator: OtelSchemaTranslator
524  ):
525      experiment = mlflow.set_experiment("span-token-usage-translation-test")
526      experiment_id = experiment.experiment_id
527  
528      tracer = create_tracer(
529          mlflow_server, experiment_id, "span-token-usage-translation-test-service"
530      )
531  
532      with tracer.start_as_current_span("llm-call") as span:
533          span.set_attribute(translator.INPUT_TOKEN_KEY, 100)
534          span.set_attribute(translator.OUTPUT_TOKEN_KEY, 50)
535  
536      if is_async:
537          _flush_async_logging()
538  
539      traces = mlflow.search_traces(
540          locations=[experiment_id], include_spans=False, return_type="list"
541      )
542      assert len(traces) > 0
543      for trace_info in traces:
544          assert trace_info.info.token_usage == {
545              "input_tokens": 100,
546              "output_tokens": 50,
547              "total_tokens": 150,
548          }
549          retrieved_trace = mlflow.get_trace(trace_info.info.trace_id)
550          assert (
551              retrieved_trace.data.spans[0].attributes[SpanAttributeKey.CHAT_USAGE]
552              == trace_info.info.token_usage
553          )
554  
555  
556  @pytest.mark.parametrize(
557      "translator", [GenAiTranslator, OpenInferenceTranslator, TraceloopTranslator]
558  )
559  def test_aggregated_token_usage_from_multiple_spans(
560      mlflow_server: str, is_async, translator: OtelSchemaTranslator
561  ):
562      experiment = mlflow.set_experiment("aggregated-token-usage-test")
563      experiment_id = experiment.experiment_id
564  
565      tracer = create_tracer(mlflow_server, experiment_id, "token-aggregation-service")
566  
567      with tracer.start_as_current_span("parent-llm-call") as parent:
568          parent.set_attribute(translator.INPUT_TOKEN_KEY, 100)
569          parent.set_attribute(translator.OUTPUT_TOKEN_KEY, 50)
570  
571          with tracer.start_as_current_span("child-llm-call-1") as child1:
572              child1.set_attribute(translator.INPUT_TOKEN_KEY, 200)
573              child1.set_attribute(translator.OUTPUT_TOKEN_KEY, 75)
574  
575          with tracer.start_as_current_span("child-llm-call-2") as child2:
576              child2.set_attribute(translator.INPUT_TOKEN_KEY, 150)
577              child2.set_attribute(translator.OUTPUT_TOKEN_KEY, 100)
578  
579      if is_async:
580          _flush_async_logging()
581  
582      traces = mlflow.search_traces(
583          locations=[experiment_id], include_spans=False, return_type="list"
584      )
585  
586      trace_id = traces[0].info.trace_id
587      retrieved_trace = mlflow.get_trace(trace_id)
588  
589      assert retrieved_trace.info.token_usage is not None
590      assert retrieved_trace.info.token_usage["input_tokens"] == 450
591      assert retrieved_trace.info.token_usage["output_tokens"] == 225
592      assert retrieved_trace.info.token_usage["total_tokens"] == 675