__init__.py
1 import functools 2 import inspect 3 import logging 4 import typing 5 6 from mlflow.pydantic_ai.autolog import ( 7 patched_agent_init, 8 patched_async_class_call, 9 patched_async_stream_call, 10 patched_class_call, 11 patched_sync_stream_call, 12 ) 13 from mlflow.telemetry.events import AutologgingEvent 14 from mlflow.telemetry.track import _record_event 15 from mlflow.utils.autologging_utils import autologging_integration, safe_patch 16 from mlflow.utils.autologging_utils.safety import _store_patch, _wrap_patch 17 18 FLAVOR_NAME = "pydantic_ai" 19 _logger = logging.getLogger(__name__) 20 21 22 def _is_async_context_manager_factory(func) -> bool: 23 wrapped = getattr(func, "__wrapped__", None) 24 return wrapped is not None and inspect.isasyncgenfunction(wrapped) 25 26 27 def _returns_sync_streamed_result(func) -> bool: 28 if inspect.iscoroutinefunction(func): 29 return False 30 31 try: 32 return_annotation = inspect.signature(func).return_annotation 33 except (ValueError, TypeError): 34 return False 35 36 if return_annotation is inspect.Signature.empty: 37 return False 38 39 # pydantic-ai uses `from __future__ import annotations`, so the return 40 # annotation is a raw string rather than a resolved type. We match by class 41 # name to avoid calling `get_type_hints()`, which would try to resolve *all* 42 # parameter annotations (e.g. `AgentSpec` added in 1.71.0) and raise 43 # NameError for any forward reference that isn't importable at call time. 44 # `StreamedRunResultSync` is a unique pydantic-ai class name; substring 45 # matching is sufficient and avoids fragile import-time resolution. 46 if isinstance(return_annotation, str): 47 return "StreamedRunResultSync" in return_annotation 48 49 origin = typing.get_origin(return_annotation) or return_annotation 50 return hasattr(origin, "stream_text") and hasattr(origin, "stream_output") 51 52 53 def _patch_streaming_method(cls, method_name, wrapper_func): 54 original = getattr(cls, method_name) 55 56 @functools.wraps(original) 57 def patched_method(self, *args, **kwargs): 58 return wrapper_func(original, self, *args, **kwargs) 59 60 patch = _wrap_patch(cls, method_name, patched_method) 61 _store_patch(FLAVOR_NAME, patch) 62 63 64 def _patch_method(cls, method_name): 65 method = getattr(cls, method_name) 66 67 if _is_async_context_manager_factory(method): 68 _patch_streaming_method(cls, method_name, patched_async_stream_call) 69 elif _returns_sync_streamed_result(method): 70 _patch_streaming_method(cls, method_name, patched_sync_stream_call) 71 elif inspect.iscoroutinefunction(method): 72 safe_patch(FLAVOR_NAME, cls, method_name, patched_async_class_call) 73 else: 74 safe_patch(FLAVOR_NAME, cls, method_name, patched_class_call) 75 76 77 def _tool_manager_uses_execute_tool_call() -> bool: 78 """Return True if ToolManager has execute_tool_call (pydantic-ai >= 1.63.0). 79 80 In pydantic-ai >= 1.63.0, _agent_graph._call_tool() calls 81 tool_manager.execute_tool_call() directly rather than handle_call(), so we 82 must patch execute_tool_call to capture the TOOL span. 83 """ 84 try: 85 from pydantic_ai._tool_manager import ToolManager 86 87 return hasattr(ToolManager, "execute_tool_call") 88 except ImportError: 89 return False 90 91 92 @autologging_integration(FLAVOR_NAME) 93 def autolog(log_traces: bool = True, disable: bool = False, silent: bool = False): 94 """ 95 Enable (or disable) autologging for Pydantic_AI. 96 97 Args: 98 log_traces: If True, capture spans for agent + model calls. 99 disable: If True, disable the autologging patches. 100 silent: If True, suppress MLflow warnings/info. 101 """ 102 # Base methods that exist in all supported versions 103 agent_methods = ["run", "run_sync", "run_stream"] 104 105 try: 106 from pydantic_ai import Agent 107 108 # run_stream_sync was added in pydantic-ai 1.10.0 109 if hasattr(Agent, "run_stream_sync"): 110 agent_methods.append("run_stream_sync") 111 except ImportError: 112 pass 113 114 class_map = { 115 "pydantic_ai.Agent": agent_methods, 116 "pydantic_ai.models.instrumented.InstrumentedModel": [ 117 "request", 118 "request_stream", 119 ], 120 # In pydantic-ai >= 1.63.0, _agent_graph calls execute_tool_call directly, 121 # bypassing handle_call. Patch execute_tool_call when available; fall back to 122 # handle_call for older versions where execute_tool_call doesn't exist. 123 "pydantic_ai._tool_manager.ToolManager": ["execute_tool_call"] 124 if _tool_manager_uses_execute_tool_call() 125 else ["handle_call"], 126 "pydantic_ai.mcp.MCPServer": ["call_tool", "list_tools"], 127 } 128 129 try: 130 from pydantic_ai import Tool 131 132 # Tool.run method is removed in recent versions 133 if hasattr(Tool, "run"): 134 class_map["pydantic_ai.Tool"] = ["run"] 135 except ImportError: 136 pass 137 138 # Patch Agent.__init__ to auto-enable instrument=True so LLM spans 139 # are captured without requiring users to explicitly set it 140 try: 141 from pydantic_ai import Agent 142 143 original_init = Agent.__init__ 144 145 @functools.wraps(original_init) 146 def patched_init(self, *args, **kwargs): 147 return patched_agent_init(original_init, self, *args, **kwargs) 148 149 patch = _wrap_patch(Agent, "__init__", patched_init) 150 _store_patch(FLAVOR_NAME, patch) 151 except (ImportError, AttributeError) as e: 152 _logger.error("Error patching Agent.__init__: %s", e) 153 154 for cls_path, methods in class_map.items(): 155 module_name, class_name = cls_path.rsplit(".", 1) 156 try: 157 module = __import__(module_name, fromlist=[class_name]) 158 cls = getattr(module, class_name) 159 except (ImportError, AttributeError) as e: 160 _logger.error("Error importing %s: %s", cls_path, e) 161 continue 162 163 for method in methods: 164 try: 165 _patch_method(cls, method) 166 except AttributeError as e: 167 _logger.error("Error patching %s.%s: %s", cls_path, method, e) 168 169 _record_event( 170 AutologgingEvent, {"flavor": FLAVOR_NAME, "log_traces": log_traces, "disable": disable} 171 )