/ tests / openai / test_openai_responses_autolog.py
test_openai_responses_autolog.py
  1  from unittest import mock
  2  
  3  import httpx
  4  import openai
  5  import pytest
  6  from packaging.version import Version
  7  
  8  import mlflow
  9  from mlflow.entities.span import SpanType
 10  from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey
 11  
 12  from tests.tracing.helper import get_traces
 13  
 14  if Version(openai.__version__) < Version("1.66.00"):
 15      pytest.skip(
 16          "OpenAI < 1.66.0 does not support the Responses API.",
 17          allow_module_level=True,
 18      )
 19  
 20  
 21  @pytest.fixture(params=[True, False], ids=["sync", "async"])
 22  def client(request, monkeypatch, mock_openai):
 23      monkeypatch.setenv("OPENAI_API_KEY", "test")
 24      monkeypatch.setenv("OPENAI_API_BASE", mock_openai)
 25      if request.param:
 26          client = openai.OpenAI(api_key="test", base_url=mock_openai)
 27          client._is_async = False
 28          return client
 29      else:
 30          client = openai.AsyncOpenAI(api_key="test", base_url=mock_openai)
 31          client._is_async = True
 32          return client
 33  
 34  
 35  @pytest.mark.asyncio
 36  @pytest.mark.parametrize(
 37      "_input",
 38      [
 39          "Hello",
 40          [{"role": "user", "content": "Hello"}],
 41      ],
 42  )
 43  async def test_responses_autolog(client, _input):
 44      mlflow.openai.autolog()
 45  
 46      response = client.responses.create(
 47          input=_input,
 48          model="gpt-4o",
 49          temperature=0,
 50      )
 51  
 52      if client._is_async:
 53          await response
 54  
 55      traces = get_traces()
 56      assert len(traces) == 1
 57      assert traces[0].info.status == "OK"
 58      assert len(traces[0].data.spans) == 1
 59      span = traces[0].data.spans[0]
 60      assert span.span_type == SpanType.CHAT_MODEL
 61      assert span.inputs == {"input": _input, "model": "gpt-4o", "temperature": 0}
 62      assert span.outputs["id"] == "responses-123"
 63      assert span.attributes["model"] == "gpt-4o"
 64      assert span.attributes["temperature"] == 0
 65  
 66      # Token usage should be aggregated correctly
 67      assert traces[0].info.token_usage == {
 68          TokenUsageKey.INPUT_TOKENS: 36,
 69          TokenUsageKey.OUTPUT_TOKENS: 87,
 70          TokenUsageKey.TOTAL_TOKENS: 123,
 71          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0,
 72      }
 73  
 74  
 75  @pytest.mark.asyncio
 76  async def test_responses_image_input_autolog(client):
 77      mlflow.openai.autolog()
 78  
 79      response = client.responses.create(
 80          input=[
 81              {
 82                  "role": "user",
 83                  "content": [
 84                      {"type": "input_text", "text": "what is in this image?"},
 85                      {
 86                          "type": "input_image",
 87                          "image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
 88                      },
 89                  ],
 90              }
 91          ],
 92          model="gpt-4o",
 93          temperature=0,
 94      )
 95  
 96      if client._is_async:
 97          await response
 98  
 99      traces = get_traces()
