/ mlflow / pydantic_ai / __init__.py
__init__.py
  1  import functools
  2  import inspect
  3  import logging
  4  import typing
  5  
  6  from mlflow.pydantic_ai.autolog import (
  7      patched_agent_init,
  8      patched_async_class_call,
  9      patched_async_stream_call,
 10      patched_class_call,
 11      patched_sync_stream_call,
 12  )
 13  from mlflow.telemetry.events import AutologgingEvent
 14  from mlflow.telemetry.track import _record_event
 15  from mlflow.utils.autologging_utils import autologging_integration, safe_patch
 16  from mlflow.utils.autologging_utils.safety import _store_patch, _wrap_patch
 17  
 18  FLAVOR_NAME = "pydantic_ai"
 19  _logger = logging.getLogger(__name__)
 20  
 21  
 22  def _is_async_context_manager_factory(func) -> bool:
 23      wrapped = getattr(func, "__wrapped__", None)
 24      return wrapped is not None and inspect.isasyncgenfunction(wrapped)
 25  
 26  
 27  def _returns_sync_streamed_result(func) -> bool:
 28      if inspect.iscoroutinefunction(func):
 29          return False
 30  
 31      try:
 32          return_annotation = inspect.signature(func).return_annotation
 33      except (ValueError, TypeError):
 34          return False
 35  
 36      if return_annotation is inspect.Signature.empty:
 37          return False
 38  
 39      # pydantic-ai uses `from __future__ import annotations`, so the return
 40      # annotation is a raw string rather than a resolved type. We match by class
 41      # name to avoid calling `get_type_hints()`, which would try to resolve *all*
 42      # parameter annotations (e.g. `AgentSpec` added in 1.71.0) and raise
 43      # NameError for any forward reference that isn't importable at call time.
 44      # `StreamedRunResultSync` is a unique pydantic-ai class name; substring
 45      # matching is sufficient and avoids fragile import-time resolution.
 46      if isinstance(return_annotation, str):
 47          return "StreamedRunResultSync" in return_annotation
 48  
 49      origin = typing.get_origin(return_annotation) or return_annotation
 50      return hasattr(origin, "stream_text") and hasattr(origin, "stream_output")
 51  
 52  
 53  def _patch_streaming_method(cls, method_name, wrapper_func):
 54      original = getattr(cls, method_name)
 55  
 56      @functools.wraps(original)
 57      def patched_method(self, *args, **kwargs):
 58          return wrapper_func(original, self, *args, **kwargs)
 59  
 60      patch = _wrap_patch(cls, method_name, patched_method)
 61      _store_patch(FLAVOR_NAME, patch)
 62  
 63  
 64  def _patch_method(cls, method_name):
 65      method = getattr(cls, method_name)
 66  
 67      if _is_async_context_manager_factory(method):
 68          _patch_streaming_method(cls, method_name, patched_async_stream_call)
 69      elif _returns_sync_streamed_result(method):
 70          _patch_streaming_method(cls, method_name, patched_sync_stream_call)
 71      elif inspect.iscoroutinefunction(method):
 72          safe_patch(FLAVOR_NAME, cls, method_name, patched_async_class_call)
 73      else:
 74          safe_patch(FLAVOR_NAME, cls, method_name, patched_class_call)
 75  
 76  
 77  def _tool_manager_uses_execute_tool_call() -> bool:
 78      """Return True if ToolManager has execute_tool_call (pydantic-ai >= 1.63.0).
 79  
 80      In pydantic-ai >= 1.63.0, _agent_graph._call_tool() calls
 81      tool_manager.execute_tool_call() directly rather than handle_call(), so we
 82      must patch execute_tool_call to capture the TOOL span.
 83      """
 84      try:
 85          from pydantic_ai._tool_manager import ToolManager
 86  
 87          return hasattr(ToolManager, "execute_tool_call")
 88      except ImportError:
 89          return False
 90  
 91  
 92  @autologging_integration(FLAVOR_NAME)
 93  def autolog(log_traces: bool = True, disable: bool = False, silent: bool = False):
 94      """
 95      Enable (or disable) autologging for Pydantic_AI.
 96  
 97      Args:
 98          log_traces: If True, capture spans for agent + model calls.
 99          disable:   If True, disable the autologging patches.
100          silent:    If True, suppress MLflow warnings/info.
101      """
102      # Base methods that exist in all supported versions
103      agent_methods = ["run", "run_sync", "run_stream"]
104  
105      try:
106          from pydantic_ai import Agent
107  
108          # run_stream_sync was added in pydantic-ai 1.10.0
109          if hasattr(Agent, "run_stream_sync"):
110              agent_methods.append("run_stream_sync")
111      except ImportError:
112          pass
113  
114      class_map = {
115          "pydantic_ai.Agent": agent_methods,
116          "pydantic_ai.models.instrumented.InstrumentedModel": [
117              "request",
118              "request_stream",
119          ],
120          # In pydantic-ai >= 1.63.0, _agent_graph calls execute_tool_call directly,
121          # bypassing handle_call. Patch execute_tool_call when available; fall back to
122          # handle_call for older versions where execute_tool_call doesn't exist.
123          "pydantic_ai._tool_manager.ToolManager": ["execute_tool_call"]
124          if _tool_manager_uses_execute_tool_call()
125          else ["handle_call"],
126          "pydantic_ai.mcp.MCPServer": ["call_tool", "list_tools"],
127      }
128  
129      try:
130          from pydantic_ai import Tool
131  
132          # Tool.run method is removed in recent versions
133          if hasattr(Tool, "run"):
134              class_map["pydantic_ai.Tool"] = ["run"]
135      except ImportError:
136          pass
137  
138      # Patch Agent.__init__ to auto-enable instrument=True so LLM spans
139      # are captured without requiring users to explicitly set it
140      try:
141          from pydantic_ai import Agent
142  
143          original_init = Agent.__init__
144  
145          @functools.wraps(original_init)
146          def patched_init(self, *args, **kwargs):
147              return patched_agent_init(original_init, self, *args, **kwargs)
148  
149          patch = _wrap_patch(Agent, "__init__", patched_init)
150          _store_patch(FLAVOR_NAME, patch)
151      except (ImportError, AttributeError) as e:
152          _logger.error("Error patching Agent.__init__: %s", e)
153  
154      for cls_path, methods in class_map.items():
155          module_name, class_name = cls_path.rsplit(".", 1)
156          try:
157              module = __import__(module_name, fromlist=[class_name])
158              cls = getattr(module, class_name)
159          except (ImportError, AttributeError) as e:
160              _logger.error("Error importing %s: %s", cls_path, e)
161              continue
162  
163          for method in methods:
164              try:
165                  _patch_method(cls, method)
166              except AttributeError as e:
167                  _logger.error("Error patching %s.%s: %s", cls_path, method, e)
168  
169      _record_event(
170          AutologgingEvent, {"flavor": FLAVOR_NAME, "log_traces": log_traces, "disable": disable}
171      )