/ tests / tracing / utils / test_utils.py
test_utils.py
  1  import json
  2  import logging
  3  from unittest import mock
  4  from unittest.mock import Mock, patch
  5  
  6  import litellm
  7  import pytest
  8  from opentelemetry import trace as trace_api
  9  from pydantic import ValidationError
 10  
 11  import mlflow
 12  from mlflow.entities import (
 13      LiveSpan,
 14      SpanType,
 15  )
 16  from mlflow.entities.span import SpanType
 17  from mlflow.entities.trace_location import UCSchemaLocation
 18  from mlflow.exceptions import MlflowException
 19  from mlflow.tracing import set_span_chat_tools
 20  from mlflow.tracing.constant import (
 21      TRACE_ID_V4_PREFIX,
 22      CostKey,
 23      SpanAttributeKey,
 24      TokenUsageKey,
 25  )
 26  from mlflow.tracing.utils import (
 27      _calculate_percentile,
 28      aggregate_cost_from_spans,
 29      aggregate_usage_from_spans,
 30      calculate_cost_by_model_and_token_usage,
 31      capture_function_input_args,
 32      construct_full_inputs,
 33      dump_span_attribute_value,
 34      encode_span_id,
 35      encode_trace_id,
 36      generate_trace_id_v4,
 37      generate_trace_id_v4_from_otel_trace_id,
 38      get_active_spans_table_name,
 39      get_otel_attribute,
 40      maybe_get_request_id,
 41      parse_trace_id_v4,
 42  )
 43  from mlflow.version import IS_TRACING_SDK_ONLY
 44  
 45  from tests.tracing.helper import create_mock_otel_span
 46  
 47  
 48  def test_capture_function_input_args_does_not_raise():
 49      # Exception during inspecting inputs: trace should be logged without inputs field
 50      with patch("inspect.signature", side_effect=ValueError("Some error")) as mock_input_args:
 51          args = capture_function_input_args(lambda: None, (), {})
 52  
 53      assert args is None
 54      assert mock_input_args.call_count > 0
 55  
 56  
 57  def test_duplicate_span_names():
 58      span_names = ["red", "red", "blue", "red", "green", "blue"]
 59  
 60      spans = [
 61          LiveSpan(create_mock_otel_span("trace_id", span_id=i, name=span_name), trace_id="tr-123")
 62          for i, span_name in enumerate(span_names)
 63      ]
 64  
 65      assert [span.name for span in spans] == span_names
 66      # Check if the span order is preserved
 67      assert [span.span_id for span in spans] == [encode_span_id(i) for i in [0, 1, 2, 3, 4, 5]]
 68  
 69  
 70  def test_aggregate_usage_from_spans():
 71      spans = [
 72          LiveSpan(create_mock_otel_span("trace_id", span_id=i, name=f"span_{i}"), trace_id="tr-123")
 73          for i in range(3)
 74      ]
 75      spans[0].set_attribute(
 76          SpanAttributeKey.CHAT_USAGE,
 77          {
 78              TokenUsageKey.INPUT_TOKENS: 10,
 79              TokenUsageKey.OUTPUT_TOKENS: 20,
 80              TokenUsageKey.TOTAL_TOKENS: 30,
 81          },
 82      )
 83      spans[1].set_attribute(
 84          SpanAttributeKey.CHAT_USAGE,
 85          {TokenUsageKey.OUTPUT_TOKENS: 15, TokenUsageKey.TOTAL_TOKENS: 15},
 86      )
 87      spans[2].set_attribute(
 88          SpanAttributeKey.CHAT_USAGE,
 89          {
 90              TokenUsageKey.INPUT_TOKENS: 5,
 91              TokenUsageKey.OUTPUT_TOKENS: 10,
 92              TokenUsageKey.TOTAL_TOKENS: 15,
 93          },
 94      )
 95  
 96      usage = aggregate_usage_from_spans(spans)
 97      assert usage == {
 98          TokenUsageKey.INPUT_TOKENS: 15,
 99          TokenUsageKey.OUTPUT_TOKENS: 45,
100          TokenUsageKey.TOTAL_TOKENS: 60,
101      }
102  
103  
104  def test_aggregate_usage_from_spans_skips_descendant_usage():
105      spans = [
106          LiveSpan(create_mock_otel_span("trace_id", span_id=1, name="root"), trace_id="tr-123"),
107          LiveSpan(
108              create_mock_otel_span("trace_id", span_id=2, name="child", parent_id=1),
109              trace_id="tr-123",
110          ),
111          LiveSpan(
112              create_mock_otel_span("trace_id", span_id=3, name="grandchild", parent_id=2),
113              trace_id="tr-123",
114          ),
115          LiveSpan(
116              create_mock_otel_span("trace_id", span_id=4, name="independent"), trace_id="tr-123"
117          ),
118      ]
119  
120      spans[0].set_attribute(
121          SpanAttributeKey.CHAT_USAGE,
122          {
123              TokenUsageKey.INPUT_TOKENS: 10,
124              TokenUsageKey.OUTPUT_TOKENS: 20,
125              TokenUsageKey.TOTAL_TOKENS: 30,
126          },
127      )
128  
129      spans[2].set_attribute(
130          SpanAttributeKey.CHAT_USAGE,
131          {
132              TokenUsageKey.INPUT_TOKENS: 5,
133              TokenUsageKey.OUTPUT_TOKENS: 10,
134              TokenUsageKey.TOTAL_TOKENS: 15,
135          },
136      )
137  
138      spans[3].set_attribute(
139          SpanAttributeKey.CHAT_USAGE,
140          {
141              TokenUsageKey.INPUT_TOKENS: 3,
142              TokenUsageKey.OUTPUT_TOKENS: 6,
143              TokenUsageKey.TOTAL_TOKENS: 9,
144          },
145      )
146  
147      usage = aggregate_usage_from_spans(spans)
148  
149      assert usage == {
150          TokenUsageKey.INPUT_TOKENS: 13,
151          TokenUsageKey.OUTPUT_TOKENS: 26,
152          TokenUsageKey.TOTAL_TOKENS: 39,
153      }
154  
155  
156  def test_aggregate_usage_from_spans_with_cached_tokens():
157      spans = [
158          LiveSpan(create_mock_otel_span("trace_id", span_id=i, name=f"span_{i}"), trace_id="tr-123")
159          for i in range(3)
160      ]
161      spans[0].set_attribute(
162          SpanAttributeKey.CHAT_USAGE,
163          {
164              TokenUsageKey.INPUT_TOKENS: 100,
165              TokenUsageKey.OUTPUT_TOKENS: 50,
166              TokenUsageKey.TOTAL_TOKENS: 150,
167              TokenUsageKey.CACHE_READ_INPUT_TOKENS: 80,
168          },
169      )
170      spans[1].set_attribute(
171          SpanAttributeKey.CHAT_USAGE,
172          {
173              TokenUsageKey.INPUT_TOKENS: 200,
174              TokenUsageKey.OUTPUT_TOKENS: 100,
175              TokenUsageKey.TOTAL_TOKENS: 300,
176              TokenUsageKey.CACHE_READ_INPUT_TOKENS: 120,
177              TokenUsageKey.CACHE_CREATION_INPUT_TOKENS: 50,
178          },
179      )
180      # span without cached tokens
181      spans[2].set_attribute(
182          SpanAttributeKey.CHAT_USAGE,
183          {
184              TokenUsageKey.INPUT_TOKENS: 10,
185              TokenUsageKey.OUTPUT_TOKENS: 5,
186              TokenUsageKey.TOTAL_TOKENS: 15,
187          },
188      )
189  
190      usage = aggregate_usage_from_spans(spans)
191      assert usage == {
192          TokenUsageKey.INPUT_TOKENS: 310,
193          TokenUsageKey.OUTPUT_TOKENS: 155,
194          TokenUsageKey.TOTAL_TOKENS: 465,
195          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 200,
196          TokenUsageKey.CACHE_CREATION_INPUT_TOKENS: 50,
197      }
198  
199  
200  def test_aggregate_usage_from_spans_without_cached_tokens_omits_keys():
201      spans = [
202          LiveSpan(create_mock_otel_span("trace_id", span_id=0, name="span_0"), trace_id="tr-123")
203      ]
204      spans[0].set_attribute(
205          SpanAttributeKey.CHAT_USAGE,
206          {
207              TokenUsageKey.INPUT_TOKENS: 10,
208              TokenUsageKey.OUTPUT_TOKENS: 5,
209              TokenUsageKey.TOTAL_TOKENS: 15,
210          },
211      )
212  
213      usage = aggregate_usage_from_spans(spans)
214      assert usage == {
215          TokenUsageKey.INPUT_TOKENS: 10,
216          TokenUsageKey.OUTPUT_TOKENS: 5,
217          TokenUsageKey.TOTAL_TOKENS: 15,
218      }
219      # Cached keys should not be present
220      assert TokenUsageKey.CACHE_READ_INPUT_TOKENS not in usage
221      assert TokenUsageKey.CACHE_CREATION_INPUT_TOKENS not in usage
222  
223  
224  def test_aggregate_cost_from_spans():
225      spans = [
226          LiveSpan(create_mock_otel_span("trace_id", span_id=i, name=f"span_{i}"), trace_id="tr-123")
227          for i in range(3)
228      ]
229      spans[0].set_attribute(
230          SpanAttributeKey.LLM_COST,
231          {
232              CostKey.INPUT_COST: 10,
233              CostKey.OUTPUT_COST: 20,
234              CostKey.TOTAL_COST: 30,
235          },
236      )
237      spans[1].set_attribute(
238          SpanAttributeKey.LLM_COST,
239          {CostKey.OUTPUT_COST: 15, CostKey.TOTAL_COST: 15},
240      )
241      spans[2].set_attribute(
242          SpanAttributeKey.LLM_COST,
243          {
244              CostKey.INPUT_COST: 5,
245              CostKey.OUTPUT_COST: 10,
246              CostKey.TOTAL_COST: 15,
247          },
248      )
249  
250      cost = aggregate_cost_from_spans(spans)
251      assert cost == {
252          CostKey.INPUT_COST: 15,
253          CostKey.OUTPUT_COST: 45,
254          CostKey.TOTAL_COST: 60,
255      }
256  
257  
258  def test_aggregate_cost_from_spans_skips_descendant_cost():
259      spans = [
260          LiveSpan(create_mock_otel_span("trace_id", span_id=1, name="root"), trace_id="tr-123"),
261          LiveSpan(
262              create_mock_otel_span("trace_id", span_id=2, name="child", parent_id=1),
263              trace_id="tr-123",
264          ),
265          LiveSpan(
266              create_mock_otel_span("trace_id", span_id=3, name="grandchild", parent_id=2),
267              trace_id="tr-123",
268          ),
269          LiveSpan(
270              create_mock_otel_span("trace_id", span_id=4, name="independent"), trace_id="tr-123"
271          ),
272      ]
273  
274      spans[0].set_attribute(
275          SpanAttributeKey.LLM_COST,
276          {
277              CostKey.INPUT_COST: 10,
278              CostKey.OUTPUT_COST: 20,
279              CostKey.TOTAL_COST: 30,
280          },
281      )
282  
283      spans[2].set_attribute(
284          SpanAttributeKey.LLM_COST,
285          {
286              CostKey.INPUT_COST: 5,
287              CostKey.OUTPUT_COST: 10,
288              CostKey.TOTAL_COST: 15,
289          },
290      )
291  
292      spans[3].set_attribute(
293          SpanAttributeKey.LLM_COST,
294          {
295              CostKey.INPUT_COST: 3,
296              CostKey.OUTPUT_COST: 6,
297              CostKey.TOTAL_COST: 9,
298          },
299      )
300  
301      cost = aggregate_cost_from_spans(spans)
302  
303      assert cost == {
304          CostKey.INPUT_COST: 13,
305          CostKey.OUTPUT_COST: 26,
306          CostKey.TOTAL_COST: 39,
307      }
308  
309  
310  def test_maybe_get_request_id():
311      assert maybe_get_request_id(is_evaluate=True) is None
312  
313      try:
314          from mlflow.pyfunc.context import Context, set_prediction_context
315      except ImportError:
316          pytest.skip("Skipping the rest of tests as mlflow.pyfunc module is not available.")
317  
318      with set_prediction_context(Context(request_id="eval", is_evaluate=True)):
319          assert maybe_get_request_id(is_evaluate=True) == "eval"
320  
321      with set_prediction_context(Context(request_id="non_eval", is_evaluate=False)):
322          assert maybe_get_request_id(is_evaluate=True) is None
323  
324  
325  def test_set_chat_tools_validation():
326      tools = [
327          {
328              "type": "unsupported_function",
329              "unsupported_function": {
330                  "name": "test",
331              },
332          }
333      ]
334  
335      @mlflow.trace(span_type=SpanType.CHAT_MODEL)
336      def dummy_call(tools):
337          span = mlflow.get_current_active_span()
338          set_span_chat_tools(span, tools)
339          return None
340  
341      with pytest.raises(ValidationError, match="validation error for ChatTool"):
342          dummy_call(tools)
343  
344  
345  @pytest.mark.parametrize(
346      ("enum_values", "param_type"),
347      [
348          ([1, 2, 3, 4, 5], "integer"),
349          (["option1", "option2", "option3"], "string"),
350          ([1.1, 2.5, 3.7], "number"),
351          ([True, False], "boolean"),
352          (["mixed", 42, True, 3.14], "string"),  # Mixed types with string base type
353      ],
354  )
355  def test_openai_parse_tools_enum_validation(enum_values, param_type):
356      from mlflow.openai.utils.chat_schema import _parse_tools
357  
358      # Simulate the exact OpenAI autologging input that was failing
359      openai_inputs = {
360          "tools": [
361              {
362                  "type": "function",
363                  "function": {
364                      "name": "select_option",
365                      "description": "Select an option from the given choices",
366                      "parameters": {
367                          "type": "object",
368                          "properties": {"option": {"type": param_type, "enum": enum_values}},
369                          "required": ["option"],
370                      },
371                  },
372              }
373          ]
374      }
375  
376      # This should not raise a ValidationError - tests the actual failing code path
377      parsed_tools = _parse_tools(openai_inputs)
378      assert len(parsed_tools) == 1
379      assert parsed_tools[0].function.name == "select_option"
380      assert parsed_tools[0].function.parameters.properties["option"].enum == enum_values
381  
382  
383  def test_construct_full_inputs_simple_function():
384      def func(a, b, c=3, d=4, **kwargs):
385          pass
386  
387      result = construct_full_inputs(func, 1, 2)
388      assert result == {"a": 1, "b": 2}
389  
390      result = construct_full_inputs(func, 1, 2, c=30)
391      assert result == {"a": 1, "b": 2, "c": 30}
392  
393      result = construct_full_inputs(func, 1, 2, c=30, d=40, e=50)
394      assert result == {"a": 1, "b": 2, "c": 30, "d": 40, "kwargs": {"e": 50}}
395  
396      def no_args_func():
397          pass
398  
399      result = construct_full_inputs(no_args_func)
400      assert result == {}
401  
402      class TestClass:
403          def func(self, a, b, c=3, d=4, **kwargs):
404              pass
405  
406      result = construct_full_inputs(TestClass().func, 1, 2)
407      assert result == {"a": 1, "b": 2}
408  
409  
410  def test_calculate_percentile():
411      # Test empty list
412      assert _calculate_percentile([], 0.5) == 0.0
413  
414      # Test single element
415      assert _calculate_percentile([100], 0.25) == 100
416      assert _calculate_percentile([100], 0.5) == 100
417      assert _calculate_percentile([100], 0.75) == 100
418  
419      # Test two elements
420      assert _calculate_percentile([10, 20], 0.0) == 10
421      assert _calculate_percentile([10, 20], 0.5) == 15  # Linear interpolation
422      assert _calculate_percentile([10, 20], 1.0) == 20
423  
424      # Test odd number of elements
425      data = [10, 20, 30, 40, 50]
426      assert _calculate_percentile(data, 0.0) == 10
427      assert _calculate_percentile(data, 0.25) == 20
428      assert _calculate_percentile(data, 0.5) == 30  # Median
429      assert _calculate_percentile(data, 0.75) == 40
430      assert _calculate_percentile(data, 1.0) == 50
431  
432      # Test even number of elements
433      data = [10, 20, 30, 40]
434      assert _calculate_percentile(data, 0.0) == 10
435      assert _calculate_percentile(data, 0.25) == 17.5  # Between 10 and 20
436      assert _calculate_percentile(data, 0.5) == 25  # Between 20 and 30
437      assert _calculate_percentile(data, 0.75) == 32.5  # Between 30 and 40
438      assert _calculate_percentile(data, 1.0) == 40
439  
440      # Test with larger dataset
441      data = list(range(1, 101))  # 1 to 100
442      assert _calculate_percentile(data, 0.25) == 25.75
443      assert _calculate_percentile(data, 0.5) == 50.5
444  
445  
446  def test_parse_trace_id_v4():
447      test_trace_id = "tr-original-trace-123"
448  
449      v4_id_uc_schema = f"{TRACE_ID_V4_PREFIX}catalog.schema/{test_trace_id}"
450      location, parsed_id = parse_trace_id_v4(v4_id_uc_schema)
451      assert location == "catalog.schema"
452      assert parsed_id == test_trace_id
453  
454      v4_id_experiment = f"{TRACE_ID_V4_PREFIX}experiment_id/{test_trace_id}"
455      location, parsed_id = parse_trace_id_v4(v4_id_experiment)
456      assert location == "experiment_id"
457      assert parsed_id == test_trace_id
458  
459      location, parsed_id = parse_trace_id_v4(test_trace_id)
460      assert location is None
461      assert parsed_id == test_trace_id
462  
463  
464  def test_parse_trace_id_v4_invalid_format():
465      with pytest.raises(MlflowException, match="Invalid trace ID format"):
466          parse_trace_id_v4(f"{TRACE_ID_V4_PREFIX}123")
467  
468      with pytest.raises(MlflowException, match="Invalid trace ID format"):
469          parse_trace_id_v4(f"{TRACE_ID_V4_PREFIX}123/")
470  
471      with pytest.raises(MlflowException, match="Invalid trace ID format"):
472          parse_trace_id_v4(f"{TRACE_ID_V4_PREFIX}catalog.schema/../invalid-trace-id")
473  
474      with pytest.raises(MlflowException, match="Invalid trace ID format"):
475          parse_trace_id_v4(f"{TRACE_ID_V4_PREFIX}catalog.schema/invalid-trace-id/invalid-format")
476  
477  
478  def test_get_otel_attribute_existing_attribute():
479      # Create a mock span with attributes
480      span = Mock(spec=trace_api.Span)
481      span.attributes = {
482          "test_key": json.dumps({"data": "value"}),
483          "string_key": json.dumps("simple_string"),
484          "number_key": json.dumps(42),
485          "boolean_key": json.dumps(True),
486          "list_key": json.dumps([1, 2, 3]),
487      }
488  
489      # Test various data types
490      result = get_otel_attribute(span, "test_key")
491      assert result == {"data": "value"}
492  
493      result = get_otel_attribute(span, "string_key")
494      assert result == "simple_string"
495  
496      result = get_otel_attribute(span, "number_key")
497      assert result == 42
498  
499      result = get_otel_attribute(span, "boolean_key")
500      assert result is True
501  
502      result = get_otel_attribute(span, "list_key")
503      assert result == [1, 2, 3]
504  
505  
506  def test_get_otel_attribute_missing_attribute():
507      # Create a mock span with empty attributes
508      span = Mock(spec=trace_api.Span)
509      span.attributes = {}
510  
511      result = get_otel_attribute(span, "nonexistent_key")
512      assert result is None
513  
514  
515  def test_get_otel_attribute_none_attribute():
516      # Create a mock span where attributes.get() returns None
517      span = Mock(spec=trace_api.Span)
518      span.attributes = Mock()
519      span.attributes.get.return_value = None
520  
521      result = get_otel_attribute(span, "any_key")
522      assert result is None
523  
524  
525  def test_get_otel_attribute_invalid_json():
526      # Create a mock span with invalid JSON
527      span = Mock(spec=trace_api.Span)
528      span.attributes = {
529          "invalid_json": "not valid json {",
530          "empty_string": "",
531      }
532  
533      result = get_otel_attribute(span, "invalid_json")
534      assert result is None
535  
536      result = get_otel_attribute(span, "empty_string")
537      assert result is None
538  
539  
540  def test_get_otel_attribute_non_string_attribute():
541      # In some edge cases, attributes might contain non-string values
542      span = Mock(spec=trace_api.Span)
543      span.attributes = {
544          "number_value": 123,  # Not a JSON string
545          "boolean_value": True,  # Not a JSON string
546      }
547  
548      # These should fail gracefully and return None
549      result = get_otel_attribute(span, "number_value")
550      assert result is None
551  
552      result = get_otel_attribute(span, "boolean_value")
553      assert result is None
554  
555  
556  def test_generate_trace_id_v4_with_uc_schema():
557      span = create_mock_otel_span(trace_id=12345, span_id=1)
558      uc_schema = "catalog.schema"
559  
560      with mock.patch(
561          "mlflow.tracing.utils.construct_trace_id_v4", return_value="trace:/catalog.schema/abc123"
562      ) as mock_construct:
563          result = generate_trace_id_v4(span, uc_schema)
564  
565          mock_construct.assert_called_once_with(uc_schema, mock.ANY)
566          assert result == "trace:/catalog.schema/abc123"
567  
568  
569  def test_get_spans_table_name_for_trace_with_destination():
570      mock_destination = UCSchemaLocation(catalog_name="catalog", schema_name="schema")
571  
572      with mock.patch("mlflow.tracing.provider._MLFLOW_TRACE_USER_DESTINATION") as mock_ctx:
573          mock_ctx.get.return_value = mock_destination
574  
575          result = get_active_spans_table_name()
576          assert result == "catalog.schema.mlflow_experiment_trace_otel_spans"
577  
578  
579  def test_get_spans_table_name_for_trace_no_destination():
580      with mock.patch("mlflow.tracing.provider._MLFLOW_TRACE_USER_DESTINATION") as mock_ctx:
581          mock_ctx.get.return_value = None
582  
583          result = get_active_spans_table_name()
584          assert result is None
585  
586  
587  @pytest.mark.skipif(IS_TRACING_SDK_ONLY, reason="mock_litellm_cost cannot affect server-side cost")
588  @pytest.mark.parametrize("is_databricks", [True, False])
589  def test_cost_not_computed_client_side(is_databricks, mock_litellm_cost):
590      # Mock should_compute_cost_client_side in the span module (where it's bound at import time)
591      # rather than is_databricks_uri. Mocking is_databricks_uri captures the reference in
592      # mlflow_v3.py during lazy import and causes _export_spans_incrementally to skip spans.
593      with (
594          mock.patch(
595              "mlflow.entities.span.should_compute_cost_client_side", return_value=is_databricks
596          ),
597          mock.patch(
598              "mlflow.tracing.processor.base_mlflow.should_compute_cost_client_side",
599              return_value=is_databricks,
600          ),
601          mock.patch(
602              "mlflow.entities.span.set_span_cost_attribute", wraps=lambda span: None
603          ) as mock_set_cost,
604      ):
605          with mlflow.start_span(name="llm_span") as span:
606              span.set_attribute(SpanAttributeKey.MODEL, "gpt-5")
607              span.set_attribute(
608                  SpanAttributeKey.CHAT_USAGE,
609                  {
610                      TokenUsageKey.INPUT_TOKENS: 100,
611                      TokenUsageKey.OUTPUT_TOKENS: 50,
612                      TokenUsageKey.TOTAL_TOKENS: 150,
613                  },
614              )
615          # Cost should be computed at server side if not in Databricks
616          if is_databricks:
617              mock_set_cost.assert_called_once()
618          else:
619              mock_set_cost.assert_not_called()
620  
621      trace = mlflow.get_trace(trace_id=span.trace_id, flush=True)
622      # cost should be set
623      assert trace.info.cost is not None
624      assert CostKey.INPUT_COST in trace.info.cost
625      assert CostKey.OUTPUT_COST in trace.info.cost
626      assert CostKey.TOTAL_COST in trace.info.cost
627  
628  
629  def test_generate_trace_id_v4_from_otel_trace_id():
630      otel_trace_id = 0x12345678901234567890123456789012
631      location = "catalog.schema"
632  
633      result = generate_trace_id_v4_from_otel_trace_id(otel_trace_id, location)
634  
635      # Verify the format is trace:/<location>/<hex_trace_id>
636      assert result.startswith(f"{TRACE_ID_V4_PREFIX}{location}/")
637  
638      # Extract and verify the hex trace ID part
639      expected_hex_id = encode_trace_id(otel_trace_id)
640      assert result == f"{TRACE_ID_V4_PREFIX}{location}/{expected_hex_id}"
641  
642      # Verify it can be parsed back
643      parsed_location, parsed_id = parse_trace_id_v4(result)
644      assert parsed_location == location
645      assert parsed_id == expected_hex_id
646  
647  
648  def test_builtin_cost_fallback_when_litellm_unavailable():
649      with mock.patch.dict("sys.modules", {"litellm": None}):
650          result = calculate_cost_by_model_and_token_usage(
651              "gpt-4o", {"input_tokens": 1000, "output_tokens": 500}
652          )
653      assert result is not None
654      assert result["input_cost"] == pytest.approx(0.0025)
655      assert result["output_cost"] == pytest.approx(0.005)
656      assert result["total_cost"] == pytest.approx(0.0075)
657  
658  
659  def test_builtin_cost_fallback_returns_none_for_unknown_model():
660      with mock.patch.dict("sys.modules", {"litellm": None}):
661          result = calculate_cost_by_model_and_token_usage(
662              "unknown-model", {"input_tokens": 100, "output_tokens": 50}
663          )
664      assert result is None
665  
666  
667  def test_builtin_cost_fallback_with_cache_tokens():
668      with mock.patch.dict("sys.modules", {"litellm": None}):
669          result = calculate_cost_by_model_and_token_usage(
670              "gpt-4o",
671              {
672                  "input_tokens": 1000,
673                  "output_tokens": 500,
674                  "cache_read_input_tokens": 200,
675              },
676          )
677      assert result is not None
678      assert result["input_cost"] == pytest.approx(0.00225)
679  
680  
681  def test_builtin_cost_fallback_with_provider():
682      with mock.patch.dict("sys.modules", {"litellm": None}):
683          result = calculate_cost_by_model_and_token_usage(
684              "gpt-4o",
685              {"input_tokens": 1000, "output_tokens": 500},
686              model_provider="openai",
687          )
688      assert result is not None
689      assert result["total_cost"] == pytest.approx(0.0075)
690  
691  
692  @pytest.mark.parametrize("model_provider", ["OpenAI", "OPENAI", "openai"])
693  def test_builtin_cost_fallback_with_provider_case_insensitive(model_provider):
694      with mock.patch.dict("sys.modules", {"litellm": None}):
695          result = calculate_cost_by_model_and_token_usage(
696              "gpt-4o",
697              {"input_tokens": 1000, "output_tokens": 500},
698              model_provider=model_provider,
699          )
700      assert result is not None
701      assert result["total_cost"] == pytest.approx(0.0075)
702  
703  
704  @pytest.mark.parametrize("model_name", ["gateway:/my-endpoint", "endpoints:/my-endpoint"])
705  def test_cost_skipped_for_internal_routing_uris(model_name):
706      result = calculate_cost_by_model_and_token_usage(
707          model_name, {"input_tokens": 1000, "output_tokens": 500}
708      )
709      assert result is None
710  
711  
712  def test_litellm_provider_list_not_printed_during_cost_calculation(capsys):
713      litellm.suppress_debug_info = False
714  
715      calculate_cost_by_model_and_token_usage(
716          model_name="databricks-claude-sonnet-4-5",
717          usage={TokenUsageKey.INPUT_TOKENS: 10, TokenUsageKey.OUTPUT_TOKENS: 5},
718      )
719  
720      captured = capsys.readouterr()
721      assert "Provider List" not in captured.out
722      assert litellm.suppress_debug_info is False
723  
724  
725  def test_litellm_provider_list_printed_when_debug_logging(capsys):
726      litellm.suppress_debug_info = True
727  
728      _logger = logging.getLogger("mlflow.tracing.utils")
729      original_level = _logger.level
730      _logger.setLevel(logging.DEBUG)
731      try:
732          calculate_cost_by_model_and_token_usage(
733              model_name="databricks-claude-sonnet-4-5",
734              usage={TokenUsageKey.INPUT_TOKENS: 10, TokenUsageKey.OUTPUT_TOKENS: 5},
735          )
736      finally:
737          _logger.setLevel(original_level)
738  
739      captured = capsys.readouterr()
740      assert "Provider List" in captured.out
741      # During the call to calculate cost, suppress was set to False
742      # We are asserting that suppress is reset to the original value after
743      assert litellm.suppress_debug_info is True
744  
745  
746  def test_dump_span_attribute_value_handles_circular_reference():
747      cyclic = {"name": "run_context"}
748      cyclic["self"] = cyclic
749  
750      with pytest.raises(ValueError, match="Circular reference detected"):
751          json.dumps(cyclic)
752  
753      # Must not raise; fall back result is a valid JSON string containing repr(value).
754      result = dump_span_attribute_value(cyclic)
755      loaded = json.loads(result)
756      assert isinstance(loaded, str)
757      assert "run_context" in loaded