test_openai_autolog.py
1 import json 2 import re 3 import sys 4 from unittest import mock 5 6 import httpx 7 import openai 8 import pytest 9 from packaging.version import Version 10 from pydantic import BaseModel 11 12 import mlflow 13 from mlflow.entities.span import SpanType 14 from mlflow.exceptions import MlflowException 15 from mlflow.openai.utils.chat_schema import _parse_tools 16 from mlflow.tracing.constant import ( 17 STREAM_CHUNK_EVENT_VALUE_KEY, 18 CostKey, 19 SpanAttributeKey, 20 TokenUsageKey, 21 TraceMetadataKey, 22 ) 23 from mlflow.version import IS_TRACING_SDK_ONLY 24 25 from tests.openai.mock_openai import EMPTY_CHOICES, LIST_CONTENT 26 from tests.tracing.helper import get_traces, skip_when_testing_trace_sdk 27 28 MOCK_TOOLS = [ 29 { 30 "type": "function", 31 "function": { 32 "name": "add", 33 "description": "Add two numbers", 34 "parameters": { 35 "type": "object", 36 "properties": { 37 "a": {"type": "number"}, 38 "b": {"type": "number"}, 39 }, 40 "required": ["a", "b"], 41 }, 42 }, 43 } 44 ] 45 46 47 @pytest.fixture(params=[True, False], ids=["sync", "async"]) 48 def client(request, monkeypatch, mock_openai): 49 monkeypatch.setenv("OPENAI_API_KEY", "test") 50 monkeypatch.setenv("OPENAI_API_BASE", mock_openai) 51 if request.param: 52 client = openai.OpenAI(api_key="test", base_url=mock_openai) 53 client._is_async = False 54 return client 55 else: 56 client = openai.AsyncOpenAI(api_key="test", base_url=mock_openai) 57 client._is_async = True 58 return client 59 60 61 @pytest.fixture 62 def completion_models(): 63 return [ 64 mlflow.openai.log_model( 65 "gpt-4o-mini", 66 "completions", 67 name="model", 68 temperature=temp, 69 prompt="Say {text}", 70 pip_requirements=["mlflow"], # Hard code for speed up 71 ) 72 for temp in [0.1, 0.2, 0.3] 73 ] 74 75 76 @pytest.fixture 77 def embedding_models(): 78 float_model = mlflow.openai.log_model( 79 "text-embedding-ada-002", 80 "embeddings", 81 name="model", 82 encoding_format="float", 83 pip_requirements=["mlflow"], # Hard code for speed up 84 ) 85 base64_model = mlflow.openai.log_model( 86 "text-embedding-ada-002", 87 "embeddings", 88 name="model", 89 encoding_format="base64", 90 pip_requirements=["mlflow"], # Hard code for speed up 91 ) 92 return [float_model, base64_model] 93 94 95 @pytest.mark.asyncio 96 @pytest.mark.skipif( 97 Version(openai.__version__) < Version("1.66"), reason="Cost tracking does not work before 1.66" 98 ) 99 async def test_chat_completions_autolog(client, mock_litellm_cost): 100 mlflow.openai.autolog() 101 102 messages = [{"role": "user", "content": "test"}] 103 response = client.chat.completions.create( 104 messages=messages, 105 model="gpt-4o-mini", 106 temperature=0, 107 ) 108 109 if client._is_async: 110 await response 111 112 traces = get_traces() 113 assert len(traces) == 1 114 trace = traces[0] 115 assert trace is not None 116 assert trace.info.status == "OK" 117 assert len(trace.data.spans) == 1 118 span = trace.data.spans[0] 119 assert span.span_type == SpanType.CHAT_MODEL 120 assert span.inputs == {"messages": messages, "model": "gpt-4o-mini", "temperature": 0} 121 assert span.outputs["id"] == "chatcmpl-123" 122 assert span.attributes["model"] == "gpt-4o-mini" 123 assert span.attributes["temperature"] == 0 124 125 assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == { 126 TokenUsageKey.INPUT_TOKENS: 9, 127 TokenUsageKey.OUTPUT_TOKENS: 12, 128 TokenUsageKey.TOTAL_TOKENS: 21, 129 } 130 assert span.model_name == "gpt-4o-mini" 131 assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "openai" 132 133 if not IS_TRACING_SDK_ONLY: 134 # Verify cost is calculated (9 input tokens * 1.0 + 12 output tokens * 2.0) 135 assert span.llm_cost == { 136 "input_cost": 9.0, 137 "output_cost": 24.0, 138 "total_cost": 33.0, 139 } 140 141 assert TraceMetadataKey.SOURCE_RUN not in trace.info.request_metadata 142 assert trace.info.token_usage == { 143 TokenUsageKey.INPUT_TOKENS: 9, 144 TokenUsageKey.OUTPUT_TOKENS: 12, 145 TokenUsageKey.TOTAL_TOKENS: 21, 146 } 147 if not IS_TRACING_SDK_ONLY: 148 assert trace.info.cost == { 149 CostKey.INPUT_COST: 9.0, 150 CostKey.OUTPUT_COST: 24.0, 151 CostKey.TOTAL_COST: 33.0, 152 } 153 154 155 @pytest.mark.asyncio 156 @pytest.mark.skipif( 157 Version(openai.__version__) < Version("1.66"), reason="Cost tracking does not work before 1.66" 158 ) 159 async def test_chat_completions_autolog_with_cached_tokens(client, mock_litellm_cost): 160 mlflow.openai.autolog() 161 162 mock_response = { 163 "id": "chatcmpl-cached", 164 "object": "chat.completion", 165 "created": 1677652288, 166 "model": "gpt-4o-mini", 167 "choices": [ 168 { 169 "index": 0, 170 "message": {"role": "assistant", "content": "Hello"}, 171 "logprobs": None, 172 "finish_reason": "stop", 173 } 174 ], 175 "usage": { 176 "prompt_tokens": 50, 177 "completion_tokens": 20, 178 "total_tokens": 70, 179 "prompt_tokens_details": {"cached_tokens": 30, "audio_tokens": 0}, 180 "completion_tokens_details": {"reasoning_tokens": 0}, 181 }, 182 } 183 184 if client._is_async: 185 patch_target = "httpx.AsyncClient.send" 186 187 async def send_patch(self, request, *args, **kwargs): 188 return httpx.Response(status_code=200, request=request, json=mock_response) 189 190 else: 191 patch_target = "httpx.Client.send" 192 193 def send_patch(self, request, *args, **kwargs): 194 return httpx.Response(status_code=200, request=request, json=mock_response) 195 196 with mock.patch(patch_target, send_patch): 197 response = client.chat.completions.create( 198 messages=[{"role": "user", "content": "test"}], 199 model="gpt-4o-mini", 200 temperature=0, 201 ) 202 if client._is_async: 203 response = await response 204 205 traces = get_traces() 206 assert len(traces) == 1 207 span = traces[0].data.spans[0] 208 209 assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == { 210 TokenUsageKey.INPUT_TOKENS: 50, 211 TokenUsageKey.OUTPUT_TOKENS: 20, 212 TokenUsageKey.TOTAL_TOKENS: 70, 213 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 30, 214 } 215 216 assert traces[0].info.token_usage == { 217 TokenUsageKey.INPUT_TOKENS: 50, 218 TokenUsageKey.OUTPUT_TOKENS: 20, 219 TokenUsageKey.TOTAL_TOKENS: 70, 220 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 30, 221 } 222 223 224 @pytest.mark.asyncio 225 @pytest.mark.skipif( 226 Version(openai.__version__) < Version("1.66"), reason="Cost tracking does not work before 1.66" 227 ) 228 async def test_chat_completions_autolog_under_current_active_span(client): 229 # If a user have an active span, the autologging should create a child span under it. 230 mlflow.openai.autolog() 231 232 messages = [{"role": "user", "content": "test"}] 233 with mlflow.start_span(name="parent"): 234 for _ in range(3): 235 response = client.chat.completions.create( 236 messages=messages, 237 model="gpt-4o-mini", 238 temperature=0, 239 ) 240 241 if client._is_async: 242 await response 243 244 traces = get_traces() 245 assert len(traces) == 1 246 trace = traces[0] 247 assert trace is not None 248 assert trace.info.status == "OK" 249 assert len(trace.data.spans) == 4 250 parent_span = trace.data.spans[0] 251 assert parent_span.name == "parent" 252 child_span = trace.data.spans[1] 253 assert child_span.name == "AsyncCompletions" if client._is_async else "Completions" 254 assert child_span.inputs == {"messages": messages, "model": "gpt-4o-mini", "temperature": 0} 255 assert child_span.outputs["id"] == "chatcmpl-123" 256 assert child_span.parent_id == parent_span.span_id 257 258 # Token usage should be aggregated correctly 259 assert trace.info.token_usage == { 260 TokenUsageKey.INPUT_TOKENS: 27, 261 TokenUsageKey.OUTPUT_TOKENS: 36, 262 TokenUsageKey.TOTAL_TOKENS: 63, 263 } 264 265 266 @pytest.mark.asyncio 267 @pytest.mark.parametrize("include_usage", [True, False]) 268 async def test_chat_completions_autolog_streaming(client, include_usage): 269 mlflow.openai.autolog() 270 271 stream_options_supported = Version(openai.__version__) >= Version("1.26") 272 273 if not stream_options_supported and include_usage: 274 pytest.skip("OpenAI SDK version does not support usage tracking in streaming") 275 276 messages = [{"role": "user", "content": "test"}] 277 278 input_params = { 279 "messages": messages, 280 "model": "gpt-4o-mini", 281 "temperature": 0, 282 "stream": True, 283 } 284 if stream_options_supported: 285 input_params["stream_options"] = {"include_usage": include_usage} 286 287 stream = client.chat.completions.create(**input_params) 288 289 if client._is_async: 290 async for _ in await stream: 291 pass 292 else: 293 for _ in stream: 294 pass 295 296 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 297 assert trace is not None 298 assert trace.info.status == "OK" 299 assert len(trace.data.spans) == 1 300 span = trace.data.spans[0] 301 assert span.span_type == SpanType.CHAT_MODEL 302 assert span.inputs == input_params 303 304 # Reconstructed response from streaming chunks 305 assert isinstance(span.outputs, dict) 306 assert span.outputs["id"] == "chatcmpl-123" 307 assert span.outputs["object"] == "chat.completion" 308 assert span.outputs["model"] == "gpt-4o-mini" 309 assert span.outputs["system_fingerprint"] == "fp_44709d6fcb" 310 assert "choices" in span.outputs 311 assert span.outputs["choices"][0]["message"]["role"] == "assistant" 312 assert span.outputs["choices"][0]["message"]["content"] == "Hello world" 313 314 # Usage should be preserved when include_usage=True 315 if include_usage: 316 assert "usage" in span.outputs 317 assert span.outputs["usage"]["prompt_tokens"] == 9 318 assert span.outputs["usage"]["completion_tokens"] == 12 319 assert span.outputs["usage"]["total_tokens"] == 21 320 321 stream_event_data = trace.data.spans[0].events 322 assert stream_event_data[0].name == "mlflow.chunk.item.0" 323 chunk_1 = json.loads(stream_event_data[0].attributes[STREAM_CHUNK_EVENT_VALUE_KEY]) 324 assert chunk_1["id"] == "chatcmpl-123" 325 assert chunk_1["choices"][0]["delta"]["content"] == "Hello" 326 assert stream_event_data[1].name == "mlflow.chunk.item.1" 327 chunk_2 = json.loads(stream_event_data[1].attributes[STREAM_CHUNK_EVENT_VALUE_KEY]) 328 assert chunk_2["id"] == "chatcmpl-123" 329 assert chunk_2["choices"][0]["delta"]["content"] == " world" 330 331 if include_usage: 332 assert trace.info.token_usage == { 333 TokenUsageKey.INPUT_TOKENS: 9, 334 TokenUsageKey.OUTPUT_TOKENS: 12, 335 TokenUsageKey.TOTAL_TOKENS: 21, 336 } 337 338 339 @pytest.mark.asyncio 340 async def test_chat_completions_autolog_tracing_error(client): 341 mlflow.openai.autolog() 342 messages = [{"role": "user", "content": "test"}] 343 with pytest.raises(openai.UnprocessableEntityError, match="Input should be less"): # noqa: PT012 344 response = client.chat.completions.create( 345 messages=messages, 346 model="gpt-4o-mini", 347 temperature=5.0, 348 ) 349 350 if client._is_async: 351 await response 352 353 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 354 assert trace.info.status == "ERROR" 355 356 assert len(trace.data.spans) == 1 357 span = trace.data.spans[0] 358 assert span.name == "AsyncCompletions" if client._is_async else "Completions" 359 assert span.inputs["messages"][0]["content"] == "test" 360 assert span.outputs is None 361 362 assert span.events[0].name == "exception" 363 assert span.events[0].attributes["exception.type"] == "UnprocessableEntityError" 364 365 366 @pytest.mark.asyncio 367 async def test_chat_completions_autolog_tracing_error_with_parent_span(client): 368 mlflow.openai.autolog() 369 370 if client._is_async: 371 372 @mlflow.trace 373 async def create_completions(text: str) -> str: 374 try: 375 response = await client.chat.completions.create( 376 messages=[{"role": "user", "content": text}], 377 model="gpt-4o-mini", 378 temperature=5.0, 379 ) 380 return response.choices[0].delta.content 381 except openai.OpenAIError as e: 382 raise MlflowException("Failed to create completions") from e 383 384 with pytest.raises(MlflowException, match="Failed to create completions"): 385 await create_completions("test") 386 387 else: 388 389 @mlflow.trace 390 def create_completions(text: str) -> str: 391 try: 392 response = client.chat.completions.create( 393 messages=[{"role": "user", "content": text}], 394 model="gpt-4o-mini", 395 temperature=5.0, 396 ) 397 return response.choices[0].delta.content 398 except openai.OpenAIError as e: 399 raise MlflowException("Failed to create completions") from e 400 401 with pytest.raises(MlflowException, match="Failed to create completions"): 402 create_completions("test") 403 404 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 405 assert trace.info.status == "ERROR" 406 407 assert len(trace.data.spans) == 2 408 parent_span = trace.data.spans[0] 409 assert parent_span.name == "create_completions" 410 assert parent_span.inputs == {"text": "test"} 411 assert parent_span.outputs is None 412 assert parent_span.status.status_code == "ERROR" 413 assert parent_span.events[0].name == "exception" 414 assert parent_span.events[0].attributes["exception.type"] == "MlflowException" 415 assert parent_span.events[0].attributes["exception.message"] == "Failed to create completions" 416 417 child_span = trace.data.spans[1] 418 assert child_span.name == "AsyncCompletions" if client._is_async else "Completions" 419 assert child_span.inputs["messages"][0]["content"] == "test" 420 assert child_span.outputs is None 421 assert child_span.status.status_code == "ERROR" 422 assert child_span.events[0].name == "exception" 423 assert child_span.events[0].attributes["exception.type"] == "UnprocessableEntityError" 424 425 426 @pytest.mark.asyncio 427 async def test_chat_completions_streaming_empty_choices(client): 428 mlflow.openai.autolog() 429 stream = client.chat.completions.create( 430 messages=[{"role": "user", "content": EMPTY_CHOICES}], 431 model="gpt-4o-mini", 432 stream=True, 433 ) 434 435 chunks = [chunk async for chunk in await stream] if client._is_async else list(stream) 436 437 # Ensure the stream has a chunk with empty choices 438 assert chunks[0].choices == [] 439 440 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 441 assert trace.info.status == "OK" 442 443 444 @pytest.mark.asyncio 445 async def test_chat_completions_streaming_with_list_content(client): 446 # Test streaming with Databricks-style list content in chunks. 447 mlflow.openai.autolog() 448 stream = client.chat.completions.create( 449 messages=[{"role": "user", "content": LIST_CONTENT}], 450 model="gpt-4o-mini", 451 stream=True, 452 ) 453 454 chunks = [chunk async for chunk in await stream] if client._is_async else list(stream) 455 456 assert len(chunks) == 2 457 assert chunks[0].choices[0].delta.content == [{"type": "text", "text": "Hello"}] 458 assert chunks[1].choices[0].delta.content == [{"type": "text", "text": " world"}] 459 460 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 461 assert trace is not None 462 assert trace.info.status == "OK" 463 assert len(trace.data.spans) == 1 464 span = trace.data.spans[0] 465 assert span.span_type == SpanType.CHAT_MODEL 466 467 # Verify the reconstructed message content is correct (text extracted from list) 468 assert isinstance(span.outputs, dict) 469 assert span.outputs["choices"][0]["message"]["content"] == "Hello world" 470 471 472 @pytest.mark.asyncio 473 async def test_completions_autolog(client): 474 mlflow.openai.autolog() 475 476 response = client.completions.create( 477 prompt="test", 478 model="gpt-4o-mini", 479 temperature=0, 480 ) 481 482 if client._is_async: 483 await response 484 485 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 486 assert trace is not None 487 assert trace.info.status == "OK" 488 assert len(trace.data.spans) == 1 489 span = trace.data.spans[0] 490 assert span.span_type == SpanType.LLM 491 assert span.inputs == {"prompt": "test", "model": "gpt-4o-mini", "temperature": 0} 492 assert span.outputs["id"] == "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7" 493 assert span.model_name == "gpt-4o-mini" 494 assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "openai" 495 assert TraceMetadataKey.SOURCE_RUN not in trace.info.request_metadata 496 497 498 @pytest.mark.asyncio 499 async def test_completions_autolog_streaming_empty_choices(client): 500 mlflow.openai.autolog() 501 stream = client.completions.create( 502 prompt=EMPTY_CHOICES, 503 model="gpt-4o-mini", 504 stream=True, 505 ) 506 507 chunks = [chunk async for chunk in await stream] if client._is_async else list(stream) 508 509 # Ensure the stream has a chunk with empty choices 510 assert chunks[0].choices == [] 511 512 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 513 assert trace.info.status == "OK" 514 515 516 @pytest.mark.asyncio 517 async def test_completions_autolog_streaming(client): 518 mlflow.openai.autolog() 519 520 stream = client.completions.create( 521 prompt="test", 522 model="gpt-4o-mini", 523 temperature=0, 524 stream=True, 525 ) 526 if client._is_async: 527 async for _ in await stream: 528 pass 529 else: 530 for _ in stream: 531 pass 532 533 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 534 assert trace is not None 535 assert trace.info.status == "OK" 536 assert len(trace.data.spans) == 1 537 span = trace.data.spans[0] 538 assert span.span_type == SpanType.LLM 539 assert span.inputs == { 540 "prompt": "test", 541 "model": "gpt-4o-mini", 542 "temperature": 0, 543 "stream": True, 544 } 545 assert span.outputs == "Hello world" # aggregated string of streaming response 546 547 stream_event_data = trace.data.spans[0].events 548 549 assert stream_event_data[0].name == "mlflow.chunk.item.0" 550 chunk_1 = json.loads(stream_event_data[0].attributes[STREAM_CHUNK_EVENT_VALUE_KEY]) 551 assert chunk_1["id"] == "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7" 552 assert chunk_1["choices"][0]["text"] == "Hello" 553 assert stream_event_data[1].name == "mlflow.chunk.item.1" 554 chunk_2 = json.loads(stream_event_data[1].attributes[STREAM_CHUNK_EVENT_VALUE_KEY]) 555 assert chunk_2["id"] == "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7" 556 assert chunk_2["choices"][0]["text"] == " world" 557 558 559 @pytest.mark.asyncio 560 async def test_embeddings_autolog(client): 561 mlflow.openai.autolog() 562 563 response = client.embeddings.create( 564 input="test", 565 model="text-embedding-ada-002", 566 ) 567 568 if client._is_async: 569 await response 570 571 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 572 assert trace is not None 573 assert trace.info.status == "OK" 574 assert len(trace.data.spans) == 1 575 span = trace.data.spans[0] 576 assert span.span_type == SpanType.EMBEDDING 577 assert span.inputs == {"input": "test", "model": "text-embedding-ada-002"} 578 assert span.outputs["data"][0]["embedding"] == list(range(1536)) 579 assert span.model_name == "text-embedding-ada-002" 580 581 assert TraceMetadataKey.SOURCE_RUN not in trace.info.request_metadata 582 583 584 @skip_when_testing_trace_sdk 585 @pytest.mark.asyncio 586 async def test_autolog_use_active_run_id(client): 587 mlflow.openai.autolog() 588 589 messages = [{"role": "user", "content": "test"}] 590 591 async def _call_create(): 592 response = client.chat.completions.create(messages=messages, model="gpt-4o-mini") 593 if client._is_async: 594 await response 595 return response 596 597 with mlflow.start_run() as run_1: 598 await _call_create() 599 600 with mlflow.start_run() as run_2: 601 await _call_create() 602 await _call_create() 603 604 with mlflow.start_run() as run_3: 605 mlflow.openai.autolog() 606 await _call_create() 607 608 traces = get_traces()[::-1] # reverse order to sort by timestamp in ascending order 609 assert len(traces) == 4 610 611 assert traces[0].info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_1.info.run_id 612 assert traces[1].info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_2.info.run_id 613 assert traces[2].info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_2.info.run_id 614 assert traces[3].info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_3.info.run_id 615 616 617 @pytest.mark.asyncio 618 async def test_autolog_raw_response(client): 619 mlflow.openai.autolog() 620 621 messages = [{"role": "user", "content": "test"}] 622 623 resp = client.chat.completions.with_raw_response.create( 624 model="gpt-4o-mini", 625 messages=messages, 626 tools=MOCK_TOOLS, 627 ) 628 629 if client._is_async: 630 resp = await resp 631 632 resp = resp.parse() # ensure the raw response is returned 633 634 assert resp.choices[0].message.content == '[{"role": "user", "content": "test"}]' 635 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 636 assert len(trace.data.spans) == 1 637 span = trace.data.spans[0] 638 assert span.span_type == SpanType.CHAT_MODEL 639 assert isinstance(span.outputs, dict) 640 assert ( 641 span.outputs["choices"][0]["message"]["content"] == '[{"role": "user", "content": "test"}]' 642 ) 643 assert span.attributes[SpanAttributeKey.CHAT_TOOLS] == MOCK_TOOLS 644 assert span.model_name == "gpt-4o-mini" 645 646 assert trace.info.token_usage == { 647 TokenUsageKey.INPUT_TOKENS: 9, 648 TokenUsageKey.OUTPUT_TOKENS: 12, 649 TokenUsageKey.TOTAL_TOKENS: 21, 650 } 651 652 653 @pytest.mark.asyncio 654 async def test_autolog_raw_response_stream(client): 655 mlflow.openai.autolog() 656 657 messages = [{"role": "user", "content": "test"}] 658 659 resp = client.chat.completions.with_raw_response.create( 660 model="gpt-4o-mini", 661 messages=messages, 662 tools=MOCK_TOOLS, 663 stream=True, 664 ) 665 666 if client._is_async: 667 resp = await resp 668 669 resp = resp.parse() # ensure the raw response is returned 670 671 if client._is_async: 672 chunks = [c.choices[0].delta.content async for c in resp] 673 else: 674 chunks = [c.choices[0].delta.content for c in resp] 675 assert chunks == ["Hello", " world"] 676 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 677 assert len(trace.data.spans) == 1 678 span = trace.data.spans[0] 679 assert span.span_type == SpanType.CHAT_MODEL 680 assert span.model_name == "gpt-4o-mini" 681 682 # Reconstructed response from streaming chunks 683 assert isinstance(span.outputs, dict) 684 assert span.outputs["id"] == "chatcmpl-123" 685 assert span.outputs["object"] == "chat.completion" 686 assert span.outputs["model"] == "gpt-4o-mini" 687 assert span.outputs["choices"][0]["message"]["content"] == "Hello world" 688 assert span.attributes[SpanAttributeKey.CHAT_TOOLS] == MOCK_TOOLS 689 690 691 @pytest.mark.skipif( 692 Version(openai.__version__) < Version("1.40"), reason="Requires OpenAI SDK >= 1.40" 693 ) 694 @pytest.mark.asyncio 695 async def test_response_format(client): 696 mlflow.openai.autolog() 697 698 class Person(BaseModel): 699 name: str 700 age: int 701 702 mock_response = { 703 "id": "chatcmpl-Ax4UAd5xf32KjgLkS1SEEY9oorI9m", 704 "object": "chat.completion", 705 "created": 1738641958, 706 "model": "gpt-4o-2024-08-06", 707 "choices": [ 708 { 709 "index": 0, 710 "message": { 711 "role": "assistant", 712 "content": '{"name":"Angelo","age":42}', 713 "refusal": None, 714 }, 715 "logprobs": None, 716 "finish_reason": "stop", 717 } 718 ], 719 "usage": { 720 "prompt_tokens": 68, 721 "completion_tokens": 11, 722 "total_tokens": 79, 723 "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, 724 "completion_tokens_details": { 725 "reasoning_tokens": 0, 726 "audio_tokens": 0, 727 "accepted_prediction_tokens": 0, 728 "rejected_prediction_tokens": 0, 729 }, 730 }, 731 "service_tier": "default", 732 "system_fingerprint": "fp_50cad350e4", 733 } 734 735 if client._is_async: 736 patch_target = "httpx.AsyncClient.send" 737 738 async def send_patch(self, request, *args, **kwargs): 739 return httpx.Response( 740 status_code=200, 741 request=request, 742 json=mock_response, 743 ) 744 745 else: 746 patch_target = "httpx.Client.send" 747 748 def send_patch(self, request, *args, **kwargs): 749 return httpx.Response( 750 status_code=200, 751 request=request, 752 json=mock_response, 753 ) 754 755 with mock.patch(patch_target, send_patch): 756 response = client.beta.chat.completions.parse( 757 messages=[ 758 {"role": "system", "content": "Extract info from text"}, 759 {"role": "user", "content": "I am Angelo and I am 42."}, 760 ], 761 model="gpt-4o", 762 temperature=0, 763 response_format=Person, 764 ) 765 766 if client._is_async: 767 response = await response 768 769 assert response.choices[0].message.parsed == Person(name="Angelo", age=42) 770 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 771 assert len(trace.data.spans) == 1 772 span = trace.data.spans[0] 773 assert span.outputs["choices"][0]["message"]["content"] == '{"name":"Angelo","age":42}' 774 assert span.span_type == SpanType.CHAT_MODEL 775 assert span.model_name == "gpt-4o" 776 777 assert trace.info.trace_metadata.get(TraceMetadataKey.TOKEN_USAGE) == json.dumps({ 778 TokenUsageKey.INPUT_TOKENS: 68, 779 TokenUsageKey.OUTPUT_TOKENS: 11, 780 TokenUsageKey.TOTAL_TOKENS: 79, 781 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0, 782 }) 783 784 785 @skip_when_testing_trace_sdk 786 @pytest.mark.asyncio 787 async def test_autolog_link_traces_to_loaded_model_chat_completions(client, completion_models): 788 mlflow.openai.autolog() 789 790 for model_info in completion_models: 791 model_dict = mlflow.openai.load_model(model_info.model_uri) 792 resp = client.chat.completions.create( 793 messages=[{"role": "user", "content": f"test {model_info.model_id}"}], 794 model=model_dict["model"], 795 temperature=model_dict["temperature"], 796 ) 797 if client._is_async: 798 await resp 799 800 traces = get_traces() 801 assert len(traces) == len(completion_models) 802 for trace in traces: 803 span = trace.data.spans[0] 804 model_id = trace.info.request_metadata[TraceMetadataKey.MODEL_ID] 805 assert model_id is not None 806 assert span.inputs["messages"][0]["content"] == f"test {model_id}" 807 assert span.model_name == model_dict["model"] 808 809 810 @skip_when_testing_trace_sdk 811 @pytest.mark.asyncio 812 async def test_autolog_link_traces_to_loaded_model_completions(client, completion_models): 813 mlflow.openai.autolog() 814 815 for model_info in completion_models: 816 model_dict = mlflow.openai.load_model(model_info.model_uri) 817 resp = client.completions.create( 818 prompt=f"test {model_info.model_id}", 819 model=model_dict["model"], 820 temperature=model_dict["temperature"], 821 ) 822 if client._is_async: 823 await resp 824 825 traces = get_traces() 826 assert len(traces) == len(completion_models) 827 for trace in traces: 828 span = trace.data.spans[0] 829 model_id = trace.info.request_metadata[TraceMetadataKey.MODEL_ID] 830 assert model_id is not None 831 assert span.inputs["prompt"] == f"test {model_id}" 832 assert span.model_name == model_dict["model"] 833 834 835 @skip_when_testing_trace_sdk 836 @pytest.mark.asyncio 837 async def test_autolog_link_traces_to_loaded_model_embeddings(client, embedding_models): 838 mlflow.openai.autolog() 839 840 for model_info in embedding_models: 841 model_dict = mlflow.openai.load_model(model_info.model_uri) 842 resp = client.embeddings.create( 843 input=f"test {model_info.model_id}", 844 model=model_dict["model"], 845 encoding_format=model_dict["encoding_format"], 846 ) 847 if client._is_async: 848 await resp 849 850 traces = get_traces() 851 assert len(traces) == len(embedding_models) 852 for trace in traces: 853 span = trace.data.spans[0] 854 model_id = trace.info.request_metadata[TraceMetadataKey.MODEL_ID] 855 assert model_id is not None 856 assert span.inputs["input"] == f"test {model_id}" 857 assert span.model_name == model_dict["model"] 858 859 860 @skip_when_testing_trace_sdk 861 def test_autolog_link_traces_to_loaded_model_embeddings_pyfunc( 862 monkeypatch, mock_openai, embedding_models 863 ): 864 monkeypatch.setenv("OPENAI_API_KEY", "test") 865 monkeypatch.setenv("OPENAI_API_BASE", mock_openai) 866 867 mlflow.openai.autolog() 868 869 for model_info in embedding_models: 870 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 871 assert mlflow.get_active_model_id() == model_info.model_id 872 pyfunc_model.predict(f"test {model_info.model_id}") 873 874 traces = get_traces() 875 assert len(traces) == len(embedding_models) 876 for trace in traces: 877 span = trace.data.spans[0] 878 model_id = trace.info.request_metadata[TraceMetadataKey.MODEL_ID] 879 assert model_id is not None 880 assert span.inputs["input"] == [f"test {model_id}"] 881 assert span.model_name == "text-embedding-ada-002" 882 883 884 @skip_when_testing_trace_sdk 885 def test_autolog_link_traces_to_active_model(monkeypatch, mock_openai, embedding_models): 886 monkeypatch.setenv("OPENAI_API_KEY", "test") 887 monkeypatch.setenv("OPENAI_API_BASE", mock_openai) 888 889 model = mlflow.create_external_model(name="test_model") 890 mlflow.set_active_model(model_id=model.model_id) 891 mlflow.openai.autolog() 892 893 for model_info in embedding_models: 894 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 895 pyfunc_model.predict(model_info.model_id) 896 897 traces = get_traces() 898 assert len(traces) == len(embedding_models) 899 for trace in traces: 900 span = trace.data.spans[0] 901 assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == model.model_id 902 logged_model_id = span.inputs["input"][0] 903 assert logged_model_id != model.model_id 904 assert span.model_name == "text-embedding-ada-002" 905 906 907 @pytest.mark.asyncio 908 async def test_images_generate_autolog(client): 909 mlflow.openai.autolog() 910 911 # Disable tracing header injection — safe_patch rejects the extra_headers 912 # dict as a "new input" because it's not an ExceptionSafe-wrapped object. 913 # This is a known limitation shared with other non-chat endpoints. 914 openai_autolog_module = sys.modules["mlflow.openai.autolog"] 915 with mock.patch.object(openai_autolog_module, "_inject_tracing_headers"): 916 response = client.images.generate( 917 model="dall-e-3", 918 prompt="a white siamese cat", 919 n=1, 920 response_format="b64_json", 921 ) 922 923 if client._is_async: 924 await response 925 926 traces = get_traces() 927 assert len(traces) == 1 928 trace = traces[0] 929 assert trace.info.status == "OK" 930 assert len(trace.data.spans) == 1 931 span = trace.data.spans[0] 932 assert span.span_type == SpanType.TOOL 933 assert span.inputs["prompt"] == "a white siamese cat" 934 assert span.outputs["data"][0]["revised_prompt"] == "a test image" 935 936 937 @pytest.mark.parametrize( 938 "sentinel", 939 [None, 42, object()], 940 ) 941 def test_parse_tools_handles_openai_not_given_sentinel(sentinel): 942 assert _parse_tools({"tools": sentinel}) == [] 943 944 945 @skip_when_testing_trace_sdk 946 @pytest.mark.asyncio 947 async def test_model_loading_set_active_model_id_without_fetching_logged_model( 948 client, completion_models 949 ): 950 mlflow.openai.autolog() 951 952 model_info = completion_models[0] 953 with mock.patch("mlflow.get_logged_model", side_effect=Exception("get_logged_model failed")): 954 model_dict = mlflow.openai.load_model(model_info.model_uri) 955 resp = client.chat.completions.create( 956 messages=[{"role": "user", "content": f"test {model_info.model_id}"}], 957 model=model_dict["model"], 958 temperature=model_dict["temperature"], 959 ) 960 if client._is_async: 961 await resp 962 963 traces = get_traces() 964 assert len(traces) == 1 965 span = traces[0].data.spans[0] 966 model_id = traces[0].info.request_metadata[TraceMetadataKey.MODEL_ID] 967 assert model_id is not None 968 assert span.inputs["messages"][0]["content"] == f"test {model_id}" 969 assert span.model_name == model_dict["model"] 970 971 972 @pytest.mark.skipif( 973 Version(openai.__version__) < Version("1.66"), reason="Requires OpenAI SDK >= 1.66" 974 ) 975 @skip_when_testing_trace_sdk 976 def test_reconstruct_response_from_stream(): 977 from openai.types.responses import ( 978 ResponseOutputItemDoneEvent, 979 ResponseOutputMessage, 980 ResponseOutputText, 981 ) 982 983 from mlflow.openai.autolog import _reconstruct_response_from_stream 984 from mlflow.types.responses_helpers import OutputItem 985 986 content1 = ResponseOutputText(annotations=[], text="Hello", type="output_text") 987 content2 = ResponseOutputText(annotations=[], text=" world", type="output_text") 988 989 message1 = ResponseOutputMessage( 990 id="test-1", content=[content1], role="assistant", status="completed", type="message" 991 ) 992 993 message2 = ResponseOutputMessage( 994 id="test-2", content=[content2], role="assistant", status="completed", type="message" 995 ) 996 997 chunk1 = ResponseOutputItemDoneEvent( 998 item=message1, output_index=0, sequence_number=1, type="response.output_item.done" 999 ) 1000 1001 chunk2 = ResponseOutputItemDoneEvent( 1002 item=message2, output_index=1, sequence_number=2, type="response.output_item.done" 1003 ) 1004 1005 chunks = [chunk1, chunk2] 1006 1007 result = _reconstruct_response_from_stream(chunks) 1008 1009 assert result.output == [ 1010 OutputItem(**chunk1.item.to_dict()), 1011 OutputItem(**chunk2.item.to_dict()), 1012 ] 1013 1014 1015 @pytest.mark.asyncio 1016 async def test_tracing_headers_injected(client): 1017 mlflow.openai.autolog() 1018 1019 captured_request = {} 1020 mock_response = { 1021 "id": "chatcmpl-123", 1022 "object": "chat.completion", 1023 "created": 1677652288, 1024 "model": "gpt-4o-mini", 1025 "choices": [ 1026 { 1027 "index": 0, 1028 "message": {"role": "assistant", "content": "hi"}, 1029 "finish_reason": "stop", 1030 } 1031 ], 1032 "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}, 1033 } 1034 1035 if client._is_async: 1036 patch_target = "httpx.AsyncClient.send" 1037 original_send = httpx.AsyncClient.send 1038 1039 async def send_patch(self, request, *args, **kwargs): 1040 if "chat/completions" in str(request.url): 1041 captured_request["headers"] = dict(request.headers) 1042 return httpx.Response(status_code=200, request=request, json=mock_response) 1043 return await original_send(self, request, *args, **kwargs) 1044 1045 else: 1046 patch_target = "httpx.Client.send" 1047 original_send = httpx.Client.send 1048 1049 def send_patch(self, request, *args, **kwargs): 1050 if "chat/completions" in str(request.url): 1051 captured_request["headers"] = dict(request.headers) 1052 return httpx.Response(status_code=200, request=request, json=mock_response) 1053 return original_send(self, request, *args, **kwargs) 1054 1055 with mock.patch(patch_target, send_patch): 1056 response = client.chat.completions.create( 1057 messages=[{"role": "user", "content": "test"}], 1058 model="gpt-4o-mini", 1059 ) 1060 if client._is_async: 1061 response = await response 1062 1063 # Verify traceparent header was injected 1064 assert "traceparent" in captured_request["headers"] 1065 traceparent = captured_request["headers"]["traceparent"] 1066 assert re.fullmatch(r"00-[0-9a-f]{32}-[0-9a-f]{16}-[0-9a-f]{2}", traceparent) 1067 1068 # Verify the traceparent points to the LLM span 1069 traces = get_traces() 1070 assert len(traces) == 1 1071 span = traces[0].data.spans[0] 1072 span_ctx = span._span.get_span_context() 1073 trace_id_hex = format(span_ctx.trace_id, "032x") 1074 span_id_hex = format(span_ctx.span_id, "016x") 1075 assert traceparent.startswith(f"00-{trace_id_hex}-{span_id_hex}-") 1076 1077 1078 @pytest.mark.asyncio 1079 async def test_tracing_headers_preserve_user_headers(client): 1080 mlflow.openai.autolog() 1081 1082 captured_request = {} 1083 mock_response = { 1084 "id": "chatcmpl-123", 1085 "object": "chat.completion", 1086 "created": 1677652288, 1087 "model": "gpt-4o-mini", 1088 "choices": [ 1089 { 1090 "index": 0, 1091 "message": {"role": "assistant", "content": "hi"}, 1092 "finish_reason": "stop", 1093 } 1094 ], 1095 "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}, 1096 } 1097 1098 if client._is_async: 1099 patch_target = "httpx.AsyncClient.send" 1100 original_send = httpx.AsyncClient.send 1101 1102 async def send_patch(self, request, *args, **kwargs): 1103 if "chat/completions" in str(request.url): 1104 captured_request["headers"] = dict(request.headers) 1105 return httpx.Response(status_code=200, request=request, json=mock_response) 1106 return await original_send(self, request, *args, **kwargs) 1107 1108 else: 1109 patch_target = "httpx.Client.send" 1110 original_send = httpx.Client.send 1111 1112 def send_patch(self, request, *args, **kwargs): 1113 if "chat/completions" in str(request.url): 1114 captured_request["headers"] = dict(request.headers) 1115 return httpx.Response(status_code=200, request=request, json=mock_response) 1116 return original_send(self, request, *args, **kwargs) 1117 1118 with mock.patch(patch_target, send_patch): 1119 response = client.chat.completions.create( 1120 messages=[{"role": "user", "content": "test"}], 1121 model="gpt-4o-mini", 1122 extra_headers={"X-Custom": "my-value"}, 1123 ) 1124 if client._is_async: 1125 response = await response 1126 1127 # User-provided headers should be preserved alongside traceparent 1128 assert "traceparent" in captured_request["headers"] 1129 assert captured_request["headers"].get("x-custom") == "my-value" 1130 1131 1132 @pytest.mark.asyncio 1133 @pytest.mark.skipif( 1134 Version(openai.__version__) < Version("1.66"), reason="Cost tracking does not work before 1.66" 1135 ) 1136 async def test_chat_completions_autolog_streaming_with_cached_tokens(client, mock_litellm_cost): 1137 mlflow.openai.autolog() 1138 1139 mock_chunk = { 1140 "id": "chatcmpl-stream-cached", 1141 "object": "chat.completion.chunk", 1142 "created": 1677652288, 1143 "model": "gpt-4o-mini", 1144 "choices": [], 1145 "usage": { 1146 "prompt_tokens": 50, 1147 "completion_tokens": 20, 1148 "total_tokens": 70, 1149 "prompt_tokens_details": {"cached_tokens": 30, "audio_tokens": 0}, 1150 "completion_tokens_details": {"reasoning_tokens": 0}, 1151 }, 1152 } 1153 1154 if client._is_async: 1155 patch_target = "httpx.AsyncClient.send" 1156 1157 async def send_patch(self, request, *args, **kwargs): 1158 content = f"data: {json.dumps(mock_chunk)}\n\ndata: [DONE]\n\n".encode() 1159 return httpx.Response(status_code=200, request=request, content=content) 1160 1161 else: 1162 patch_target = "httpx.Client.send" 1163 1164 def send_patch(self, request, *args, **kwargs): 1165 content = f"data: {json.dumps(mock_chunk)}\n\ndata: [DONE]\n\n".encode() 1166 return httpx.Response(status_code=200, request=request, content=content) 1167 1168 with mock.patch(patch_target, send_patch): 1169 stream = client.chat.completions.create( 1170 messages=[{"role": "user", "content": "test"}], 1171 model="gpt-4o-mini", 1172 stream=True, 1173 stream_options={"include_usage": True}, 1174 ) 1175 if client._is_async: 1176 async for _ in await stream: 1177 pass 1178 else: 1179 for _ in stream: 1180 pass 1181 1182 traces = get_traces() 1183 assert len(traces) == 1 1184 span = traces[0].data.spans[0] 1185 1186 assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == { 1187 TokenUsageKey.INPUT_TOKENS: 50, 1188 TokenUsageKey.OUTPUT_TOKENS: 20, 1189 TokenUsageKey.TOTAL_TOKENS: 70, 1190 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 30, 1191 }