100      assert len(traces) == 1
101      assert traces[0].info.status == "OK"
102      assert len(traces[0].data.spans) == 1
103      span = traces[0].data.spans[0]
104      assert span.span_type == SpanType.CHAT_MODEL
105  
106  
107  @pytest.mark.asyncio
108  async def test_responses_web_search_autolog(client):
109      mlflow.openai.autolog()
110  
111      response = client.responses.create(
112          model="gpt-4o",
113          tools=[{"type": "web_search_preview"}],
114          input="What was a positive news story from today?",
115      )
116  
117      if client._is_async:
118          await response
119  
120      traces = get_traces()
121      assert len(traces) == 1
122      assert traces[0].info.status == "OK"
123      assert len(traces[0].data.spans) == 1
124      span = traces[0].data.spans[0]
125      assert span.attributes[SpanAttributeKey.CHAT_TOOLS] == [
126          {"type": "function", "function": {"name": "web_search_preview"}}
127      ]
128  
129  
130  @pytest.mark.asyncio
131  async def test_responses_file_search_autolog(client):
132      mlflow.openai.autolog()
133  
134      response = client.responses.create(
135          model="gpt-4o",
136          tools=[
137              {
138                  "type": "file_search",
139                  "vector_store_ids": ["vs_1234567890"],
140                  "max_num_results": 20,
141              }
142          ],
143          input="What are the attributes of an ancient brown dragon?",
144      )
145  
146      if client._is_async:
147          await response
148  
149      traces = get_traces()
150      assert len(traces) == 1
151      assert traces[0].info.status == "OK"
152      assert len(traces[0].data.spans) == 1
153      span = traces[0].data.spans[0]
154      assert span.attributes[SpanAttributeKey.CHAT_TOOLS] == [
155          {"type": "function", "function": {"name": "file_search"}}
156      ]
157  
158  
159  @pytest.mark.asyncio
160  async def test_responses_computer_use_autolog(client):
161      mlflow.openai.autolog()
162  
163      computer_tool_def = {
164          "type": "computer_use_preview",
165          "display_width": 1024,
166          "display_height": 768,
167          "environment": "browser",
168      }
169  
170      with mlflow.start_span(name="openai_computer_use"):
171          response = client.responses.create(
172              model="computer-use-preview",
173              input=[{"role": "user", "content": "Check the latest OpenAI news on bing.com."}],
174              tools=[computer_tool_def],
175          )
176  
177          if client._is_async:
178              await response
179  
180          # Send the response back to the computer tool
181          response = client.responses.create(
182              model="computer-use-preview",
183              input=[
184                  {
185                      "call_id": "computer_call_1",
186                      "type": "computer_call_output",
187                      "output": {
188                          "type": "input_image",
189                          "image_url": "data:image/png;base64,screenshot_base64",
190                      },
191                  }
192              ],
193              tools=[computer_tool_def],
194          )
195  
196          if client._is_async:
197              await response
198  
199      traces = get_traces()
200      assert len(traces) == 1
201      assert traces[0].info.status == "OK"
202      assert len(traces[0].data.spans) == 3
203      llm_span_1 = traces[0].data.spans[1]
204      assert llm_span_1.span_type == SpanType.CHAT_MODEL
205      assert llm_span_1.inputs["model"] == "computer-use-preview"
206      assert llm_span_1.outputs["id"] == "responses-123"
207      assert llm_span_1.attributes[SpanAttributeKey.CHAT_TOOLS] == [
208          {"type": "function", "function": {"name": "computer_use_preview"}}
209      ]
210  
211  
212  @pytest.mark.asyncio
213  async def test_responses_function_calling_autolog(client):
214      mlflow.openai.autolog()
215  
216      tools = [
217          {
218              "type": "function",
219              "name": "get_current_weather",
220              "description": "Get the current weather in a given location",
221              "parameters": {
222                  "type": "object",
223                  "properties": {
224                      "location": {
225                          "type": "string",
226                          "description": "The city and state, e.g. San Francisco, CA",
227                      },
228                      "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
229                  },
230                  "required": ["location", "unit"],
231              },
232          }
233      ]
234  
235      response = client.responses.create(
236          model="gpt-4o",
237          tools=tools,
238          input="What is the weather like in Boston today?",
239          tool_choice="auto",
240      )
241  
242      if client._is_async:
243          await response
244  
245      traces = get_traces()
246      assert len(traces) == 1
247      assert traces[0].info.status == "OK"
248      assert len(traces[0].data.spans) == 1
249      span = traces[0].data.spans[0]
250      assert span.span_type == SpanType.CHAT_MODEL
251      assert span.inputs["model"] == "gpt-4o"
252      assert span.outputs["id"] == "responses-123"
253      assert span.attributes[SpanAttributeKey.CHAT_TOOLS] == [
254          {"type": "function", "function": {k: v for k, v in tools[0].items() if k != "type"}}
255      ]
256      assert span.attributes[SpanAttributeKey.MESSAGE_FORMAT] == "openai"
257  
258  
259  @pytest.mark.asyncio
260  async def test_responses_autolog_with_cached_tokens(client):
261      mlflow.openai.autolog()
262  
263      mock_response = {
264          "id": "responses-cached",
265          "object": "response",
266          "created": 1589478378,
267          "status": "completed",
268          "error": None,
269          "incomplete_details": None,
270          "max_output_tokens": None,
271          "model": "gpt-4o",
272          "output": [
273              {
274                  "type": "message",
275                  "id": "test",
276                  "status": "completed",
277                  "role": "assistant",
278                  "content": [{"type": "output_text", "text": "Hello"}],
279              }
280          ],
281          "parallel_tool_calls": True,
282          "previous_response_id": None,
283          "reasoning": {"effort": None, "generate_summary": None},
284          "store": True,
285          "temperature": 1.0,
286          "text": {"format": {"type": "text"}},
287          "tool_choice": "auto",
288          "tools": [],
289          "top_p": 1.0,
290          "truncation": "disabled",
291          "usage": {
292              "input_tokens": 100,
293              "input_tokens_details": {"cached_tokens": 40},
294              "output_tokens": 50,
295              "output_tokens_details": {"reasoning_tokens": 0},
296              "total_tokens": 150,
297          },
298          "user": None,
299          "metadata": {},
300      }
301      if client._is_async:
302          patch_target = "httpx.AsyncClient.send"
303  
304          async def send_patch(self, request, *args, **kwargs):
305              return httpx.Response(status_code=200, request=request, json=mock_response)
306  
307      else:
308          patch_target = "httpx.Client.send"
309  
310          def send_patch(self, request, *args, **kwargs):
311              return httpx.Response(status_code=200, request=request, json=mock_response)
312  
313      with mock.patch(patch_target, send_patch):
314          response = client.responses.create(
315              input="Hello",
316              model="gpt-4o",
317              temperature=0,
318          )
319          if client._is_async:
320              response = await response
321  
322      traces = get_traces()
323      assert len(traces) == 1
324      span = traces[0].data.spans[0]
325  
326      assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
327          TokenUsageKey.INPUT_TOKENS: 100,
328          TokenUsageKey.OUTPUT_TOKENS: 50,
329          TokenUsageKey.TOTAL_TOKENS: 150,
330          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 40,
331      }
332  
333      assert traces[0].info.token_usage == {
334          TokenUsageKey.INPUT_TOKENS: 100,
335          TokenUsageKey.OUTPUT_TOKENS: 50,
336          TokenUsageKey.TOTAL_TOKENS: 150,
337          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 40,
338      }
339  
340  
341  @pytest.mark.asyncio
342  async def test_responses_stream_autolog(client):
343      mlflow.openai.autolog()
344  
345      response = client.responses.create(
346          input="Hello",
347          model="gpt-4o",
348          stream=True,
349      )
350  
351      if client._is_async:
352          async for _ in await response:
353              pass
354      else:
355          for _ in response:
356              pass
357  
358      traces = get_traces()
359      assert len(traces) == 1
360      assert traces[0].info.status == "OK"
361      assert len(traces[0].data.spans) == 1
362      span = traces[0].data.spans[0]
363      assert span.span_type == SpanType.CHAT_MODEL
364      assert span.outputs["id"] == "responses-123"
365      # "logprobs" is only returned from certain version of OpenAI SDK
366      span.outputs["output"][0]["content"][0].pop("logprobs", None)
367      assert span.outputs["output"][0]["content"] == [
368          {
369              "text": "Dummy output",
370              "annotations": None,
371              "type": "output_text",
372          }
373      ]
374      assert span.attributes["model"] == "gpt-4o"
375      assert span.attributes["stream"] is True
376  
377      # Token usage should be aggregated correctly
378      assert traces[0].info.token_usage == {
379          TokenUsageKey.INPUT_TOKENS: 36,
380          TokenUsageKey.OUTPUT_TOKENS: 87,
381          TokenUsageKey.TOTAL_TOKENS: 123,
382          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0,
383      }