/ tests / openai / test_openai_autolog.py
test_openai_autolog.py
   1  import json
   2  import re
   3  import sys
   4  from unittest import mock
   5  
   6  import httpx
   7  import openai
   8  import pytest
   9  from packaging.version import Version
  10  from pydantic import BaseModel
  11  
  12  import mlflow
  13  from mlflow.entities.span import SpanType
  14  from mlflow.exceptions import MlflowException
  15  from mlflow.openai.utils.chat_schema import _parse_tools
  16  from mlflow.tracing.constant import (
  17      STREAM_CHUNK_EVENT_VALUE_KEY,
  18      CostKey,
  19      SpanAttributeKey,
  20      TokenUsageKey,
  21      TraceMetadataKey,
  22  )
  23  from mlflow.version import IS_TRACING_SDK_ONLY
  24  
  25  from tests.openai.mock_openai import EMPTY_CHOICES, LIST_CONTENT
  26  from tests.tracing.helper import get_traces, skip_when_testing_trace_sdk
  27  
  28  MOCK_TOOLS = [
  29      {
  30          "type": "function",
  31          "function": {
  32              "name": "add",
  33              "description": "Add two numbers",
  34              "parameters": {
  35                  "type": "object",
  36                  "properties": {
  37                      "a": {"type": "number"},
  38                      "b": {"type": "number"},
  39                  },
  40                  "required": ["a", "b"],
  41              },
  42          },
  43      }
  44  ]
  45  
  46  
  47  @pytest.fixture(params=[True, False], ids=["sync", "async"])
  48  def client(request, monkeypatch, mock_openai):
  49      monkeypatch.setenv("OPENAI_API_KEY", "test")
  50      monkeypatch.setenv("OPENAI_API_BASE", mock_openai)
  51      if request.param:
  52          client = openai.OpenAI(api_key="test", base_url=mock_openai)
  53          client._is_async = False
  54          return client
  55      else:
  56          client = openai.AsyncOpenAI(api_key="test", base_url=mock_openai)
  57          client._is_async = True
  58          return client
  59  
  60  
  61  @pytest.fixture
  62  def completion_models():
  63      return [
  64          mlflow.openai.log_model(
  65              "gpt-4o-mini",
  66              "completions",
  67              name="model",
  68              temperature=temp,
  69              prompt="Say {text}",
  70              pip_requirements=["mlflow"],  # Hard code for speed up
  71          )
  72          for temp in [0.1, 0.2, 0.3]
  73      ]
  74  
  75  
  76  @pytest.fixture
  77  def embedding_models():
  78      float_model = mlflow.openai.log_model(
  79          "text-embedding-ada-002",
  80          "embeddings",
  81          name="model",
  82          encoding_format="float",
  83          pip_requirements=["mlflow"],  # Hard code for speed up
  84      )
  85      base64_model = mlflow.openai.log_model(
  86          "text-embedding-ada-002",
  87          "embeddings",
  88          name="model",
  89          encoding_format="base64",
  90          pip_requirements=["mlflow"],  # Hard code for speed up
  91      )
  92      return [float_model, base64_model]
  93  
  94  
  95  @pytest.mark.asyncio
  96  @pytest.mark.skipif(
  97      Version(openai.__version__) < Version("1.66"), reason="Cost tracking does not work before 1.66"
  98  )
  99  async def test_chat_completions_autolog(client, mock_litellm_cost):
 100      mlflow.openai.autolog()
 101  
 102      messages = [{"role": "user", "content": "test"}]
 103      response = client.chat.completions.create(
 104          messages=messages,
 105          model="gpt-4o-mini",
 106          temperature=0,
 107      )
 108  
 109      if client._is_async:
 110          await response
 111  
 112      traces = get_traces()
 113      assert len(traces) == 1
 114      trace = traces[0]
 115      assert trace is not None
 116      assert trace.info.status == "OK"
 117      assert len(trace.data.spans) == 1
 118      span = trace.data.spans[0]
 119      assert span.span_type == SpanType.CHAT_MODEL
 120      assert span.inputs == {"messages": messages, "model": "gpt-4o-mini", "temperature": 0}
 121      assert span.outputs["id"] == "chatcmpl-123"
 122      assert span.attributes["model"] == "gpt-4o-mini"
 123      assert span.attributes["temperature"] == 0
 124  
 125      assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
 126          TokenUsageKey.INPUT_TOKENS: 9,
 127          TokenUsageKey.OUTPUT_TOKENS: 12,
 128          TokenUsageKey.TOTAL_TOKENS: 21,
 129      }
 130      assert span.model_name == "gpt-4o-mini"
 131      assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "openai"
 132  
 133      if not IS_TRACING_SDK_ONLY:
 134          # Verify cost is calculated (9 input tokens * 1.0 + 12 output tokens * 2.0)
 135          assert span.llm_cost == {
 136              "input_cost": 9.0,
 137              "output_cost": 24.0,
 138              "total_cost": 33.0,
 139          }
 140  
 141      assert TraceMetadataKey.SOURCE_RUN not in trace.info.request_metadata
 142      assert trace.info.token_usage == {
 143          TokenUsageKey.INPUT_TOKENS: 9,
 144          TokenUsageKey.OUTPUT_TOKENS: 12,
 145          TokenUsageKey.TOTAL_TOKENS: 21,
 146      }
 147      if not IS_TRACING_SDK_ONLY:
 148          assert trace.info.cost == {
 149              CostKey.INPUT_COST: 9.0,
 150              CostKey.OUTPUT_COST: 24.0,
 151              CostKey.TOTAL_COST: 33.0,
 152          }
 153  
 154  
 155  @pytest.mark.asyncio
 156  @pytest.mark.skipif(
 157      Version(openai.__version__) < Version("1.66"), reason="Cost tracking does not work before 1.66"
 158  )
 159  async def test_chat_completions_autolog_with_cached_tokens(client, mock_litellm_cost):
 160      mlflow.openai.autolog()
 161  
 162      mock_response = {
 163          "id": "chatcmpl-cached",
 164          "object": "chat.completion",
 165          "created": 1677652288,
 166          "model": "gpt-4o-mini",
 167          "choices": [
 168              {
 169                  "index": 0,
 170                  "message": {"role": "assistant", "content": "Hello"},
 171                  "logprobs": None,
 172                  "finish_reason": "stop",
 173              }
 174          ],
 175          "usage": {
 176              "prompt_tokens": 50,
 177              "completion_tokens": 20,
 178              "total_tokens": 70,
 179              "prompt_tokens_details": {"cached_tokens": 30, "audio_tokens": 0},
 180              "completion_tokens_details": {"reasoning_tokens": 0},
 181          },
 182      }
 183  
 184      if client._is_async:
 185          patch_target = "httpx.AsyncClient.send"
 186  
 187          async def send_patch(self, request, *args, **kwargs):
 188              return httpx.Response(status_code=200, request=request, json=mock_response)
 189  
 190      else:
 191          patch_target = "httpx.Client.send"
 192  
 193          def send_patch(self, request, *args, **kwargs):
 194              return httpx.Response(status_code=200, request=request, json=mock_response)
 195  
 196      with mock.patch(patch_target, send_patch):
 197          response = client.chat.completions.create(
 198              messages=[{"role": "user", "content": "test"}],
 199              model="gpt-4o-mini",
 200              temperature=0,
 201          )
 202          if client._is_async:
 203              response = await response
 204  
 205      traces = get_traces()
 206      assert len(traces) == 1
 207      span = traces[0].data.spans[0]
 208  
 209      assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
 210          TokenUsageKey.INPUT_TOKENS: 50,
 211          TokenUsageKey.OUTPUT_TOKENS: 20,
 212          TokenUsageKey.TOTAL_TOKENS: 70,
 213          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 30,
 214      }
 215  
 216      assert traces[0].info.token_usage == {
 217          TokenUsageKey.INPUT_TOKENS: 50,
 218          TokenUsageKey.OUTPUT_TOKENS: 20,
 219          TokenUsageKey.TOTAL_TOKENS: 70,
 220          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 30,
 221      }
 222  
 223  
 224  @pytest.mark.asyncio
 225  @pytest.mark.skipif(
 226      Version(openai.__version__) < Version("1.66"), reason="Cost tracking does not work before 1.66"
 227  )
 228  async def test_chat_completions_autolog_under_current_active_span(client):
 229      # If a user have an active span, the autologging should create a child span under it.
 230      mlflow.openai.autolog()
 231  
 232      messages = [{"role": "user", "content": "test"}]
 233      with mlflow.start_span(name="parent"):
 234          for _ in range(3):
 235              response = client.chat.completions.create(
 236                  messages=messages,
 237                  model="gpt-4o-mini",
 238                  temperature=0,
 239              )
 240  
 241              if client._is_async:
 242                  await response
 243  
 244      traces = get_traces()
 245      assert len(traces) == 1
 246      trace = traces[0]
 247      assert trace is not None
 248      assert trace.info.status == "OK"
 249      assert len(trace.data.spans) == 4
 250      parent_span = trace.data.spans[0]
 251      assert parent_span.name == "parent"
 252      child_span = trace.data.spans[1]
 253      assert child_span.name == "AsyncCompletions" if client._is_async else "Completions"
 254      assert child_span.inputs == {"messages": messages, "model": "gpt-4o-mini", "temperature": 0}
 255      assert child_span.outputs["id"] == "chatcmpl-123"
 256      assert child_span.parent_id == parent_span.span_id
 257  
 258      # Token usage should be aggregated correctly
 259      assert trace.info.token_usage == {
 260          TokenUsageKey.INPUT_TOKENS: 27,
 261          TokenUsageKey.OUTPUT_TOKENS: 36,
 262          TokenUsageKey.TOTAL_TOKENS: 63,
 263      }
 264  
 265  
 266  @pytest.mark.asyncio
 267  @pytest.mark.parametrize("include_usage", [True, False])
 268  async def test_chat_completions_autolog_streaming(client, include_usage):
 269      mlflow.openai.autolog()
 270  
 271      stream_options_supported = Version(openai.__version__) >= Version("1.26")
 272  
 273      if not stream_options_supported and include_usage:
 274          pytest.skip("OpenAI SDK version does not support usage tracking in streaming")
 275  
 276      messages = [{"role": "user", "content": "test"}]
 277  
 278      input_params = {
 279          "messages": messages,
 280          "model": "gpt-4o-mini",
 281          "temperature": 0,
 282          "stream": True,
 283      }
 284      if stream_options_supported:
 285          input_params["stream_options"] = {"include_usage": include_usage}
 286  
 287      stream = client.chat.completions.create(**input_params)
 288  
 289      if client._is_async:
 290          async for _ in await stream:
 291              pass
 292      else:
 293          for _ in stream:
 294              pass
 295  
 296      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 297      assert trace is not None
 298      assert trace.info.status == "OK"
 299      assert len(trace.data.spans) == 1
 300      span = trace.data.spans[0]
 301      assert span.span_type == SpanType.CHAT_MODEL
 302      assert span.inputs == input_params
 303  
 304      # Reconstructed response from streaming chunks
 305      assert isinstance(span.outputs, dict)
 306      assert span.outputs["id"] == "chatcmpl-123"
 307      assert span.outputs["object"] == "chat.completion"
 308      assert span.outputs["model"] == "gpt-4o-mini"
 309      assert span.outputs["system_fingerprint"] == "fp_44709d6fcb"
 310      assert "choices" in span.outputs
 311      assert span.outputs["choices"][0]["message"]["role"] == "assistant"
 312      assert span.outputs["choices"][0]["message"]["content"] == "Hello world"
 313  
 314      # Usage should be preserved when include_usage=True
 315      if include_usage:
 316          assert "usage" in span.outputs
 317          assert span.outputs["usage"]["prompt_tokens"] == 9
 318          assert span.outputs["usage"]["completion_tokens"] == 12
 319          assert span.outputs["usage"]["total_tokens"] == 21
 320  
 321      stream_event_data = trace.data.spans[0].events
 322      assert stream_event_data[0].name == "mlflow.chunk.item.0"
 323      chunk_1 = json.loads(stream_event_data[0].attributes[STREAM_CHUNK_EVENT_VALUE_KEY])
 324      assert chunk_1["id"] == "chatcmpl-123"
 325      assert chunk_1["choices"][0]["delta"]["content"] == "Hello"
 326      assert stream_event_data[1].name == "mlflow.chunk.item.1"
 327      chunk_2 = json.loads(stream_event_data[1].attributes[STREAM_CHUNK_EVENT_VALUE_KEY])
 328      assert chunk_2["id"] == "chatcmpl-123"
 329      assert chunk_2["choices"][0]["delta"]["content"] == " world"
 330  
 331      if include_usage:
 332          assert trace.info.token_usage == {
 333              TokenUsageKey.INPUT_TOKENS: 9,
 334              TokenUsageKey.OUTPUT_TOKENS: 12,
 335              TokenUsageKey.TOTAL_TOKENS: 21,
 336          }
 337  
 338  
 339  @pytest.mark.asyncio
 340  async def test_chat_completions_autolog_tracing_error(client):
 341      mlflow.openai.autolog()
 342      messages = [{"role": "user", "content": "test"}]
 343      with pytest.raises(openai.UnprocessableEntityError, match="Input should be less"):  # noqa: PT012
 344          response = client.chat.completions.create(
 345              messages=messages,
 346              model="gpt-4o-mini",
 347              temperature=5.0,
 348          )
 349  
 350          if client._is_async:
 351              await response
 352  
 353      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 354      assert trace.info.status == "ERROR"
 355  
 356      assert len(trace.data.spans) == 1
 357      span = trace.data.spans[0]
 358      assert span.name == "AsyncCompletions" if client._is_async else "Completions"
 359      assert span.inputs["messages"][0]["content"] == "test"
 360      assert span.outputs is None
 361  
 362      assert span.events[0].name == "exception"
 363      assert span.events[0].attributes["exception.type"] == "UnprocessableEntityError"
 364  
 365  
 366  @pytest.mark.asyncio
 367  async def test_chat_completions_autolog_tracing_error_with_parent_span(client):
 368      mlflow.openai.autolog()
 369  
 370      if client._is_async:
 371  
 372          @mlflow.trace
 373          async def create_completions(text: str) -> str:
 374              try:
 375                  response = await client.chat.completions.create(
 376                      messages=[{"role": "user", "content": text}],
 377                      model="gpt-4o-mini",
 378                      temperature=5.0,
 379                  )
 380                  return response.choices[0].delta.content
 381              except openai.OpenAIError as e:
 382                  raise MlflowException("Failed to create completions") from e
 383  
 384          with pytest.raises(MlflowException, match="Failed to create completions"):
 385              await create_completions("test")
 386  
 387      else:
 388  
 389          @mlflow.trace
 390          def create_completions(text: str) -> str:
 391              try:
 392                  response = client.chat.completions.create(
 393                      messages=[{"role": "user", "content": text}],
 394                      model="gpt-4o-mini",
 395                      temperature=5.0,
 396                  )
 397                  return response.choices[0].delta.content
 398              except openai.OpenAIError as e:
 399                  raise MlflowException("Failed to create completions") from e
 400  
 401          with pytest.raises(MlflowException, match="Failed to create completions"):
 402              create_completions("test")
 403  
 404      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 405      assert trace.info.status == "ERROR"
 406  
 407      assert len(trace.data.spans) == 2
 408      parent_span = trace.data.spans[0]
 409      assert parent_span.name == "create_completions"
 410      assert parent_span.inputs == {"text": "test"}
 411      assert parent_span.outputs is None
 412      assert parent_span.status.status_code == "ERROR"
 413      assert parent_span.events[0].name == "exception"
 414      assert parent_span.events[0].attributes["exception.type"] == "MlflowException"
 415      assert parent_span.events[0].attributes["exception.message"] == "Failed to create completions"
 416  
 417      child_span = trace.data.spans[1]
 418      assert child_span.name == "AsyncCompletions" if client._is_async else "Completions"
 419      assert child_span.inputs["messages"][0]["content"] == "test"
 420      assert child_span.outputs is None
 421      assert child_span.status.status_code == "ERROR"
 422      assert child_span.events[0].name == "exception"
 423      assert child_span.events[0].attributes["exception.type"] == "UnprocessableEntityError"
 424  
 425  
 426  @pytest.mark.asyncio
 427  async def test_chat_completions_streaming_empty_choices(client):
 428      mlflow.openai.autolog()
 429      stream = client.chat.completions.create(
 430          messages=[{"role": "user", "content": EMPTY_CHOICES}],
 431          model="gpt-4o-mini",
 432          stream=True,
 433      )
 434  
 435      chunks = [chunk async for chunk in await stream] if client._is_async else list(stream)
 436  
 437      # Ensure the stream has a chunk with empty choices
 438      assert chunks[0].choices == []
 439  
 440      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 441      assert trace.info.status == "OK"
 442  
 443  
 444  @pytest.mark.asyncio
 445  async def test_chat_completions_streaming_with_list_content(client):
 446      # Test streaming with Databricks-style list content in chunks.
 447      mlflow.openai.autolog()
 448      stream = client.chat.completions.create(
 449          messages=[{"role": "user", "content": LIST_CONTENT}],
 450          model="gpt-4o-mini",
 451          stream=True,
 452      )
 453  
 454      chunks = [chunk async for chunk in await stream] if client._is_async else list(stream)
 455  
 456      assert len(chunks) == 2
 457      assert chunks[0].choices[0].delta.content == [{"type": "text", "text": "Hello"}]
 458      assert chunks[1].choices[0].delta.content == [{"type": "text", "text": " world"}]
 459  
 460      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 461      assert trace is not None
 462      assert trace.info.status == "OK"
 463      assert len(trace.data.spans) == 1
 464      span = trace.data.spans[0]
 465      assert span.span_type == SpanType.CHAT_MODEL
 466  
 467      # Verify the reconstructed message content is correct (text extracted from list)
 468      assert isinstance(span.outputs, dict)
 469      assert span.outputs["choices"][0]["message"]["content"] == "Hello world"
 470  
 471  
 472  @pytest.mark.asyncio
 473  async def test_completions_autolog(client):
 474      mlflow.openai.autolog()
 475  
 476      response = client.completions.create(
 477          prompt="test",
 478          model="gpt-4o-mini",
 479          temperature=0,
 480      )
 481  
 482      if client._is_async:
 483          await response
 484  
 485      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 486      assert trace is not None
 487      assert trace.info.status == "OK"
 488      assert len(trace.data.spans) == 1
 489      span = trace.data.spans[0]
 490      assert span.span_type == SpanType.LLM
 491      assert span.inputs == {"prompt": "test", "model": "gpt-4o-mini", "temperature": 0}
 492      assert span.outputs["id"] == "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"
 493      assert span.model_name == "gpt-4o-mini"
 494      assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "openai"
 495      assert TraceMetadataKey.SOURCE_RUN not in trace.info.request_metadata
 496  
 497  
 498  @pytest.mark.asyncio
 499  async def test_completions_autolog_streaming_empty_choices(client):
 500      mlflow.openai.autolog()
 501      stream = client.completions.create(
 502          prompt=EMPTY_CHOICES,
 503          model="gpt-4o-mini",
 504          stream=True,
 505      )
 506  
 507      chunks = [chunk async for chunk in await stream] if client._is_async else list(stream)
 508  
 509      # Ensure the stream has a chunk with empty choices
 510      assert chunks[0].choices == []
 511  
 512      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 513      assert trace.info.status == "OK"
 514  
 515  
 516  @pytest.mark.asyncio
 517  async def test_completions_autolog_streaming(client):
 518      mlflow.openai.autolog()
 519  
 520      stream = client.completions.create(
 521          prompt="test",
 522          model="gpt-4o-mini",
 523          temperature=0,
 524          stream=True,
 525      )
 526      if client._is_async:
 527          async for _ in await stream:
 528              pass
 529      else:
 530          for _ in stream:
 531              pass
 532  
 533      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 534      assert trace is not None
 535      assert trace.info.status == "OK"
 536      assert len(trace.data.spans) == 1
 537      span = trace.data.spans[0]
 538      assert span.span_type == SpanType.LLM
 539      assert span.inputs == {
 540          "prompt": "test",
 541          "model": "gpt-4o-mini",
 542          "temperature": 0,
 543          "stream": True,
 544      }
 545      assert span.outputs == "Hello world"  # aggregated string of streaming response
 546  
 547      stream_event_data = trace.data.spans[0].events
 548  
 549      assert stream_event_data[0].name == "mlflow.chunk.item.0"
 550      chunk_1 = json.loads(stream_event_data[0].attributes[STREAM_CHUNK_EVENT_VALUE_KEY])
 551      assert chunk_1["id"] == "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"
 552      assert chunk_1["choices"][0]["text"] == "Hello"
 553      assert stream_event_data[1].name == "mlflow.chunk.item.1"
 554      chunk_2 = json.loads(stream_event_data[1].attributes[STREAM_CHUNK_EVENT_VALUE_KEY])
 555      assert chunk_2["id"] == "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"
 556      assert chunk_2["choices"][0]["text"] == " world"
 557  
 558  
 559  @pytest.mark.asyncio
 560  async def test_embeddings_autolog(client):
 561      mlflow.openai.autolog()
 562  
 563      response = client.embeddings.create(
 564          input="test",
 565          model="text-embedding-ada-002",
 566      )
 567  
 568      if client._is_async:
 569          await response
 570  
 571      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 572      assert trace is not None
 573      assert trace.info.status == "OK"
 574      assert len(trace.data.spans) == 1
 575      span = trace.data.spans[0]
 576      assert span.span_type == SpanType.EMBEDDING
 577      assert span.inputs == {"input": "test", "model": "text-embedding-ada-002"}
 578      assert span.outputs["data"][0]["embedding"] == list(range(1536))
 579      assert span.model_name == "text-embedding-ada-002"
 580  
 581      assert TraceMetadataKey.SOURCE_RUN not in trace.info.request_metadata
 582  
 583  
 584  @skip_when_testing_trace_sdk
 585  @pytest.mark.asyncio
 586  async def test_autolog_use_active_run_id(client):
 587      mlflow.openai.autolog()
 588  
 589      messages = [{"role": "user", "content": "test"}]
 590  
 591      async def _call_create():
 592          response = client.chat.completions.create(messages=messages, model="gpt-4o-mini")
 593          if client._is_async:
 594              await response
 595          return response
 596  
 597      with mlflow.start_run() as run_1:
 598          await _call_create()
 599  
 600      with mlflow.start_run() as run_2:
 601          await _call_create()
 602          await _call_create()
 603  
 604      with mlflow.start_run() as run_3:
 605          mlflow.openai.autolog()
 606          await _call_create()
 607  
 608      traces = get_traces()[::-1]  # reverse order to sort by timestamp in ascending order
 609      assert len(traces) == 4
 610  
 611      assert traces[0].info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_1.info.run_id
 612      assert traces[1].info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_2.info.run_id
 613      assert traces[2].info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_2.info.run_id
 614      assert traces[3].info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_3.info.run_id
 615  
 616  
 617  @pytest.mark.asyncio
 618  async def test_autolog_raw_response(client):
 619      mlflow.openai.autolog()
 620  
 621      messages = [{"role": "user", "content": "test"}]
 622  
 623      resp = client.chat.completions.with_raw_response.create(
 624          model="gpt-4o-mini",
 625          messages=messages,
 626          tools=MOCK_TOOLS,
 627      )
 628  
 629      if client._is_async:
 630          resp = await resp
 631  
 632      resp = resp.parse()  # ensure the raw response is returned
 633  
 634      assert resp.choices[0].message.content == '[{"role": "user", "content": "test"}]'
 635      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 636      assert len(trace.data.spans) == 1
 637      span = trace.data.spans[0]
 638      assert span.span_type == SpanType.CHAT_MODEL
 639      assert isinstance(span.outputs, dict)
 640      assert (
 641          span.outputs["choices"][0]["message"]["content"] == '[{"role": "user", "content": "test"}]'
 642      )
 643      assert span.attributes[SpanAttributeKey.CHAT_TOOLS] == MOCK_TOOLS
 644      assert span.model_name == "gpt-4o-mini"
 645  
 646      assert trace.info.token_usage == {
 647          TokenUsageKey.INPUT_TOKENS: 9,
 648          TokenUsageKey.OUTPUT_TOKENS: 12,
 649          TokenUsageKey.TOTAL_TOKENS: 21,
 650      }
 651  
 652  
 653  @pytest.mark.asyncio
 654  async def test_autolog_raw_response_stream(client):
 655      mlflow.openai.autolog()
 656  
 657      messages = [{"role": "user", "content": "test"}]
 658  
 659      resp = client.chat.completions.with_raw_response.create(
 660          model="gpt-4o-mini",
 661          messages=messages,
 662          tools=MOCK_TOOLS,
 663          stream=True,
 664      )
 665  
 666      if client._is_async:
 667          resp = await resp
 668  
 669      resp = resp.parse()  # ensure the raw response is returned
 670  
 671      if client._is_async:
 672          chunks = [c.choices[0].delta.content async for c in resp]
 673      else:
 674          chunks = [c.choices[0].delta.content for c in resp]
 675      assert chunks == ["Hello", " world"]
 676      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 677      assert len(trace.data.spans) == 1
 678      span = trace.data.spans[0]
 679      assert span.span_type == SpanType.CHAT_MODEL
 680      assert span.model_name == "gpt-4o-mini"
 681  
 682      # Reconstructed response from streaming chunks
 683      assert isinstance(span.outputs, dict)
 684      assert span.outputs["id"] == "chatcmpl-123"
 685      assert span.outputs["object"] == "chat.completion"
 686      assert span.outputs["model"] == "gpt-4o-mini"
 687      assert span.outputs["choices"][0]["message"]["content"] == "Hello world"
 688      assert span.attributes[SpanAttributeKey.CHAT_TOOLS] == MOCK_TOOLS
 689  
 690  
 691  @pytest.mark.skipif(
 692      Version(openai.__version__) < Version("1.40"), reason="Requires OpenAI SDK >= 1.40"
 693  )
 694  @pytest.mark.asyncio
 695  async def test_response_format(client):
 696      mlflow.openai.autolog()
 697  
 698      class Person(BaseModel):
 699          name: str
 700          age: int
 701  
 702      mock_response = {
 703          "id": "chatcmpl-Ax4UAd5xf32KjgLkS1SEEY9oorI9m",
 704          "object": "chat.completion",
 705          "created": 1738641958,
 706          "model": "gpt-4o-2024-08-06",
 707          "choices": [
 708              {
 709                  "index": 0,
 710                  "message": {
 711                      "role": "assistant",
 712                      "content": '{"name":"Angelo","age":42}',
 713                      "refusal": None,
 714                  },
 715                  "logprobs": None,
 716                  "finish_reason": "stop",
 717              }
 718          ],
 719          "usage": {
 720              "prompt_tokens": 68,
 721              "completion_tokens": 11,
 722              "total_tokens": 79,
 723              "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
 724              "completion_tokens_details": {
 725                  "reasoning_tokens": 0,
 726                  "audio_tokens": 0,
 727                  "accepted_prediction_tokens": 0,
 728                  "rejected_prediction_tokens": 0,
 729              },
 730          },
 731          "service_tier": "default",
 732          "system_fingerprint": "fp_50cad350e4",
 733      }
 734  
 735      if client._is_async:
 736          patch_target = "httpx.AsyncClient.send"
 737  
 738          async def send_patch(self, request, *args, **kwargs):
 739              return httpx.Response(
 740                  status_code=200,
 741                  request=request,
 742                  json=mock_response,
 743              )
 744  
 745      else:
 746          patch_target = "httpx.Client.send"
 747  
 748          def send_patch(self, request, *args, **kwargs):
 749              return httpx.Response(
 750                  status_code=200,
 751                  request=request,
 752                  json=mock_response,
 753              )
 754  
 755      with mock.patch(patch_target, send_patch):
 756          response = client.beta.chat.completions.parse(
 757              messages=[
 758                  {"role": "system", "content": "Extract info from text"},
 759                  {"role": "user", "content": "I am Angelo and I am 42."},
 760              ],
 761              model="gpt-4o",
 762              temperature=0,
 763              response_format=Person,
 764          )
 765  
 766          if client._is_async:
 767              response = await response
 768  
 769      assert response.choices[0].message.parsed == Person(name="Angelo", age=42)
 770      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 771      assert len(trace.data.spans) == 1
 772      span = trace.data.spans[0]
 773      assert span.outputs["choices"][0]["message"]["content"] == '{"name":"Angelo","age":42}'
 774      assert span.span_type == SpanType.CHAT_MODEL
 775      assert span.model_name == "gpt-4o"
 776  
 777      assert trace.info.trace_metadata.get(TraceMetadataKey.TOKEN_USAGE) == json.dumps({
 778          TokenUsageKey.INPUT_TOKENS: 68,
 779          TokenUsageKey.OUTPUT_TOKENS: 11,
 780          TokenUsageKey.TOTAL_TOKENS: 79,
 781          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0,
 782      })
 783  
 784  
 785  @skip_when_testing_trace_sdk
 786  @pytest.mark.asyncio
 787  async def test_autolog_link_traces_to_loaded_model_chat_completions(client, completion_models):
 788      mlflow.openai.autolog()
 789  
 790      for model_info in completion_models:
 791          model_dict = mlflow.openai.load_model(model_info.model_uri)
 792          resp = client.chat.completions.create(
 793              messages=[{"role": "user", "content": f"test {model_info.model_id}"}],
 794              model=model_dict["model"],
 795              temperature=model_dict["temperature"],
 796          )
 797          if client._is_async:
 798              await resp
 799  
 800      traces = get_traces()
 801      assert len(traces) == len(completion_models)
 802      for trace in traces:
 803          span = trace.data.spans[0]
 804          model_id = trace.info.request_metadata[TraceMetadataKey.MODEL_ID]
 805          assert model_id is not None
 806          assert span.inputs["messages"][0]["content"] == f"test {model_id}"
 807          assert span.model_name == model_dict["model"]
 808  
 809  
 810  @skip_when_testing_trace_sdk
 811  @pytest.mark.asyncio
 812  async def test_autolog_link_traces_to_loaded_model_completions(client, completion_models):
 813      mlflow.openai.autolog()
 814  
 815      for model_info in completion_models:
 816          model_dict = mlflow.openai.load_model(model_info.model_uri)
 817          resp = client.completions.create(
 818              prompt=f"test {model_info.model_id}",
 819              model=model_dict["model"],
 820              temperature=model_dict["temperature"],
 821          )
 822          if client._is_async:
 823              await resp
 824  
 825      traces = get_traces()
 826      assert len(traces) == len(completion_models)
 827      for trace in traces:
 828          span = trace.data.spans[0]
 829          model_id = trace.info.request_metadata[TraceMetadataKey.MODEL_ID]
 830          assert model_id is not None
 831          assert span.inputs["prompt"] == f"test {model_id}"
 832          assert span.model_name == model_dict["model"]
 833  
 834  
 835  @skip_when_testing_trace_sdk
 836  @pytest.mark.asyncio
 837  async def test_autolog_link_traces_to_loaded_model_embeddings(client, embedding_models):
 838      mlflow.openai.autolog()
 839  
 840      for model_info in embedding_models:
 841          model_dict = mlflow.openai.load_model(model_info.model_uri)
 842          resp = client.embeddings.create(
 843              input=f"test {model_info.model_id}",
 844              model=model_dict["model"],
 845              encoding_format=model_dict["encoding_format"],
 846          )
 847          if client._is_async:
 848              await resp
 849  
 850      traces = get_traces()
 851      assert len(traces) == len(embedding_models)
 852      for trace in traces:
 853          span = trace.data.spans[0]
 854          model_id = trace.info.request_metadata[TraceMetadataKey.MODEL_ID]
 855          assert model_id is not None
 856          assert span.inputs["input"] == f"test {model_id}"
 857          assert span.model_name == model_dict["model"]
 858  
 859  
 860  @skip_when_testing_trace_sdk
 861  def test_autolog_link_traces_to_loaded_model_embeddings_pyfunc(
 862      monkeypatch, mock_openai, embedding_models
 863  ):
 864      monkeypatch.setenv("OPENAI_API_KEY", "test")
 865      monkeypatch.setenv("OPENAI_API_BASE", mock_openai)
 866  
 867      mlflow.openai.autolog()
 868  
 869      for model_info in embedding_models:
 870          pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
 871          assert mlflow.get_active_model_id() == model_info.model_id
 872          pyfunc_model.predict(f"test {model_info.model_id}")
 873  
 874      traces = get_traces()
 875      assert len(traces) == len(embedding_models)
 876      for trace in traces:
 877          span = trace.data.spans[0]
 878          model_id = trace.info.request_metadata[TraceMetadataKey.MODEL_ID]
 879          assert model_id is not None
 880          assert span.inputs["input"] == [f"test {model_id}"]
 881          assert span.model_name == "text-embedding-ada-002"
 882  
 883  
 884  @skip_when_testing_trace_sdk
 885  def test_autolog_link_traces_to_active_model(monkeypatch, mock_openai, embedding_models):
 886      monkeypatch.setenv("OPENAI_API_KEY", "test")
 887      monkeypatch.setenv("OPENAI_API_BASE", mock_openai)
 888  
 889      model = mlflow.create_external_model(name="test_model")
 890      mlflow.set_active_model(model_id=model.model_id)
 891      mlflow.openai.autolog()
 892  
 893      for model_info in embedding_models:
 894          pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
 895          pyfunc_model.predict(model_info.model_id)
 896  
 897      traces = get_traces()
 898      assert len(traces) == len(embedding_models)
 899      for trace in traces:
 900          span = trace.data.spans[0]
 901          assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == model.model_id
 902          logged_model_id = span.inputs["input"][0]
 903          assert logged_model_id != model.model_id
 904          assert span.model_name == "text-embedding-ada-002"
 905  
 906  
 907  @pytest.mark.asyncio
 908  async def test_images_generate_autolog(client):
 909      mlflow.openai.autolog()
 910  
 911      # Disable tracing header injection — safe_patch rejects the extra_headers
 912      # dict as a "new input" because it's not an ExceptionSafe-wrapped object.
 913      # This is a known limitation shared with other non-chat endpoints.
 914      openai_autolog_module = sys.modules["mlflow.openai.autolog"]
 915      with mock.patch.object(openai_autolog_module, "_inject_tracing_headers"):
 916          response = client.images.generate(
 917              model="dall-e-3",
 918              prompt="a white siamese cat",
 919              n=1,
 920              response_format="b64_json",
 921          )
 922  
 923          if client._is_async:
 924              await response
 925  
 926      traces = get_traces()
 927      assert len(traces) == 1
 928      trace = traces[0]
 929      assert trace.info.status == "OK"
 930      assert len(trace.data.spans) == 1
 931      span = trace.data.spans[0]
 932      assert span.span_type == SpanType.TOOL
 933      assert span.inputs["prompt"] == "a white siamese cat"
 934      assert span.outputs["data"][0]["revised_prompt"] == "a test image"
 935  
 936  
 937  @pytest.mark.parametrize(
 938      "sentinel",
 939      [None, 42, object()],
 940  )
 941  def test_parse_tools_handles_openai_not_given_sentinel(sentinel):
 942      assert _parse_tools({"tools": sentinel}) == []
 943  
 944  
 945  @skip_when_testing_trace_sdk
 946  @pytest.mark.asyncio
 947  async def test_model_loading_set_active_model_id_without_fetching_logged_model(
 948      client, completion_models
 949  ):
 950      mlflow.openai.autolog()
 951  
 952      model_info = completion_models[0]
 953      with mock.patch("mlflow.get_logged_model", side_effect=Exception("get_logged_model failed")):
 954          model_dict = mlflow.openai.load_model(model_info.model_uri)
 955      resp = client.chat.completions.create(
 956          messages=[{"role": "user", "content": f"test {model_info.model_id}"}],
 957          model=model_dict["model"],
 958          temperature=model_dict["temperature"],
 959      )
 960      if client._is_async:
 961          await resp
 962  
 963      traces = get_traces()
 964      assert len(traces) == 1
 965      span = traces[0].data.spans[0]
 966      model_id = traces[0].info.request_metadata[TraceMetadataKey.MODEL_ID]
 967      assert model_id is not None
 968      assert span.inputs["messages"][0]["content"] == f"test {model_id}"
 969      assert span.model_name == model_dict["model"]
 970  
 971  
 972  @pytest.mark.skipif(
 973      Version(openai.__version__) < Version("1.66"), reason="Requires OpenAI SDK >= 1.66"
 974  )
 975  @skip_when_testing_trace_sdk
 976  def test_reconstruct_response_from_stream():
 977      from openai.types.responses import (
 978          ResponseOutputItemDoneEvent,
 979          ResponseOutputMessage,
 980          ResponseOutputText,
 981      )
 982  
 983      from mlflow.openai.autolog import _reconstruct_response_from_stream
 984      from mlflow.types.responses_helpers import OutputItem
 985  
 986      content1 = ResponseOutputText(annotations=[], text="Hello", type="output_text")
 987      content2 = ResponseOutputText(annotations=[], text=" world", type="output_text")
 988  
 989      message1 = ResponseOutputMessage(
 990          id="test-1", content=[content1], role="assistant", status="completed", type="message"
 991      )
 992  
 993      message2 = ResponseOutputMessage(
 994          id="test-2", content=[content2], role="assistant", status="completed", type="message"
 995      )
 996  
 997      chunk1 = ResponseOutputItemDoneEvent(
 998          item=message1, output_index=0, sequence_number=1, type="response.output_item.done"
 999      )
