responses_agent.py
1 from typing import Any, Generator 2 3 import pydantic 4 5 from mlflow.exceptions import MlflowException 6 from mlflow.models.utils import _convert_llm_ndarray_to_list 7 from mlflow.protos.databricks_pb2 import INTERNAL_ERROR 8 from mlflow.pyfunc.model import _load_context_model_and_signature 9 from mlflow.types.responses import ( 10 ResponsesAgentRequest, 11 ResponsesAgentResponse, 12 ResponsesAgentStreamEvent, 13 ) 14 from mlflow.types.type_hints import model_validate 15 16 17 def _load_pyfunc(model_path: str, model_config: dict[str, Any] | None = None): 18 context, responses_agent, _ = _load_context_model_and_signature(model_path, model_config) 19 return _ResponsesAgentPyfuncWrapper(responses_agent, context) 20 21 22 class _ResponsesAgentPyfuncWrapper: 23 """ 24 Wrapper class that converts dict inputs to pydantic objects accepted by 25 :class:`~ResponsesAgent`. 26 """ 27 28 def __init__(self, responses_agent, context): 29 self.responses_agent = responses_agent 30 self.context = context 31 32 def get_raw_model(self): 33 """ 34 Returns the underlying model. 35 """ 36 return self.responses_agent 37 38 def _convert_input(self, model_input) -> ResponsesAgentRequest: 39 import pandas 40 41 if isinstance(model_input, pandas.DataFrame): 42 model_input = { 43 k: _convert_llm_ndarray_to_list(v[0]) 44 for k, v in model_input.to_dict(orient="list").items() 45 } 46 elif not isinstance(model_input, dict): 47 raise MlflowException( 48 "Unsupported model input type. Expected a dict or pandas.DataFrame, but got " 49 f"{type(model_input)} instead.", 50 error_code=INTERNAL_ERROR, 51 ) 52 return ResponsesAgentRequest(**model_input) 53 54 def _response_to_dict(self, response, pydantic_class) -> dict[str, Any]: 55 if isinstance(response, pydantic_class): 56 return response.model_dump(exclude_none=True) 57 try: 58 model_validate(pydantic_class, response) 59 except pydantic.ValidationError as e: 60 raise MlflowException( 61 message=( 62 f"Model returned an invalid response. Expected a {pydantic_class.__name__} " 63 f"object or dictionary with the same schema. Pydantic validation error: {e}" 64 ), 65 error_code=INTERNAL_ERROR, 66 ) from e 67 return response 68 69 def predict(self, model_input: dict[str, Any], params=None) -> dict[str, Any]: 70 """ 71 Args: 72 model_input: A dict with the 73 :py:class:`ResponsesRequest <mlflow.types.responses.ResponsesRequest>` schema. 74 params: Unused in this function, but required in the signature because 75 `load_model_and_predict` in `utils/_capture_modules.py` expects a params field 76 77 Returns: 78 A dict with the 79 (:py:class:`ResponsesResponse <mlflow.types.responses.ResponsesResponse>`) 80 schema. 81 """ 82 request = self._convert_input(model_input) 83 response = self.responses_agent.predict(request) 84 return self._response_to_dict(response, ResponsesAgentResponse) 85 86 def predict_stream( 87 self, model_input: dict[str, Any], params=None 88 ) -> Generator[dict[str, Any], None, None]: 89 """ 90 Args: 91 model_input: A dict with the 92 :py:class:`ResponsesRequest <mlflow.types.responses.ResponsesRequest>` schema. 93 params: Unused in this function, but required in the signature because 94 `load_model_and_predict` in `utils/_capture_modules.py` expects a params field 95 96 Returns: 97 A generator over dicts with the 98 (:py:class:`ResponsesStreamEvent <mlflow.types.responses.ResponsesStreamEvent>`) 99 schema. 100 """ 101 request = self._convert_input(model_input) 102 for response in self.responses_agent.predict_stream(request): 103 yield self._response_to_dict(response, ResponsesAgentStreamEvent)