test_pydanticai_fluent_tracing.py
1 import importlib.metadata 2 from contextlib import asynccontextmanager 3 from unittest.mock import patch 4 5 import pytest 6 from packaging.version import Version 7 from pydantic_ai import Agent, RunContext 8 from pydantic_ai.messages import ModelResponse, TextPart, ToolCallPart 9 from pydantic_ai.models.instrumented import InstrumentedModel 10 from pydantic_ai.usage import Usage 11 12 import mlflow 13 import mlflow.pydantic_ai # ensure the integration module is importable 14 from mlflow.entities import SpanType 15 from mlflow.tracing.constant import SpanAttributeKey 16 17 from tests.tracing.helper import get_traces 18 19 _FINAL_ANSWER_WITHOUT_TOOL = "Paris" 20 _FINAL_ANSWER_WITH_TOOL = "winner" 21 22 PYDANTIC_AI_VERSION = Version(importlib.metadata.version("pydantic_ai")) 23 # Usage was deprecated in favor of RequestUsage in 0.7.3 24 IS_USAGE_DEPRECATED = PYDANTIC_AI_VERSION >= Version("0.7.3") 25 # run_stream_sync was added in pydantic-ai 1.10.0 26 HAS_RUN_STREAM_SYNC = hasattr(Agent, "run_stream_sync") 27 # Streaming tests require pydantic-ai >= 1.0.0 due to API changes 28 HAS_STABLE_STREAMING_API = PYDANTIC_AI_VERSION >= Version("1.0.0") 29 # In pydantic-ai >= 1.63.0, _agent_graph calls execute_tool_call directly instead of handle_call. 30 # _tool_manager module doesn't exist in older versions (e.g. 0.2.x). 31 try: 32 from pydantic_ai._tool_manager import ToolManager as _ToolManager 33 34 TOOL_MANAGER_SPAN_NAME = ( 35 "ToolManager.execute_tool_call" 36 if hasattr(_ToolManager, "execute_tool_call") 37 else "ToolManager.handle_call" 38 ) 39 except ImportError: 40 TOOL_MANAGER_SPAN_NAME = "ToolManager.handle_call" 41 42 43 def _make_dummy_response_without_tool(): 44 if IS_USAGE_DEPRECATED: 45 from pydantic_ai.usage import RequestUsage 46 47 parts = [TextPart(content=_FINAL_ANSWER_WITHOUT_TOOL)] 48 resp = ModelResponse(parts=parts) 49 if IS_USAGE_DEPRECATED: 50 usage = RequestUsage(input_tokens=1, output_tokens=1) 51 else: 52 usage = Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=2) 53 54 if PYDANTIC_AI_VERSION >= Version("0.2.0"): 55 return ModelResponse(parts=parts, usage=usage) 56 else: 57 return resp, usage 58 59 60 def _make_dummy_response_with_tool(): 61 if IS_USAGE_DEPRECATED: 62 from pydantic_ai.usage import RequestUsage 63 64 call_parts = [ToolCallPart(tool_name="roulette_wheel", args={"square": 18})] 65 final_parts = [TextPart(content=_FINAL_ANSWER_WITH_TOOL)] 66 67 if IS_USAGE_DEPRECATED: 68 usage_call = RequestUsage(input_tokens=10, output_tokens=20) 69 usage_final = RequestUsage(input_tokens=100, output_tokens=200) 70 else: 71 usage_call = Usage(requests=0, request_tokens=10, response_tokens=20, total_tokens=30) 72 usage_final = Usage(requests=1, request_tokens=100, response_tokens=200, total_tokens=300) 73 74 if PYDANTIC_AI_VERSION >= Version("0.2.0"): 75 call_resp = ModelResponse(parts=call_parts, usage=usage_call) 76 final_resp = ModelResponse(parts=final_parts, usage=usage_final) 77 yield call_resp 78 yield final_resp 79 80 else: 81 call_resp = ModelResponse(parts=call_parts) 82 final_resp = ModelResponse(parts=final_parts) 83 yield call_resp, usage_call 84 yield final_resp, usage_final 85 86 87 def _make_streaming_response_without_tool(input_tokens=10, output_tokens=5): 88 if IS_USAGE_DEPRECATED: 89 from pydantic_ai.usage import RequestUsage 90 91 usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens) 92 else: 93 usage = Usage( 94 requests=1, 95 request_tokens=input_tokens, 96 response_tokens=output_tokens, 97 total_tokens=input_tokens + output_tokens, 98 ) 99 100 return ModelResponse(parts=[TextPart(content=_FINAL_ANSWER_WITHOUT_TOOL)], usage=usage), usage 101 102 103 def _make_streaming_response_with_tool(): 104 if IS_USAGE_DEPRECATED: 105 from pydantic_ai.usage import RequestUsage 106 107 usage_call = RequestUsage(input_tokens=10, output_tokens=20) 108 usage_final = RequestUsage(input_tokens=100, output_tokens=200) 109 else: 110 usage_call = Usage(requests=0, request_tokens=10, response_tokens=20, total_tokens=30) 111 usage_final = Usage(requests=1, request_tokens=100, response_tokens=200, total_tokens=300) 112 113 call_resp = ModelResponse( 114 parts=[ToolCallPart(tool_name="roulette_wheel", args={"square": 18})], 115 usage=usage_call, 116 ) 117 final_resp = ModelResponse( 118 parts=[TextPart(content=_FINAL_ANSWER_WITH_TOOL)], 119 usage=usage_final, 120 ) 121 122 return [call_resp, final_resp] 123 124 125 class MockStreamedResponse: 126 def __init__(self, response, usage): 127 self._response = response 128 self._usage = usage 129 self.model_name = "openai:gpt-4o" 130 self.timestamp = None 131 132 def usage(self): 133 return self._usage 134 135 def get(self): 136 return self._response 137 138 async def __aiter__(self): 139 for part in self._response.parts: 140 if hasattr(part, "content"): 141 for char in part.content: 142 yield char 143 else: 144 yield "" 145 146 147 @pytest.fixture(autouse=True) 148 def clear_autolog_state(): 149 from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS 150 151 for key in AUTOLOGGING_INTEGRATIONS.keys(): 152 AUTOLOGGING_INTEGRATIONS[key].clear() 153 mlflow.utils.import_hooks._post_import_hooks = {} 154 155 156 @pytest.fixture 157 def simple_agent(): 158 return Agent( 159 "openai:gpt-4o", 160 system_prompt="Tell me the capital of {{input}}.", 161 instrument=True, 162 ) 163 164 165 @pytest.fixture 166 def agent_with_tool(): 167 roulette_agent = Agent( 168 "openai:gpt-4o", 169 system_prompt=( 170 "Use the roulette_wheel function to see if the " 171 "customer has won based on the number they provide." 172 ), 173 instrument=True, 174 deps_type=int, 175 output_type=str, 176 ) 177 178 @roulette_agent.tool 179 async def roulette_wheel(ctx: RunContext[int], square: int) -> str: 180 """check if the square is a winner""" 181 return "winner" if square == ctx.deps else "loser" 182 183 return roulette_agent 184 185 186 def test_agent_run_sync_enable_fluent_disable_autolog(simple_agent): 187 dummy = _make_dummy_response_without_tool() 188 189 async def request(self, *args, **kwargs): 190 return dummy 191 192 with patch.object(InstrumentedModel, "request", new=request): 193 mlflow.pydantic_ai.autolog(log_traces=True) 194 195 result = simple_agent.run_sync("France") 196 assert result.output == _FINAL_ANSWER_WITHOUT_TOOL 197 198 traces = get_traces() 199 assert len(traces) == 1 200 spans = traces[0].data.spans 201 202 assert spans[0].name == "Agent.run_sync" 203 assert spans[0].span_type == SpanType.AGENT 204 205 assert spans[1].name == "Agent.run" 206 assert spans[1].span_type == SpanType.AGENT 207 208 span2 = spans[2] 209 assert span2.name == "InstrumentedModel.request" 210 assert span2.span_type == SpanType.LLM 211 assert span2.parent_id == spans[1].span_id 212 213 with patch.object(InstrumentedModel, "request", new=request): 214 mlflow.pydantic_ai.autolog(disable=True) 215 simple_agent.run_sync("France") 216 assert len(get_traces()) == 1 217 218 219 @pytest.mark.asyncio 220 async def test_agent_run_enable_fluent_disable_autolog(simple_agent): 221 dummy = _make_dummy_response_without_tool() 222 223 async def request(self, *args, **kwargs): 224 return dummy 225 226 with patch.object(InstrumentedModel, "request", new=request): 227 mlflow.pydantic_ai.autolog(log_traces=True) 228 229 result = await simple_agent.run("France") 230 assert result.output == _FINAL_ANSWER_WITHOUT_TOOL 231 232 traces = get_traces() 233 assert len(traces) == 1 234 spans = traces[0].data.spans 235 236 assert spans[0].name == "Agent.run" 237 assert spans[0].span_type == SpanType.AGENT 238 239 span1 = spans[1] 240 assert span1.name == "InstrumentedModel.request" 241 assert span1.span_type == SpanType.LLM 242 assert span1.parent_id == spans[0].span_id 243 244 245 def test_agent_run_sync_enable_disable_fluent_autolog_with_tool(agent_with_tool): 246 sequence = _make_dummy_response_with_tool() 247 248 async def request(self, *args, **kwargs): 249 return next(sequence) 250 251 with patch.object(InstrumentedModel, "request", new=request): 252 mlflow.pydantic_ai.autolog(log_traces=True) 253 254 result = agent_with_tool.run_sync("Put my money on square eighteen", deps=18) 255 assert result.output == _FINAL_ANSWER_WITH_TOOL 256 257 traces = get_traces() 258 assert len(traces) == 1 259 spans = traces[0].data.spans 260 261 assert len(spans) == 5 262 263 assert spans[0].name == "Agent.run_sync" 264 assert spans[0].span_type == SpanType.AGENT 265 266 assert spans[1].name == "Agent.run" 267 assert spans[1].span_type == SpanType.AGENT 268 269 span2 = spans[2] 270 assert span2.name == "InstrumentedModel.request" 271 assert span2.span_type == SpanType.LLM 272 assert span2.parent_id == spans[1].span_id 273 274 span3 = spans[3] 275 assert span3.span_type == SpanType.TOOL 276 assert span3.parent_id == spans[1].span_id 277 278 span4 = spans[4] 279 assert span4.name == "InstrumentedModel.request" 280 assert span4.span_type == SpanType.LLM 281 assert span4.parent_id == spans[1].span_id 282 283 284 @pytest.mark.asyncio 285 async def test_agent_run_enable_disable_fluent_autolog_with_tool(agent_with_tool): 286 sequence = _make_dummy_response_with_tool() 287 288 async def request(self, *args, **kwargs): 289 return next(sequence) 290 291 with patch.object(InstrumentedModel, "request", new=request): 292 mlflow.pydantic_ai.autolog(log_traces=True) 293 294 result = await agent_with_tool.run("Put my money on square eighteen", deps=18) 295 assert result.output == _FINAL_ANSWER_WITH_TOOL 296 297 traces = get_traces() 298 assert len(traces) == 1 299 spans = traces[0].data.spans 300 301 assert len(spans) == 4 302 303 assert spans[0].name == "Agent.run" 304 assert spans[0].span_type == SpanType.AGENT 305 306 span1 = spans[1] 307 assert span1.name == "InstrumentedModel.request" 308 assert span1.span_type == SpanType.LLM 309 assert span1.parent_id == spans[0].span_id 310 311 span2 = spans[2] 312 assert span2.span_type == SpanType.TOOL 313 assert span2.parent_id == spans[0].span_id 314 315 span3 = spans[3] 316 assert span3.name == "InstrumentedModel.request" 317 assert span3.span_type == SpanType.LLM 318 assert span3.parent_id == spans[0].span_id 319 320 321 @pytest.mark.skipif( 322 not HAS_STABLE_STREAMING_API, reason="Streaming API stabilized in pydantic-ai 1.0.0" 323 ) 324 @pytest.mark.asyncio 325 async def test_agent_run_stream_creates_trace(simple_agent): 326 response, usage = _make_streaming_response_without_tool(input_tokens=10, output_tokens=5) 327 328 @asynccontextmanager 329 async def request_stream(self, *args, **kwargs): 330 yield MockStreamedResponse(response, usage) 331 332 with patch.object(InstrumentedModel, "request_stream", new=request_stream): 333 mlflow.pydantic_ai.autolog(log_traces=True) 334 335 async with simple_agent.run_stream("France") as result: 336 output = await result.get_output() 337 assert output == _FINAL_ANSWER_WITHOUT_TOOL 338 339 traces = get_traces() 340 assert len(traces) == 1 341 spans = traces[0].data.spans 342 343 assert len(spans) == 2 344 345 assert spans[0].name == "Agent.run_stream" 346 assert spans[0].span_type == SpanType.AGENT 347 348 assert spans[1].name == "InstrumentedModel.request_stream" 349 assert spans[1].span_type == SpanType.LLM 350 assert spans[1].parent_id == spans[0].span_id 351 352 usage_attr = spans[0].attributes.get(SpanAttributeKey.CHAT_USAGE) 353 assert usage_attr is not None 354 assert usage_attr.get("input_tokens") == 10 355 assert usage_attr.get("output_tokens") == 5 356 assert usage_attr.get("total_tokens") == 15 357 358 359 @pytest.mark.skipif( 360 not HAS_STABLE_STREAMING_API, reason="Streaming API stabilized in pydantic-ai 1.0.0" 361 ) 362 @pytest.mark.skipif(not HAS_RUN_STREAM_SYNC, reason="run_stream_sync added in pydantic-ai 1.10.0") 363 def test_agent_run_stream_sync_creates_trace(simple_agent): 364 response, usage = _make_streaming_response_without_tool(input_tokens=10, output_tokens=5) 365 366 @asynccontextmanager 367 async def request_stream(self, *args, **kwargs): 368 yield MockStreamedResponse(response, usage) 369 370 with patch.object(InstrumentedModel, "request_stream", new=request_stream): 371 mlflow.pydantic_ai.autolog(log_traces=True) 372 373 result = simple_agent.run_stream_sync("France") 374 output = "" 375 for text in result.stream_text(): 376 output += text 377 378 assert output == _FINAL_ANSWER_WITHOUT_TOOL 379 380 traces = get_traces() 381 assert len(traces) == 1 382 spans = traces[0].data.spans 383 384 assert len(spans) == 2 385 386 assert spans[0].name == "Agent.run_stream_sync" 387 assert spans[0].span_type == SpanType.AGENT 388 assert spans[0].inputs is not None 389 assert "user_prompt" in spans[0].inputs 390 assert spans[0].outputs is not None 391 392 assert spans[1].name == "InstrumentedModel.request_stream" 393 assert spans[1].span_type == SpanType.LLM 394 assert spans[1].parent_id == spans[0].span_id 395 396 usage_attr = spans[0].attributes.get(SpanAttributeKey.CHAT_USAGE) 397 assert usage_attr is not None 398 assert usage_attr.get("input_tokens") == 10 399 assert usage_attr.get("output_tokens") == 5 400 assert usage_attr.get("total_tokens") == 15 401 402 403 @pytest.mark.skipif( 404 not HAS_STABLE_STREAMING_API, reason="Streaming API stabilized in pydantic-ai 1.0.0" 405 ) 406 @pytest.mark.asyncio 407 async def test_agent_run_stream_with_tool(agent_with_tool): 408 sequence = _make_streaming_response_with_tool() 409 410 @asynccontextmanager 411 async def request_stream(self, *args, **kwargs): 412 if sequence: 413 resp = sequence.pop(0) 414 yield MockStreamedResponse(resp, resp.usage) 415 else: 416 resp = sequence[-1] 417 yield MockStreamedResponse(resp, resp.usage) 418 419 with patch.object(InstrumentedModel, "request_stream", new=request_stream): 420 mlflow.pydantic_ai.autolog(log_traces=True) 421 422 async with agent_with_tool.run_stream("Put my money on square eighteen", deps=18) as result: 423 output = await result.get_output() 424 assert output == _FINAL_ANSWER_WITH_TOOL 425 426 traces = get_traces() 427 assert len(traces) == 1 428 spans = traces[0].data.spans 429 430 assert len(spans) == 4 431 432 assert spans[0].name == "Agent.run_stream" 433 assert spans[0].span_type == SpanType.AGENT 434 435 assert spans[1].name == "InstrumentedModel.request_stream" 436 assert spans[1].span_type == SpanType.LLM 437 assert spans[1].parent_id == spans[0].span_id 438 439 assert spans[2].span_type == SpanType.TOOL 440 assert spans[2].name == TOOL_MANAGER_SPAN_NAME 441 assert spans[2].parent_id == spans[0].span_id 442 443 assert spans[3].name == "InstrumentedModel.request_stream" 444 assert spans[3].span_type == SpanType.LLM 445 assert spans[3].parent_id == spans[0].span_id 446 447 448 @pytest.mark.skipif( 449 not HAS_STABLE_STREAMING_API, reason="Streaming API stabilized in pydantic-ai 1.0.0" 450 ) 451 @pytest.mark.skipif(not HAS_RUN_STREAM_SYNC, reason="run_stream_sync added in pydantic-ai 1.10.0") 452 def test_agent_run_stream_sync_with_tool(agent_with_tool): 453 sequence = _make_streaming_response_with_tool() 454 455 @asynccontextmanager 456 async def request_stream(self, *args, **kwargs): 457 if sequence: 458 resp = sequence.pop(0) 459 yield MockStreamedResponse(resp, resp.usage) 460 else: 461 resp = sequence[-1] 462 yield MockStreamedResponse(resp, resp.usage) 463 464 with patch.object(InstrumentedModel, "request_stream", new=request_stream): 465 mlflow.pydantic_ai.autolog(log_traces=True) 466 467 result = agent_with_tool.run_stream_sync("Put my money on square eighteen", deps=18) 468 output = "" 469 for text in result.stream_text(): 470 output += text 471 472 assert output == _FINAL_ANSWER_WITH_TOOL 473 474 traces = get_traces() 475 assert len(traces) == 1 476 spans = traces[0].data.spans 477 478 assert len(spans) == 4 479 480 assert spans[0].name == "Agent.run_stream_sync" 481 assert spans[0].span_type == SpanType.AGENT 482 assert spans[0].inputs is not None 483 assert "user_prompt" in spans[0].inputs 484 485 assert spans[1].name == "InstrumentedModel.request_stream" 486 assert spans[1].span_type == SpanType.LLM 487 assert spans[1].parent_id == spans[0].span_id 488 489 assert spans[2].span_type == SpanType.TOOL 490 assert spans[2].name == TOOL_MANAGER_SPAN_NAME 491 assert spans[2].parent_id == spans[0].span_id 492 493 assert spans[3].name == "InstrumentedModel.request_stream" 494 assert spans[3].span_type == SpanType.LLM 495 assert spans[3].parent_id == spans[0].span_id