/ tests / pydantic_ai / test_pydanticai_fluent_tracing.py
test_pydanticai_fluent_tracing.py
  1  import importlib.metadata
  2  from contextlib import asynccontextmanager
  3  from unittest.mock import patch
  4  
  5  import pytest
  6  from packaging.version import Version
  7  from pydantic_ai import Agent, RunContext
  8  from pydantic_ai.messages import ModelResponse, TextPart, ToolCallPart
  9  from pydantic_ai.models.instrumented import InstrumentedModel
 10  from pydantic_ai.usage import Usage
 11  
 12  import mlflow
 13  import mlflow.pydantic_ai  # ensure the integration module is importable
 14  from mlflow.entities import SpanType
 15  from mlflow.tracing.constant import SpanAttributeKey
 16  
 17  from tests.tracing.helper import get_traces
 18  
 19  _FINAL_ANSWER_WITHOUT_TOOL = "Paris"
 20  _FINAL_ANSWER_WITH_TOOL = "winner"
 21  
 22  PYDANTIC_AI_VERSION = Version(importlib.metadata.version("pydantic_ai"))
 23  # Usage was deprecated in favor of RequestUsage in 0.7.3
 24  IS_USAGE_DEPRECATED = PYDANTIC_AI_VERSION >= Version("0.7.3")
 25  # run_stream_sync was added in pydantic-ai 1.10.0
 26  HAS_RUN_STREAM_SYNC = hasattr(Agent, "run_stream_sync")
 27  # Streaming tests require pydantic-ai >= 1.0.0 due to API changes
 28  HAS_STABLE_STREAMING_API = PYDANTIC_AI_VERSION >= Version("1.0.0")
 29  # In pydantic-ai >= 1.63.0, _agent_graph calls execute_tool_call directly instead of handle_call.
 30  # _tool_manager module doesn't exist in older versions (e.g. 0.2.x).
 31  try:
 32      from pydantic_ai._tool_manager import ToolManager as _ToolManager
 33  
 34      TOOL_MANAGER_SPAN_NAME = (
 35          "ToolManager.execute_tool_call"
 36          if hasattr(_ToolManager, "execute_tool_call")
 37          else "ToolManager.handle_call"
 38      )
 39  except ImportError:
 40      TOOL_MANAGER_SPAN_NAME = "ToolManager.handle_call"
 41  
 42  
 43  def _make_dummy_response_without_tool():
 44      if IS_USAGE_DEPRECATED:
 45          from pydantic_ai.usage import RequestUsage
 46  
 47      parts = [TextPart(content=_FINAL_ANSWER_WITHOUT_TOOL)]
 48      resp = ModelResponse(parts=parts)
 49      if IS_USAGE_DEPRECATED:
 50          usage = RequestUsage(input_tokens=1, output_tokens=1)
 51      else:
 52          usage = Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=2)
 53  
 54      if PYDANTIC_AI_VERSION >= Version("0.2.0"):
 55          return ModelResponse(parts=parts, usage=usage)
 56      else:
 57          return resp, usage
 58  
 59  
 60  def _make_dummy_response_with_tool():
 61      if IS_USAGE_DEPRECATED:
 62          from pydantic_ai.usage import RequestUsage
 63  
 64      call_parts = [ToolCallPart(tool_name="roulette_wheel", args={"square": 18})]
 65      final_parts = [TextPart(content=_FINAL_ANSWER_WITH_TOOL)]
 66  
 67      if IS_USAGE_DEPRECATED:
 68          usage_call = RequestUsage(input_tokens=10, output_tokens=20)
 69          usage_final = RequestUsage(input_tokens=100, output_tokens=200)
 70      else:
 71          usage_call = Usage(requests=0, request_tokens=10, response_tokens=20, total_tokens=30)
 72          usage_final = Usage(requests=1, request_tokens=100, response_tokens=200, total_tokens=300)
 73  
 74      if PYDANTIC_AI_VERSION >= Version("0.2.0"):
 75          call_resp = ModelResponse(parts=call_parts, usage=usage_call)
 76          final_resp = ModelResponse(parts=final_parts, usage=usage_final)
 77          yield call_resp
 78          yield final_resp
 79  
 80      else:
 81          call_resp = ModelResponse(parts=call_parts)
 82          final_resp = ModelResponse(parts=final_parts)
 83          yield call_resp, usage_call
 84          yield final_resp, usage_final
 85  
 86  
 87  def _make_streaming_response_without_tool(input_tokens=10, output_tokens=5):
 88      if IS_USAGE_DEPRECATED:
 89          from pydantic_ai.usage import RequestUsage
 90  
 91          usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens)
 92      else:
 93          usage = Usage(
 94              requests=1,
 95              request_tokens=input_tokens,
 96              response_tokens=output_tokens,
 97              total_tokens=input_tokens + output_tokens,
 98          )
 99  
