test_utils.py
1 import json 2 import logging 3 from unittest import mock 4 from unittest.mock import Mock, patch 5 6 import litellm 7 import pytest 8 from opentelemetry import trace as trace_api 9 from pydantic import ValidationError 10 11 import mlflow 12 from mlflow.entities import ( 13 LiveSpan, 14 SpanType, 15 ) 16 from mlflow.entities.span import SpanType 17 from mlflow.entities.trace_location import UCSchemaLocation 18 from mlflow.exceptions import MlflowException 19 from mlflow.tracing import set_span_chat_tools 20 from mlflow.tracing.constant import ( 21 TRACE_ID_V4_PREFIX, 22 CostKey, 23 SpanAttributeKey, 24 TokenUsageKey, 25 ) 26 from mlflow.tracing.utils import ( 27 _calculate_percentile, 28 aggregate_cost_from_spans, 29 aggregate_usage_from_spans, 30 calculate_cost_by_model_and_token_usage, 31 capture_function_input_args, 32 construct_full_inputs, 33 dump_span_attribute_value, 34 encode_span_id, 35 encode_trace_id, 36 generate_trace_id_v4, 37 generate_trace_id_v4_from_otel_trace_id, 38 get_active_spans_table_name, 39 get_otel_attribute, 40 maybe_get_request_id, 41 parse_trace_id_v4, 42 ) 43 from mlflow.version import IS_TRACING_SDK_ONLY 44 45 from tests.tracing.helper import create_mock_otel_span 46 47 48 def test_capture_function_input_args_does_not_raise(): 49 # Exception during inspecting inputs: trace should be logged without inputs field 50 with patch("inspect.signature", side_effect=ValueError("Some error")) as mock_input_args: 51 args = capture_function_input_args(lambda: None, (), {}) 52 53 assert args is None 54 assert mock_input_args.call_count > 0 55 56 57 def test_duplicate_span_names(): 58 span_names = ["red", "red", "blue", "red", "green", "blue"] 59 60 spans = [ 61 LiveSpan(create_mock_otel_span("trace_id", span_id=i, name=span_name), trace_id="tr-123") 62 for i, span_name in enumerate(span_names) 63 ] 64 65 assert [span.name for span in spans] == span_names 66 # Check if the span order is preserved 67 assert [span.span_id for span in spans] == [encode_span_id(i) for i in [0, 1, 2, 3, 4, 5]] 68 69 70 def test_aggregate_usage_from_spans(): 71 spans = [ 72 LiveSpan(create_mock_otel_span("trace_id", span_id=i, name=f"span_{i}"), trace_id="tr-123") 73 for i in range(3) 74 ] 75 spans[0].set_attribute( 76 SpanAttributeKey.CHAT_USAGE, 77 { 78 TokenUsageKey.INPUT_TOKENS: 10, 79 TokenUsageKey.OUTPUT_TOKENS: 20, 80 TokenUsageKey.TOTAL_TOKENS: 30, 81 }, 82 ) 83 spans[1].set_attribute( 84 SpanAttributeKey.CHAT_USAGE, 85 {TokenUsageKey.OUTPUT_TOKENS: 15, TokenUsageKey.TOTAL_TOKENS: 15}, 86 ) 87 spans[2].set_attribute( 88 SpanAttributeKey.CHAT_USAGE, 89 { 90 TokenUsageKey.INPUT_TOKENS: 5, 91 TokenUsageKey.OUTPUT_TOKENS: 10, 92 TokenUsageKey.TOTAL_TOKENS: 15, 93 }, 94 ) 95 96 usage = aggregate_usage_from_spans(spans) 97 assert usage == { 98 TokenUsageKey.INPUT_TOKENS: 15, 99 TokenUsageKey.OUTPUT_TOKENS: 45, 100 TokenUsageKey.TOTAL_TOKENS: 60, 101 } 102 103 104 def test_aggregate_usage_from_spans_skips_descendant_usage(): 105 spans = [ 106 LiveSpan(create_mock_otel_span("trace_id", span_id=1, name="root"), trace_id="tr-123"), 107 LiveSpan( 108 create_mock_otel_span("trace_id", span_id=2, name="child", parent_id=1), 109 trace_id="tr-123", 110 ), 111 LiveSpan( 112 create_mock_otel_span("trace_id", span_id=3, name="grandchild", parent_id=2), 113 trace_id="tr-123", 114 ), 115 LiveSpan( 116 create_mock_otel_span("trace_id", span_id=4, name="independent"), trace_id="tr-123" 117 ), 118 ] 119 120 spans[0].set_attribute( 121 SpanAttributeKey.CHAT_USAGE, 122 { 123 TokenUsageKey.INPUT_TOKENS: 10, 124 TokenUsageKey.OUTPUT_TOKENS: 20, 125 TokenUsageKey.TOTAL_TOKENS: 30, 126 }, 127 ) 128 129 spans[2].set_attribute( 130 SpanAttributeKey.CHAT_USAGE, 131 { 132 TokenUsageKey.INPUT_TOKENS: 5, 133 TokenUsageKey.OUTPUT_TOKENS: 10, 134 TokenUsageKey.TOTAL_TOKENS: 15, 135 }, 136 ) 137 138 spans[3].set_attribute( 139 SpanAttributeKey.CHAT_USAGE, 140 { 141 TokenUsageKey.INPUT_TOKENS: 3, 142 TokenUsageKey.OUTPUT_TOKENS: 6, 143 TokenUsageKey.TOTAL_TOKENS: 9, 144 }, 145 ) 146 147 usage = aggregate_usage_from_spans(spans) 148 149 assert usage == { 150 TokenUsageKey.INPUT_TOKENS: 13, 151 TokenUsageKey.OUTPUT_TOKENS: 26, 152 TokenUsageKey.TOTAL_TOKENS: 39, 153 } 154 155 156 def test_aggregate_usage_from_spans_with_cached_tokens(): 157 spans = [ 158 LiveSpan(create_mock_otel_span("trace_id", span_id=i, name=f"span_{i}"), trace_id="tr-123") 159 for i in range(3) 160 ] 161 spans[0].set_attribute( 162 SpanAttributeKey.CHAT_USAGE, 163 { 164 TokenUsageKey.INPUT_TOKENS: 100, 165 TokenUsageKey.OUTPUT_TOKENS: 50, 166 TokenUsageKey.TOTAL_TOKENS: 150, 167 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 80, 168 }, 169 ) 170 spans[1].set_attribute( 171 SpanAttributeKey.CHAT_USAGE, 172 { 173 TokenUsageKey.INPUT_TOKENS: 200, 174 TokenUsageKey.OUTPUT_TOKENS: 100, 175 TokenUsageKey.TOTAL_TOKENS: 300, 176 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 120, 177 TokenUsageKey.CACHE_CREATION_INPUT_TOKENS: 50, 178 }, 179 ) 180 # span without cached tokens 181 spans[2].set_attribute( 182 SpanAttributeKey.CHAT_USAGE, 183 { 184 TokenUsageKey.INPUT_TOKENS: 10, 185 TokenUsageKey.OUTPUT_TOKENS: 5, 186 TokenUsageKey.TOTAL_TOKENS: 15, 187 }, 188 ) 189 190 usage = aggregate_usage_from_spans(spans) 191 assert usage == { 192 TokenUsageKey.INPUT_TOKENS: 310, 193 TokenUsageKey.OUTPUT_TOKENS: 155, 194 TokenUsageKey.TOTAL_TOKENS: 465, 195 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 200, 196 TokenUsageKey.CACHE_CREATION_INPUT_TOKENS: 50, 197 } 198 199 200 def test_aggregate_usage_from_spans_without_cached_tokens_omits_keys(): 201 spans = [ 202 LiveSpan(create_mock_otel_span("trace_id", span_id=0, name="span_0"), trace_id="tr-123") 203 ] 204 spans[0].set_attribute( 205 SpanAttributeKey.CHAT_USAGE, 206 { 207 TokenUsageKey.INPUT_TOKENS: 10, 208 TokenUsageKey.OUTPUT_TOKENS: 5, 209 TokenUsageKey.TOTAL_TOKENS: 15, 210 }, 211 ) 212 213 usage = aggregate_usage_from_spans(spans) 214 assert usage == { 215 TokenUsageKey.INPUT_TOKENS: 10, 216 TokenUsageKey.OUTPUT_TOKENS: 5, 217 TokenUsageKey.TOTAL_TOKENS: 15, 218 } 219 # Cached keys should not be present 220 assert TokenUsageKey.CACHE_READ_INPUT_TOKENS not in usage 221 assert TokenUsageKey.CACHE_CREATION_INPUT_TOKENS not in usage 222 223 224 def test_aggregate_cost_from_spans(): 225 spans = [ 226 LiveSpan(create_mock_otel_span("trace_id", span_id=i, name=f"span_{i}"), trace_id="tr-123") 227 for i in range(3) 228 ] 229 spans[0].set_attribute( 230 SpanAttributeKey.LLM_COST, 231 { 232 CostKey.INPUT_COST: 10, 233 CostKey.OUTPUT_COST: 20, 234 CostKey.TOTAL_COST: 30, 235 }, 236 ) 237 spans[1].set_attribute( 238 SpanAttributeKey.LLM_COST, 239 {CostKey.OUTPUT_COST: 15, CostKey.TOTAL_COST: 15}, 240 ) 241 spans[2].set_attribute( 242 SpanAttributeKey.LLM_COST, 243 { 244 CostKey.INPUT_COST: 5, 245 CostKey.OUTPUT_COST: 10, 246 CostKey.TOTAL_COST: 15, 247 }, 248 ) 249 250 cost = aggregate_cost_from_spans(spans) 251 assert cost == { 252 CostKey.INPUT_COST: 15, 253 CostKey.OUTPUT_COST: 45, 254 CostKey.TOTAL_COST: 60, 255 } 256 257 258 def test_aggregate_cost_from_spans_skips_descendant_cost(): 259 spans = [ 260 LiveSpan(create_mock_otel_span("trace_id", span_id=1, name="root"), trace_id="tr-123"), 261 LiveSpan( 262 create_mock_otel_span("trace_id", span_id=2, name="child", parent_id=1), 263 trace_id="tr-123", 264 ), 265 LiveSpan( 266 create_mock_otel_span("trace_id", span_id=3, name="grandchild", parent_id=2), 267 trace_id="tr-123", 268 ), 269 LiveSpan( 270 create_mock_otel_span("trace_id", span_id=4, name="independent"), trace_id="tr-123" 271 ), 272 ] 273 274 spans[0].set_attribute( 275 SpanAttributeKey.LLM_COST, 276 { 277 CostKey.INPUT_COST: 10, 278 CostKey.OUTPUT_COST: 20, 279 CostKey.TOTAL_COST: 30, 280 }, 281 ) 282 283 spans[2].set_attribute( 284 SpanAttributeKey.LLM_COST, 285 { 286 CostKey.INPUT_COST: 5, 287 CostKey.OUTPUT_COST: 10, 288 CostKey.TOTAL_COST: 15, 289 }, 290 ) 291 292 spans[3].set_attribute( 293 SpanAttributeKey.LLM_COST, 294 { 295 CostKey.INPUT_COST: 3, 296 CostKey.OUTPUT_COST: 6, 297 CostKey.TOTAL_COST: 9, 298 }, 299 ) 300 301 cost = aggregate_cost_from_spans(spans) 302 303 assert cost == { 304 CostKey.INPUT_COST: 13, 305 CostKey.OUTPUT_COST: 26, 306 CostKey.TOTAL_COST: 39, 307 } 308 309 310 def test_maybe_get_request_id(): 311 assert maybe_get_request_id(is_evaluate=True) is None 312 313 try: 314 from mlflow.pyfunc.context import Context, set_prediction_context 315 except ImportError: 316 pytest.skip("Skipping the rest of tests as mlflow.pyfunc module is not available.") 317 318 with set_prediction_context(Context(request_id="eval", is_evaluate=True)): 319 assert maybe_get_request_id(is_evaluate=True) == "eval" 320 321 with set_prediction_context(Context(request_id="non_eval", is_evaluate=False)): 322 assert maybe_get_request_id(is_evaluate=True) is None 323 324 325 def test_set_chat_tools_validation(): 326 tools = [ 327 { 328 "type": "unsupported_function", 329 "unsupported_function": { 330 "name": "test", 331 }, 332 } 333 ] 334 335 @mlflow.trace(span_type=SpanType.CHAT_MODEL) 336 def dummy_call(tools): 337 span = mlflow.get_current_active_span() 338 set_span_chat_tools(span, tools) 339 return None 340 341 with pytest.raises(ValidationError, match="validation error for ChatTool"): 342 dummy_call(tools) 343 344 345 @pytest.mark.parametrize( 346 ("enum_values", "param_type"), 347 [ 348 ([1, 2, 3, 4, 5], "integer"), 349 (["option1", "option2", "option3"], "string"), 350 ([1.1, 2.5, 3.7], "number"), 351 ([True, False], "boolean"), 352 (["mixed", 42, True, 3.14], "string"), # Mixed types with string base type 353 ], 354 ) 355 def test_openai_parse_tools_enum_validation(enum_values, param_type): 356 from mlflow.openai.utils.chat_schema import _parse_tools 357 358 # Simulate the exact OpenAI autologging input that was failing 359 openai_inputs = { 360 "tools": [ 361 { 362 "type": "function", 363 "function": { 364 "name": "select_option", 365 "description": "Select an option from the given choices", 366 "parameters": { 367 "type": "object", 368 "properties": {"option": {"type": param_type, "enum": enum_values}}, 369 "required": ["option"], 370 }, 371 }, 372 } 373 ] 374 } 375 376 # This should not raise a ValidationError - tests the actual failing code path 377 parsed_tools = _parse_tools(openai_inputs) 378 assert len(parsed_tools) == 1 379 assert parsed_tools[0].function.name == "select_option" 380 assert parsed_tools[0].function.parameters.properties["option"].enum == enum_values 381 382 383 def test_construct_full_inputs_simple_function(): 384 def func(a, b, c=3, d=4, **kwargs): 385 pass 386 387 result = construct_full_inputs(func, 1, 2) 388 assert result == {"a": 1, "b": 2} 389 390 result = construct_full_inputs(func, 1, 2, c=30) 391 assert result == {"a": 1, "b": 2, "c": 30} 392 393 result = construct_full_inputs(func, 1, 2, c=30, d=40, e=50) 394 assert result == {"a": 1, "b": 2, "c": 30, "d": 40, "kwargs": {"e": 50}} 395 396 def no_args_func(): 397 pass 398 399 result = construct_full_inputs(no_args_func) 400 assert result == {} 401 402 class TestClass: 403 def func(self, a, b, c=3, d=4, **kwargs): 404 pass 405 406 result = construct_full_inputs(TestClass().func, 1, 2) 407 assert result == {"a": 1, "b": 2} 408 409 410 def test_calculate_percentile(): 411 # Test empty list 412 assert _calculate_percentile([], 0.5) == 0.0 413 414 # Test single element 415 assert _calculate_percentile([100], 0.25) == 100 416 assert _calculate_percentile([100], 0.5) == 100 417 assert _calculate_percentile([100], 0.75) == 100 418 419 # Test two elements 420 assert _calculate_percentile([10, 20], 0.0) == 10 421 assert _calculate_percentile([10, 20], 0.5) == 15 # Linear interpolation 422 assert _calculate_percentile([10, 20], 1.0) == 20 423 424 # Test odd number of elements 425 data = [10, 20, 30, 40, 50] 426 assert _calculate_percentile(data, 0.0) == 10 427 assert _calculate_percentile(data, 0.25) == 20 428 assert _calculate_percentile(data, 0.5) == 30 # Median 429 assert _calculate_percentile(data, 0.75) == 40 430 assert _calculate_percentile(data, 1.0) == 50 431 432 # Test even number of elements 433 data = [10, 20, 30, 40] 434 assert _calculate_percentile(data, 0.0) == 10 435 assert _calculate_percentile(data, 0.25) == 17.5 # Between 10 and 20 436 assert _calculate_percentile(data, 0.5) == 25 # Between 20 and 30 437 assert _calculate_percentile(data, 0.75) == 32.5 # Between 30 and 40 438 assert _calculate_percentile(data, 1.0) == 40 439 440 # Test with larger dataset 441 data = list(range(1, 101)) # 1 to 100 442 assert _calculate_percentile(data, 0.25) == 25.75 443 assert _calculate_percentile(data, 0.5) == 50.5 444 445 446 def test_parse_trace_id_v4(): 447 test_trace_id = "tr-original-trace-123" 448 449 v4_id_uc_schema = f"{TRACE_ID_V4_PREFIX}catalog.schema/{test_trace_id}" 450 location, parsed_id = parse_trace_id_v4(v4_id_uc_schema) 451 assert location == "catalog.schema" 452 assert parsed_id == test_trace_id 453 454 v4_id_experiment = f"{TRACE_ID_V4_PREFIX}experiment_id/{test_trace_id}" 455 location, parsed_id = parse_trace_id_v4(v4_id_experiment) 456 assert location == "experiment_id" 457 assert parsed_id == test_trace_id 458 459 location, parsed_id = parse_trace_id_v4(test_trace_id) 460 assert location is None 461 assert parsed_id == test_trace_id 462 463 464 def test_parse_trace_id_v4_invalid_format(): 465 with pytest.raises(MlflowException, match="Invalid trace ID format"): 466 parse_trace_id_v4(f"{TRACE_ID_V4_PREFIX}123") 467 468 with pytest.raises(MlflowException, match="Invalid trace ID format"): 469 parse_trace_id_v4(f"{TRACE_ID_V4_PREFIX}123/") 470 471 with pytest.raises(MlflowException, match="Invalid trace ID format"): 472 parse_trace_id_v4(f"{TRACE_ID_V4_PREFIX}catalog.schema/../invalid-trace-id") 473 474 with pytest.raises(MlflowException, match="Invalid trace ID format"): 475 parse_trace_id_v4(f"{TRACE_ID_V4_PREFIX}catalog.schema/invalid-trace-id/invalid-format") 476 477 478 def test_get_otel_attribute_existing_attribute(): 479 # Create a mock span with attributes 480 span = Mock(spec=trace_api.Span) 481 span.attributes = { 482 "test_key": json.dumps({"data": "value"}), 483 "string_key": json.dumps("simple_string"), 484 "number_key": json.dumps(42), 485 "boolean_key": json.dumps(True), 486 "list_key": json.dumps([1, 2, 3]), 487 } 488 489 # Test various data types 490 result = get_otel_attribute(span, "test_key") 491 assert result == {"data": "value"} 492 493 result = get_otel_attribute(span, "string_key") 494 assert result == "simple_string" 495 496 result = get_otel_attribute(span, "number_key") 497 assert result == 42 498 499 result = get_otel_attribute(span, "boolean_key") 500 assert result is True 501 502 result = get_otel_attribute(span, "list_key") 503 assert result == [1, 2, 3] 504 505 506 def test_get_otel_attribute_missing_attribute(): 507 # Create a mock span with empty attributes 508 span = Mock(spec=trace_api.Span) 509 span.attributes = {} 510 511 result = get_otel_attribute(span, "nonexistent_key") 512 assert result is None 513 514 515 def test_get_otel_attribute_none_attribute(): 516 # Create a mock span where attributes.get() returns None 517 span = Mock(spec=trace_api.Span) 518 span.attributes = Mock() 519 span.attributes.get.return_value = None 520 521 result = get_otel_attribute(span, "any_key") 522 assert result is None 523 524 525 def test_get_otel_attribute_invalid_json(): 526 # Create a mock span with invalid JSON 527 span = Mock(spec=trace_api.Span) 528 span.attributes = { 529 "invalid_json": "not valid json {", 530 "empty_string": "", 531 } 532 533 result = get_otel_attribute(span, "invalid_json") 534 assert result is None 535 536 result = get_otel_attribute(span, "empty_string") 537 assert result is None 538 539 540 def test_get_otel_attribute_non_string_attribute(): 541 # In some edge cases, attributes might contain non-string values 542 span = Mock(spec=trace_api.Span) 543 span.attributes = { 544 "number_value": 123, # Not a JSON string 545 "boolean_value": True, # Not a JSON string 546 } 547 548 # These should fail gracefully and return None 549 result = get_otel_attribute(span, "number_value") 550 assert result is None 551 552 result = get_otel_attribute(span, "boolean_value") 553 assert result is None 554 555 556 def test_generate_trace_id_v4_with_uc_schema(): 557 span = create_mock_otel_span(trace_id=12345, span_id=1) 558 uc_schema = "catalog.schema" 559 560 with mock.patch( 561 "mlflow.tracing.utils.construct_trace_id_v4", return_value="trace:/catalog.schema/abc123" 562 ) as mock_construct: 563 result = generate_trace_id_v4(span, uc_schema) 564 565 mock_construct.assert_called_once_with(uc_schema, mock.ANY) 566 assert result == "trace:/catalog.schema/abc123" 567 568 569 def test_get_spans_table_name_for_trace_with_destination(): 570 mock_destination = UCSchemaLocation(catalog_name="catalog", schema_name="schema") 571 572 with mock.patch("mlflow.tracing.provider._MLFLOW_TRACE_USER_DESTINATION") as mock_ctx: 573 mock_ctx.get.return_value = mock_destination 574 575 result = get_active_spans_table_name() 576 assert result == "catalog.schema.mlflow_experiment_trace_otel_spans" 577 578 579 def test_get_spans_table_name_for_trace_no_destination(): 580 with mock.patch("mlflow.tracing.provider._MLFLOW_TRACE_USER_DESTINATION") as mock_ctx: 581 mock_ctx.get.return_value = None 582 583 result = get_active_spans_table_name() 584 assert result is None 585 586 587 @pytest.mark.skipif(IS_TRACING_SDK_ONLY, reason="mock_litellm_cost cannot affect server-side cost") 588 @pytest.mark.parametrize("is_databricks", [True, False]) 589 def test_cost_not_computed_client_side(is_databricks, mock_litellm_cost): 590 # Mock should_compute_cost_client_side in the span module (where it's bound at import time) 591 # rather than is_databricks_uri. Mocking is_databricks_uri captures the reference in 592 # mlflow_v3.py during lazy import and causes _export_spans_incrementally to skip spans. 593 with ( 594 mock.patch( 595 "mlflow.entities.span.should_compute_cost_client_side", return_value=is_databricks 596 ), 597 mock.patch( 598 "mlflow.tracing.processor.base_mlflow.should_compute_cost_client_side", 599 return_value=is_databricks, 600 ), 601 mock.patch( 602 "mlflow.entities.span.set_span_cost_attribute", wraps=lambda span: None 603 ) as mock_set_cost, 604 ): 605 with mlflow.start_span(name="llm_span") as span: 606 span.set_attribute(SpanAttributeKey.MODEL, "gpt-5") 607 span.set_attribute( 608 SpanAttributeKey.CHAT_USAGE, 609 { 610 TokenUsageKey.INPUT_TOKENS: 100, 611 TokenUsageKey.OUTPUT_TOKENS: 50, 612 TokenUsageKey.TOTAL_TOKENS: 150, 613 }, 614 ) 615 # Cost should be computed at server side if not in Databricks 616 if is_databricks: 617 mock_set_cost.assert_called_once() 618 else: 619 mock_set_cost.assert_not_called() 620 621 trace = mlflow.get_trace(trace_id=span.trace_id, flush=True) 622 # cost should be set 623 assert trace.info.cost is not None 624 assert CostKey.INPUT_COST in trace.info.cost 625 assert CostKey.OUTPUT_COST in trace.info.cost 626 assert CostKey.TOTAL_COST in trace.info.cost 627 628 629 def test_generate_trace_id_v4_from_otel_trace_id(): 630 otel_trace_id = 0x12345678901234567890123456789012 631 location = "catalog.schema" 632 633 result = generate_trace_id_v4_from_otel_trace_id(otel_trace_id, location) 634 635 # Verify the format is trace:/<location>/<hex_trace_id> 636 assert result.startswith(f"{TRACE_ID_V4_PREFIX}{location}/") 637 638 # Extract and verify the hex trace ID part 639 expected_hex_id = encode_trace_id(otel_trace_id) 640 assert result == f"{TRACE_ID_V4_PREFIX}{location}/{expected_hex_id}" 641 642 # Verify it can be parsed back 643 parsed_location, parsed_id = parse_trace_id_v4(result) 644 assert parsed_location == location 645 assert parsed_id == expected_hex_id 646 647 648 def test_builtin_cost_fallback_when_litellm_unavailable(): 649 with mock.patch.dict("sys.modules", {"litellm": None}): 650 result = calculate_cost_by_model_and_token_usage( 651 "gpt-4o", {"input_tokens": 1000, "output_tokens": 500} 652 ) 653 assert result is not None 654 assert result["input_cost"] == pytest.approx(0.0025) 655 assert result["output_cost"] == pytest.approx(0.005) 656 assert result["total_cost"] == pytest.approx(0.0075) 657 658 659 def test_builtin_cost_fallback_returns_none_for_unknown_model(): 660 with mock.patch.dict("sys.modules", {"litellm": None}): 661 result = calculate_cost_by_model_and_token_usage( 662 "unknown-model", {"input_tokens": 100, "output_tokens": 50} 663 ) 664 assert result is None 665 666 667 def test_builtin_cost_fallback_with_cache_tokens(): 668 with mock.patch.dict("sys.modules", {"litellm": None}): 669 result = calculate_cost_by_model_and_token_usage( 670 "gpt-4o", 671 { 672 "input_tokens": 1000, 673 "output_tokens": 500, 674 "cache_read_input_tokens": 200, 675 }, 676 ) 677 assert result is not None 678 assert result["input_cost"] == pytest.approx(0.00225) 679 680 681 def test_builtin_cost_fallback_with_provider(): 682 with mock.patch.dict("sys.modules", {"litellm": None}): 683 result = calculate_cost_by_model_and_token_usage( 684 "gpt-4o", 685 {"input_tokens": 1000, "output_tokens": 500}, 686 model_provider="openai", 687 ) 688 assert result is not None 689 assert result["total_cost"] == pytest.approx(0.0075) 690 691 692 @pytest.mark.parametrize("model_provider", ["OpenAI", "OPENAI", "openai"]) 693 def test_builtin_cost_fallback_with_provider_case_insensitive(model_provider): 694 with mock.patch.dict("sys.modules", {"litellm": None}): 695 result = calculate_cost_by_model_and_token_usage( 696 "gpt-4o", 697 {"input_tokens": 1000, "output_tokens": 500}, 698 model_provider=model_provider, 699 ) 700 assert result is not None 701 assert result["total_cost"] == pytest.approx(0.0075) 702 703 704 @pytest.mark.parametrize("model_name", ["gateway:/my-endpoint", "endpoints:/my-endpoint"]) 705 def test_cost_skipped_for_internal_routing_uris(model_name): 706 result = calculate_cost_by_model_and_token_usage( 707 model_name, {"input_tokens": 1000, "output_tokens": 500} 708 ) 709 assert result is None 710 711 712 def test_litellm_provider_list_not_printed_during_cost_calculation(capsys): 713 litellm.suppress_debug_info = False 714 715 calculate_cost_by_model_and_token_usage( 716 model_name="databricks-claude-sonnet-4-5", 717 usage={TokenUsageKey.INPUT_TOKENS: 10, TokenUsageKey.OUTPUT_TOKENS: 5}, 718 ) 719 720 captured = capsys.readouterr() 721 assert "Provider List" not in captured.out 722 assert litellm.suppress_debug_info is False 723 724 725 def test_litellm_provider_list_printed_when_debug_logging(capsys): 726 litellm.suppress_debug_info = True 727 728 _logger = logging.getLogger("mlflow.tracing.utils") 729 original_level = _logger.level 730 _logger.setLevel(logging.DEBUG) 731 try: 732 calculate_cost_by_model_and_token_usage( 733 model_name="databricks-claude-sonnet-4-5", 734 usage={TokenUsageKey.INPUT_TOKENS: 10, TokenUsageKey.OUTPUT_TOKENS: 5}, 735 ) 736 finally: 737 _logger.setLevel(original_level) 738 739 captured = capsys.readouterr() 740 assert "Provider List" in captured.out 741 # During the call to calculate cost, suppress was set to False 742 # We are asserting that suppress is reset to the original value after 743 assert litellm.suppress_debug_info is True 744 745 746 def test_dump_span_attribute_value_handles_circular_reference(): 747 cyclic = {"name": "run_context"} 748 cyclic["self"] = cyclic 749 750 with pytest.raises(ValueError, match="Circular reference detected"): 751 json.dumps(cyclic) 752 753 # Must not raise; fall back result is a valid JSON string containing repr(value). 754 result = dump_span_attribute_value(cyclic) 755 loaded = json.loads(result) 756 assert isinstance(loaded, str) 757 assert "run_context" in loaded