1000  
1001      chunk2 = ResponseOutputItemDoneEvent(
1002          item=message2, output_index=1, sequence_number=2, type="response.output_item.done"
1003      )
1004  
1005      chunks = [chunk1, chunk2]
1006  
1007      result = _reconstruct_response_from_stream(chunks)
1008  
1009      assert result.output == [
1010          OutputItem(**chunk1.item.to_dict()),
1011          OutputItem(**chunk2.item.to_dict()),
1012      ]
1013  
1014  
1015  @pytest.mark.asyncio
1016  async def test_tracing_headers_injected(client):
1017      mlflow.openai.autolog()
1018  
1019      captured_request = {}
1020      mock_response = {
1021          "id": "chatcmpl-123",
1022          "object": "chat.completion",
1023          "created": 1677652288,
1024          "model": "gpt-4o-mini",
1025          "choices": [
1026              {
1027                  "index": 0,
1028                  "message": {"role": "assistant", "content": "hi"},
1029                  "finish_reason": "stop",
1030              }
1031          ],
1032          "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
1033      }
1034  
1035      if client._is_async:
1036          patch_target = "httpx.AsyncClient.send"
1037          original_send = httpx.AsyncClient.send
1038  
1039          async def send_patch(self, request, *args, **kwargs):
1040              if "chat/completions" in str(request.url):
1041                  captured_request["headers"] = dict(request.headers)
1042                  return httpx.Response(status_code=200, request=request, json=mock_response)
1043              return await original_send(self, request, *args, **kwargs)
1044  
1045      else:
1046          patch_target = "httpx.Client.send"
1047          original_send = httpx.Client.send
1048  
1049          def send_patch(self, request, *args, **kwargs):
1050              if "chat/completions" in str(request.url):
1051                  captured_request["headers"] = dict(request.headers)
1052                  return httpx.Response(status_code=200, request=request, json=mock_response)
1053              return original_send(self, request, *args, **kwargs)
1054  
1055      with mock.patch(patch_target, send_patch):
1056          response = client.chat.completions.create(
1057              messages=[{"role": "user", "content": "test"}],
1058              model="gpt-4o-mini",
1059          )
1060          if client._is_async:
1061              response = await response
1062  
1063      # Verify traceparent header was injected
1064      assert "traceparent" in captured_request["headers"]
1065      traceparent = captured_request["headers"]["traceparent"]
1066      assert re.fullmatch(r"00-[0-9a-f]{32}-[0-9a-f]{16}-[0-9a-f]{2}", traceparent)
1067  
1068      # Verify the traceparent points to the LLM span
1069      traces = get_traces()
1070      assert len(traces) == 1
1071      span = traces[0].data.spans[0]
1072      span_ctx = span._span.get_span_context()
1073      trace_id_hex = format(span_ctx.trace_id, "032x")
1074      span_id_hex = format(span_ctx.span_id, "016x")
1075      assert traceparent.startswith(f"00-{trace_id_hex}-{span_id_hex}-")
1076  
1077  
1078  @pytest.mark.asyncio
1079  async def test_tracing_headers_preserve_user_headers(client):
1080      mlflow.openai.autolog()
1081  
1082      captured_request = {}
1083      mock_response = {
1084          "id": "chatcmpl-123",
1085          "object": "chat.completion",
1086          "created": 1677652288,
1087          "model": "gpt-4o-mini",
1088          "choices": [
1089              {
1090                  "index": 0,
1091                  "message": {"role": "assistant", "content": "hi"},
1092                  "finish_reason": "stop",
1093              }
1094          ],
1095          "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
1096      }
1097  
1098      if client._is_async:
1099          patch_target = "httpx.AsyncClient.send"
1100          original_send = httpx.AsyncClient.send
1101  
1102          async def send_patch(self, request, *args, **kwargs):
1103              if "chat/completions" in str(request.url):
1104                  captured_request["headers"] = dict(request.headers)
1105                  return httpx.Response(status_code=200, request=request, json=mock_response)
1106              return await original_send(self, request, *args, **kwargs)
1107  
1108      else:
1109          patch_target = "httpx.Client.send"
1110          original_send = httpx.Client.send
1111  
1112          def send_patch(self, request, *args, **kwargs):
1113              if "chat/completions" in str(request.url):
1114                  captured_request["headers"] = dict(request.headers)
1115                  return httpx.Response(status_code=200, request=request, json=mock_response)
1116              return original_send(self, request, *args, **kwargs)
1117  
1118      with mock.patch(patch_target, send_patch):
1119          response = client.chat.completions.create(
1120              messages=[{"role": "user", "content": "test"}],
1121              model="gpt-4o-mini",
1122              extra_headers={"X-Custom": "my-value"},
1123          )
1124          if client._is_async:
1125              response = await response
1126  
1127      # User-provided headers should be preserved alongside traceparent
1128      assert "traceparent" in captured_request["headers"]
1129      assert captured_request["headers"].get("x-custom") == "my-value"
1130  
1131  
1132  @pytest.mark.asyncio
1133  @pytest.mark.skipif(
1134      Version(openai.__version__) < Version("1.66"), reason="Cost tracking does not work before 1.66"
1135  )
1136  async def test_chat_completions_autolog_streaming_with_cached_tokens(client, mock_litellm_cost):
1137      mlflow.openai.autolog()
1138  
1139      mock_chunk = {
1140          "id": "chatcmpl-stream-cached",
1141          "object": "chat.completion.chunk",
1142          "created": 1677652288,
1143          "model": "gpt-4o-mini",
1144          "choices": [],
1145          "usage": {
1146              "prompt_tokens": 50,
1147              "completion_tokens": 20,
1148              "total_tokens": 70,
1149              "prompt_tokens_details": {"cached_tokens": 30, "audio_tokens": 0},
1150              "completion_tokens_details": {"reasoning_tokens": 0},
1151          },
1152      }
1153  
1154      if client._is_async:
1155          patch_target = "httpx.AsyncClient.send"
1156  
1157          async def send_patch(self, request, *args, **kwargs):
1158              content = f"data: {json.dumps(mock_chunk)}\n\ndata: [DONE]\n\n".encode()
1159              return httpx.Response(status_code=200, request=request, content=content)
1160  
1161      else:
1162          patch_target = "httpx.Client.send"
1163  
1164          def send_patch(self, request, *args, **kwargs):
1165              content = f"data: {json.dumps(mock_chunk)}\n\ndata: [DONE]\n\n".encode()
1166              return httpx.Response(status_code=200, request=request, content=content)
1167  
1168      with mock.patch(patch_target, send_patch):
1169          stream = client.chat.completions.create(
1170              messages=[{"role": "user", "content": "test"}],
1171              model="gpt-4o-mini",
1172              stream=True,
1173              stream_options={"include_usage": True},
1174          )
1175          if client._is_async:
1176              async for _ in await stream:
1177                  pass
1178          else:
1179              for _ in stream:
1180                  pass
1181  
1182      traces = get_traces()
1183      assert len(traces) == 1
1184      span = traces[0].data.spans[0]
1185  
1186      assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
1187          TokenUsageKey.INPUT_TOKENS: 50,
1188          TokenUsageKey.OUTPUT_TOKENS: 20,
1189          TokenUsageKey.TOTAL_TOKENS: 70,
1190          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 30,
1191      }