/ mlflow / pyfunc / loaders / responses_agent.py
responses_agent.py
  1  from typing import Any, Generator
  2  
  3  import pydantic
  4  
  5  from mlflow.exceptions import MlflowException
  6  from mlflow.models.utils import _convert_llm_ndarray_to_list
  7  from mlflow.protos.databricks_pb2 import INTERNAL_ERROR
  8  from mlflow.pyfunc.model import _load_context_model_and_signature
  9  from mlflow.types.responses import (
 10      ResponsesAgentRequest,
 11      ResponsesAgentResponse,
 12      ResponsesAgentStreamEvent,
 13  )
 14  from mlflow.types.type_hints import model_validate
 15  
 16  
 17  def _load_pyfunc(model_path: str, model_config: dict[str, Any] | None = None):
 18      context, responses_agent, _ = _load_context_model_and_signature(model_path, model_config)
 19      return _ResponsesAgentPyfuncWrapper(responses_agent, context)
 20  
 21  
 22  class _ResponsesAgentPyfuncWrapper:
 23      """
 24      Wrapper class that converts dict inputs to pydantic objects accepted by
 25      :class:`~ResponsesAgent`.
 26      """
 27  
 28      def __init__(self, responses_agent, context):
 29          self.responses_agent = responses_agent
 30          self.context = context
 31  
 32      def get_raw_model(self):
 33          """
 34          Returns the underlying model.
 35          """
 36          return self.responses_agent
 37  
 38      def _convert_input(self, model_input) -> ResponsesAgentRequest:
 39          import pandas
 40  
 41          if isinstance(model_input, pandas.DataFrame):
 42              model_input = {
 43                  k: _convert_llm_ndarray_to_list(v[0])
 44                  for k, v in model_input.to_dict(orient="list").items()
 45              }
 46          elif not isinstance(model_input, dict):
 47              raise MlflowException(
 48                  "Unsupported model input type. Expected a dict or pandas.DataFrame, but got "
 49                  f"{type(model_input)} instead.",
 50                  error_code=INTERNAL_ERROR,
 51              )
 52          return ResponsesAgentRequest(**model_input)
 53  
 54      def _response_to_dict(self, response, pydantic_class) -> dict[str, Any]:
 55          if isinstance(response, pydantic_class):
 56              return response.model_dump(exclude_none=True)
 57          try:
 58              model_validate(pydantic_class, response)
 59          except pydantic.ValidationError as e:
 60              raise MlflowException(
 61                  message=(
 62                      f"Model returned an invalid response. Expected a {pydantic_class.__name__} "
 63                      f"object or dictionary with the same schema. Pydantic validation error: {e}"
 64                  ),
 65                  error_code=INTERNAL_ERROR,
 66              ) from e
 67          return response
 68  
 69      def predict(self, model_input: dict[str, Any], params=None) -> dict[str, Any]:
 70          """
 71          Args:
 72              model_input: A dict with the
 73                  :py:class:`ResponsesRequest <mlflow.types.responses.ResponsesRequest>` schema.
 74              params: Unused in this function, but required in the signature because
 75                  `load_model_and_predict` in `utils/_capture_modules.py` expects a params field
 76  
 77          Returns:
 78              A dict with the
 79              (:py:class:`ResponsesResponse <mlflow.types.responses.ResponsesResponse>`)
 80              schema.
 81          """
 82          request = self._convert_input(model_input)
 83          response = self.responses_agent.predict(request)
 84          return self._response_to_dict(response, ResponsesAgentResponse)
 85  
 86      def predict_stream(
 87          self, model_input: dict[str, Any], params=None
 88      ) -> Generator[dict[str, Any], None, None]:
 89          """
 90          Args:
 91              model_input: A dict with the
 92                  :py:class:`ResponsesRequest <mlflow.types.responses.ResponsesRequest>` schema.
 93              params: Unused in this function, but required in the signature because
 94                  `load_model_and_predict` in `utils/_capture_modules.py` expects a params field
 95  
 96          Returns:
 97              A generator over dicts with the
 98                  (:py:class:`ResponsesStreamEvent <mlflow.types.responses.ResponsesStreamEvent>`)
 99                  schema.
100          """
101          request = self._convert_input(model_input)
102          for response in self.responses_agent.predict_stream(request):
103              yield self._response_to_dict(response, ResponsesAgentStreamEvent)