100      return ModelResponse(parts=[TextPart(content=_FINAL_ANSWER_WITHOUT_TOOL)], usage=usage), usage
101  
102  
103  def _make_streaming_response_with_tool():
104      if IS_USAGE_DEPRECATED:
105          from pydantic_ai.usage import RequestUsage
106  
107          usage_call = RequestUsage(input_tokens=10, output_tokens=20)
108          usage_final = RequestUsage(input_tokens=100, output_tokens=200)
109      else:
110          usage_call = Usage(requests=0, request_tokens=10, response_tokens=20, total_tokens=30)
111          usage_final = Usage(requests=1, request_tokens=100, response_tokens=200, total_tokens=300)
112  
113      call_resp = ModelResponse(
114          parts=[ToolCallPart(tool_name="roulette_wheel", args={"square": 18})],
115          usage=usage_call,
116      )
117      final_resp = ModelResponse(
118          parts=[TextPart(content=_FINAL_ANSWER_WITH_TOOL)],
119          usage=usage_final,
120      )
121  
122      return [call_resp, final_resp]
123  
124  
125  class MockStreamedResponse:
126      def __init__(self, response, usage):
127          self._response = response
128          self._usage = usage
129          self.model_name = "openai:gpt-4o"
130          self.timestamp = None
131  
132      def usage(self):
133          return self._usage
134  
135      def get(self):
136          return self._response
137  
138      async def __aiter__(self):
139          for part in self._response.parts:
140              if hasattr(part, "content"):
141                  for char in part.content:
142                      yield char
143              else:
144                  yield ""
145  
146  
147  @pytest.fixture(autouse=True)
148  def clear_autolog_state():
149      from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS
150  
151      for key in AUTOLOGGING_INTEGRATIONS.keys():
152          AUTOLOGGING_INTEGRATIONS[key].clear()
153      mlflow.utils.import_hooks._post_import_hooks = {}
154  
155  
156  @pytest.fixture
157  def simple_agent():
158      return Agent(
159          "openai:gpt-4o",
160          system_prompt="Tell me the capital of {{input}}.",
161          instrument=True,
162      )
163  
164  
165  @pytest.fixture
166  def agent_with_tool():
167      roulette_agent = Agent(
168          "openai:gpt-4o",
169          system_prompt=(
170              "Use the roulette_wheel function to see if the "
171              "customer has won based on the number they provide."
172          ),
173          instrument=True,
174          deps_type=int,
175          output_type=str,
176      )
177  
178      @roulette_agent.tool
179      async def roulette_wheel(ctx: RunContext[int], square: int) -> str:
180          """check if the square is a winner"""
181          return "winner" if square == ctx.deps else "loser"
182  
183      return roulette_agent
184  
185  
186  def test_agent_run_sync_enable_fluent_disable_autolog(simple_agent):
187      dummy = _make_dummy_response_without_tool()
188  
189      async def request(self, *args, **kwargs):
190          return dummy
191  
192      with patch.object(InstrumentedModel, "request", new=request):
193          mlflow.pydantic_ai.autolog(log_traces=True)
194  
195          result = simple_agent.run_sync("France")
196          assert result.output == _FINAL_ANSWER_WITHOUT_TOOL
197  
198      traces = get_traces()
199      assert len(traces) == 1
200      spans = traces[0].data.spans
201  
202      assert spans[0].name == "Agent.run_sync"
203      assert spans[0].span_type == SpanType.AGENT
204  
205      assert spans[1].name == "Agent.run"
206      assert spans[1].span_type == SpanType.AGENT
207  
208      span2 = spans[2]
209      assert span2.name == "InstrumentedModel.request"
210      assert span2.span_type == SpanType.LLM
211      assert span2.parent_id == spans[1].span_id
212  
213      with patch.object(InstrumentedModel, "request", new=request):
214          mlflow.pydantic_ai.autolog(disable=True)
215          simple_agent.run_sync("France")
216      assert len(get_traces()) == 1
217  
218  
219  @pytest.mark.asyncio
220  async def test_agent_run_enable_fluent_disable_autolog(simple_agent):
221      dummy = _make_dummy_response_without_tool()
222  
223      async def request(self, *args, **kwargs):
224          return dummy
225  
226      with patch.object(InstrumentedModel, "request", new=request):
227          mlflow.pydantic_ai.autolog(log_traces=True)
228  
229          result = await simple_agent.run("France")
230          assert result.output == _FINAL_ANSWER_WITHOUT_TOOL
231  
232      traces = get_traces()
233      assert len(traces) == 1
234      spans = traces[0].data.spans
235  
236      assert spans[0].name == "Agent.run"
237      assert spans[0].span_type == SpanType.AGENT
238  
239      span1 = spans[1]
240      assert span1.name == "InstrumentedModel.request"
241      assert span1.span_type == SpanType.LLM
242      assert span1.parent_id == spans[0].span_id
243  
244  
245  def test_agent_run_sync_enable_disable_fluent_autolog_with_tool(agent_with_tool):
246      sequence = _make_dummy_response_with_tool()
247  
248      async def request(self, *args, **kwargs):
249          return next(sequence)
250  
251      with patch.object(InstrumentedModel, "request", new=request):
252          mlflow.pydantic_ai.autolog(log_traces=True)
253  
254          result = agent_with_tool.run_sync("Put my money on square eighteen", deps=18)
255          assert result.output == _FINAL_ANSWER_WITH_TOOL
256  
257      traces = get_traces()
258      assert len(traces) == 1
259      spans = traces[0].data.spans
260  
261      assert len(spans) == 5
262  
263      assert spans[0].name == "Agent.run_sync"
264      assert spans[0].span_type == SpanType.AGENT
265  
266      assert spans[1].name == "Agent.run"
267      assert spans[1].span_type == SpanType.AGENT
268  
269      span2 = spans[2]
270      assert span2.name == "InstrumentedModel.request"
271      assert span2.span_type == SpanType.LLM
272      assert span2.parent_id == spans[1].span_id
273  
274      span3 = spans[3]
275      assert span3.span_type == SpanType.TOOL
276      assert span3.parent_id == spans[1].span_id
277  
278      span4 = spans[4]
279      assert span4.name == "InstrumentedModel.request"
280      assert span4.span_type == SpanType.LLM
281      assert span4.parent_id == spans[1].span_id
282  
283  
284  @pytest.mark.asyncio
285  async def test_agent_run_enable_disable_fluent_autolog_with_tool(agent_with_tool):
286      sequence = _make_dummy_response_with_tool()
287  
288      async def request(self, *args, **kwargs):
289          return next(sequence)
290  
291      with patch.object(InstrumentedModel, "request", new=request):
292          mlflow.pydantic_ai.autolog(log_traces=True)
293  
294          result = await agent_with_tool.run("Put my money on square eighteen", deps=18)
295          assert result.output == _FINAL_ANSWER_WITH_TOOL
296  
297      traces = get_traces()
298      assert len(traces) == 1
299      spans = traces[0].data.spans
300  
301      assert len(spans) == 4
302  
303      assert spans[0].name == "Agent.run"
304      assert spans[0].span_type == SpanType.AGENT
305  
306      span1 = spans[1]
307      assert span1.name == "InstrumentedModel.request"
308      assert span1.span_type == SpanType.LLM
309      assert span1.parent_id == spans[0].span_id
310  
311      span2 = spans[2]
312      assert span2.span_type == SpanType.TOOL
313      assert span2.parent_id == spans[0].span_id
314  
315      span3 = spans[3]
316      assert span3.name == "InstrumentedModel.request"
317      assert span3.span_type == SpanType.LLM
318      assert span3.parent_id == spans[0].span_id
319  
320  
321  @pytest.mark.skipif(
322      not HAS_STABLE_STREAMING_API, reason="Streaming API stabilized in pydantic-ai 1.0.0"
323  )
324  @pytest.mark.asyncio
325  async def test_agent_run_stream_creates_trace(simple_agent):
326      response, usage = _make_streaming_response_without_tool(input_tokens=10, output_tokens=5)
327  
328      @asynccontextmanager
329      async def request_stream(self, *args, **kwargs):
330          yield MockStreamedResponse(response, usage)
331  
332      with patch.object(InstrumentedModel, "request_stream", new=request_stream):
333          mlflow.pydantic_ai.autolog(log_traces=True)
334  
335          async with simple_agent.run_stream("France") as result:
336              output = await result.get_output()
337              assert output == _FINAL_ANSWER_WITHOUT_TOOL
338  
339      traces = get_traces()
340      assert len(traces) == 1
341      spans = traces[0].data.spans
342  
343      assert len(spans) == 2
344  
345      assert spans[0].name == "Agent.run_stream"
346      assert spans[0].span_type == SpanType.AGENT
347  
348      assert spans[1].name == "InstrumentedModel.request_stream"
349      assert spans[1].span_type == SpanType.LLM
350      assert spans[1].parent_id == spans[0].span_id
351  
352      usage_attr = spans[0].attributes.get(SpanAttributeKey.CHAT_USAGE)
353      assert usage_attr is not None
354      assert usage_attr.get("input_tokens") == 10
355      assert usage_attr.get("output_tokens") == 5
356      assert usage_attr.get("total_tokens") == 15
357  
358  
359  @pytest.mark.skipif(
360      not HAS_STABLE_STREAMING_API, reason="Streaming API stabilized in pydantic-ai 1.0.0"
361  )
362  @pytest.mark.skipif(not HAS_RUN_STREAM_SYNC, reason="run_stream_sync added in pydantic-ai 1.10.0")
363  def test_agent_run_stream_sync_creates_trace(simple_agent):
364      response, usage = _make_streaming_response_without_tool(input_tokens=10, output_tokens=5)
365  
366      @asynccontextmanager
367      async def request_stream(self, *args, **kwargs):
368          yield MockStreamedResponse(response, usage)
369  
370      with patch.object(InstrumentedModel, "request_stream", new=request_stream):
371          mlflow.pydantic_ai.autolog(log_traces=True)
372  
373          result = simple_agent.run_stream_sync("France")
374          output = ""
375          for text in result.stream_text():
376              output += text
377  
378      assert output == _FINAL_ANSWER_WITHOUT_TOOL
379  
380      traces = get_traces()
381      assert len(traces) == 1
382      spans = traces[0].data.spans
383  
384      assert len(spans) == 2
385  
386      assert spans[0].name == "Agent.run_stream_sync"
387      assert spans[0].span_type == SpanType.AGENT
388      assert spans[0].inputs is not None
389      assert "user_prompt" in spans[0].inputs
390      assert spans[0].outputs is not None
391  
392      assert spans[1].name == "InstrumentedModel.request_stream"
393      assert spans[1].span_type == SpanType.LLM
394      assert spans[1].parent_id == spans[0].span_id
395  
396      usage_attr = spans[0].attributes.get(SpanAttributeKey.CHAT_USAGE)
397      assert usage_attr is not None
398      assert usage_attr.get("input_tokens") == 10
399      assert usage_attr.get("output_tokens") == 5
400      assert usage_attr.get("total_tokens") == 15
401  
402  
403  @pytest.mark.skipif(
404      not HAS_STABLE_STREAMING_API, reason="Streaming API stabilized in pydantic-ai 1.0.0"
405  )
406  @pytest.mark.asyncio
407  async def test_agent_run_stream_with_tool(agent_with_tool):
408      sequence = _make_streaming_response_with_tool()
409  
410      @asynccontextmanager
411      async def request_stream(self, *args, **kwargs):
412          if sequence:
413              resp = sequence.pop(0)
414              yield MockStreamedResponse(resp, resp.usage)
415          else:
416              resp = sequence[-1]
417              yield MockStreamedResponse(resp, resp.usage)
418  
419      with patch.object(InstrumentedModel, "request_stream", new=request_stream):
420          mlflow.pydantic_ai.autolog(log_traces=True)
421  
422          async with agent_with_tool.run_stream("Put my money on square eighteen", deps=18) as result:
423              output = await result.get_output()
424              assert output == _FINAL_ANSWER_WITH_TOOL
425  
426      traces = get_traces()
427      assert len(traces) == 1
428      spans = traces[0].data.spans
429  
430      assert len(spans) == 4
431  
432      assert spans[0].name == "Agent.run_stream"
433      assert spans[0].span_type == SpanType.AGENT
434  
435      assert spans[1].name == "InstrumentedModel.request_stream"
436      assert spans[1].span_type == SpanType.LLM
437      assert spans[1].parent_id == spans[0].span_id
438  
439      assert spans[2].span_type == SpanType.TOOL
440      assert spans[2].name == TOOL_MANAGER_SPAN_NAME
441      assert spans[2].parent_id == spans[0].span_id
442  
443      assert spans[3].name == "InstrumentedModel.request_stream"
444      assert spans[3].span_type == SpanType.LLM
445      assert spans[3].parent_id == spans[0].span_id
446  
447  
448  @pytest.mark.skipif(
449      not HAS_STABLE_STREAMING_API, reason="Streaming API stabilized in pydantic-ai 1.0.0"
450  )
451  @pytest.mark.skipif(not HAS_RUN_STREAM_SYNC, reason="run_stream_sync added in pydantic-ai 1.10.0")
452  def test_agent_run_stream_sync_with_tool(agent_with_tool):
453      sequence = _make_streaming_response_with_tool()
454  
455      @asynccontextmanager
456      async def request_stream(self, *args, **kwargs):
457          if sequence:
458              resp = sequence.pop(0)
459              yield MockStreamedResponse(resp, resp.usage)
460          else:
461              resp = sequence[-1]
462              yield MockStreamedResponse(resp, resp.usage)
463  
464      with patch.object(InstrumentedModel, "request_stream", new=request_stream):
465          mlflow.pydantic_ai.autolog(log_traces=True)
466  
467          result = agent_with_tool.run_stream_sync("Put my money on square eighteen", deps=18)
468          output = ""
469          for text in result.stream_text():
470              output += text
471  
472      assert output == _FINAL_ANSWER_WITH_TOOL
473  
474      traces = get_traces()
475      assert len(traces) == 1
476      spans = traces[0].data.spans
477  
478      assert len(spans) == 4
479  
480      assert spans[0].name == "Agent.run_stream_sync"
481      assert spans[0].span_type == SpanType.AGENT
482      assert spans[0].inputs is not None
483      assert "user_prompt" in spans[0].inputs
484  
485      assert spans[1].name == "InstrumentedModel.request_stream"
486      assert spans[1].span_type == SpanType.LLM
487      assert spans[1].parent_id == spans[0].span_id
488  
489      assert spans[2].span_type == SpanType.TOOL
490      assert spans[2].name == TOOL_MANAGER_SPAN_NAME
491      assert spans[2].parent_id == spans[0].span_id
492  
493      assert spans[3].name == "InstrumentedModel.request_stream"
494      assert spans[3].span_type == SpanType.LLM
495      assert spans[3].parent_id == spans[0].span_id