/ mlflow / dspy / wrapper.py
wrapper.py
  1  import importlib.metadata
  2  import json
  3  from dataclasses import asdict, is_dataclass
  4  from typing import TYPE_CHECKING, Any
  5  
  6  from packaging.version import Version
  7  
  8  if TYPE_CHECKING:
  9      import dspy
 10  
 11  from mlflow.exceptions import INVALID_PARAMETER_VALUE, MlflowException
 12  from mlflow.protos.databricks_pb2 import (
 13      INVALID_PARAMETER_VALUE,
 14  )
 15  from mlflow.pyfunc import PythonModel
 16  from mlflow.types.schema import DataType, Schema
 17  
 18  _INVALID_SIZE_MESSAGE = (
 19      "Dspy model doesn't support batch inference or empty input. Please provide a single input."
 20  )
 21  
 22  
 23  class DspyModelWrapper(PythonModel):
 24      """MLflow PyFunc wrapper class for Dspy models.
 25  
 26      This wrapper serves two purposes:
 27          - It stores the Dspy model along with dspy global settings, which are required for seamless
 28              saving and loading.
 29          - It provides a `predict` method so that it can be loaded as an MLflow pyfunc, which is
 30              used at serving time.
 31      """
 32  
 33      def __init__(
 34          self,
 35          model: "dspy.Module",
 36          dspy_settings: dict[str, Any],
 37          model_config: dict[str, Any] | None = None,
 38      ):
 39          self.model = model
 40          self.dspy_settings = dspy_settings
 41          self.model_config = model_config or {}
 42          self.output_schema: Schema | None = None
 43  
 44      def predict(self, inputs: Any, params: dict[str, Any] | None = None):
 45          import dspy
 46  
 47          converted_inputs = self._get_model_input(inputs)
 48  
 49          with dspy.context(**self.dspy_settings):
 50              if isinstance(converted_inputs, dict):
 51                  # We pass a dict as keyword args and don't allow DSPy models
 52                  # to receive a single dict.
 53                  result = self.model(**converted_inputs)
 54              else:
 55                  result = self.model(converted_inputs)
 56  
 57              if isinstance(result, dspy.Prediction):
 58                  return result.toDict()
 59              else:
 60                  return result
 61  
 62      def predict_stream(self, inputs: Any, params=None):
 63          import dspy
 64  
 65          converted_inputs = self._get_model_input(inputs)
 66  
 67          self._validate_streaming()
 68  
 69          stream_listeners = [
 70              dspy.streaming.StreamListener(signature_field_name=spec.name)
 71              for spec in self.output_schema
 72          ]
 73          stream_model = dspy.streamify(
 74              self.model,
 75              stream_listeners=stream_listeners,
 76              async_streaming=False,
 77              include_final_prediction_in_output_stream=False,
 78          )
 79  
 80          if isinstance(converted_inputs, dict):
 81              outputs = stream_model(**converted_inputs)
 82          else:
 83              outputs = stream_model(converted_inputs)
 84  
 85          with dspy.context(**self.dspy_settings):
 86              for output in outputs:
 87                  if is_dataclass(output):
 88                      yield asdict(output)
 89                  elif isinstance(output, dspy.Prediction):
 90                      yield output.toDict()
 91                  else:
 92                      yield output
 93  
 94      def _get_model_input(self, inputs: Any) -> str | dict[str, Any]:
 95          """Convert the PythonModel input into the DSPy program input
 96  
 97          Examples of expected conversions:
 98          - str -> str
 99          - dict -> dict
100          - np.ndarray with one element -> single element
101          - pd.DataFrame with one row and string column -> single row dict
102          - pd.DataFrame with one row and non-string column -> single element
103          - list -> raises an exception
104          - np.ndarray with more than one element -> raises an exception
105          - pd.DataFrame with more than one row -> raises an exception
106          """
107          import numpy as np
108          import pandas as pd
109  
110          supported_input_types = (np.ndarray, pd.DataFrame, str, dict)
111          if not isinstance(inputs, supported_input_types):
112              raise MlflowException(
113                  f"`inputs` must be one of: {[x.__name__ for x in supported_input_types]}, but "
114                  f"received type: {type(inputs)}.",
115                  INVALID_PARAMETER_VALUE,
116              )
117          if isinstance(inputs, pd.DataFrame):
118              if len(inputs) != 1:
119                  raise MlflowException(
120                      _INVALID_SIZE_MESSAGE,
121                      INVALID_PARAMETER_VALUE,
122                  )
123              if all(isinstance(col, str) for col in inputs.columns):
124                  inputs = inputs.to_dict(orient="records")[0]
125              else:
126                  inputs = inputs.values[0]
127          if isinstance(inputs, np.ndarray):
128              if len(inputs) != 1:
129                  raise MlflowException(
130                      _INVALID_SIZE_MESSAGE,
131                      INVALID_PARAMETER_VALUE,
132                  )
133              inputs = inputs[0]
134  
135          return inputs
136  
137      def _validate_streaming(
138          self,
139      ):
140          if Version(importlib.metadata.version("dspy")) <= Version("2.6.23"):
141              raise MlflowException(
142                  "Streaming API is only supported in dspy 2.6.24 or later. "
143                  "Please upgrade your dspy version."
144              )
145  
146          if self.output_schema is None:
147              raise MlflowException(
148                  "Output schema of the DSPy model is not set. Please log your DSPy "
149                  "model with `signature` or `input_example` to use streaming API.",
150                  error_code=INVALID_PARAMETER_VALUE,
151              )
152  
153          if any(spec.type != DataType.string for spec in self.output_schema):
154              raise MlflowException(
155                  f"All output fields must be string to use streaming API. Got {self.output_schema}.",
156                  error_code=INVALID_PARAMETER_VALUE,
157              )
158  
159  
160  class DspyChatModelWrapper(DspyModelWrapper):
161      """MLflow PyFunc wrapper class for Dspy chat models."""
162  
163      def predict(self, inputs: Any, params: dict[str, Any] | None = None):
164          import dspy
165  
166          converted_inputs = self._get_model_input(inputs)
167  
168          # `dspy.settings` cannot be shared across threads, so we are setting the context at every
169          # predict call.
170          with dspy.context(**self.dspy_settings):
171              outputs = self.model(converted_inputs)
172  
173          choices = []
174          if isinstance(outputs, str):
175              choices.append(self._construct_chat_message("assistant", outputs))
176          elif isinstance(outputs, dict):
177              role = outputs.get("role", "assistant")
178              choices.append(self._construct_chat_message(role, json.dumps(outputs)))
179          elif isinstance(outputs, dspy.Prediction):
180              choices.append(self._construct_chat_message("assistant", json.dumps(outputs.toDict())))
181          elif isinstance(outputs, list):
182              for output in outputs:
183                  if isinstance(output, dict):
184                      role = output.get("role", "assistant")
185                      choices.append(self._construct_chat_message(role, json.dumps(outputs)))
186                  elif isinstance(output, dspy.Prediction):
187                      role = output.get("role", "assistant")
188                      choices.append(self._construct_chat_message(role, json.dumps(outputs.toDict())))
189                  else:
190                      raise MlflowException(
191                          f"Unsupported output type: {type(output)}. To log a DSPy model with task "
192                          "'llm/v1/chat', the DSPy model must return a dict, a dspy.Prediction, or a "
193                          "list of dicts or dspy.Prediction.",
194                          INVALID_PARAMETER_VALUE,
195                      )
196          else:
197              raise MlflowException(
198                  f"Unsupported output type: {type(outputs)}. To log a DSPy model with task "
199                  "'llm/v1/chat', the DSPy model must return a dict, a dspy.Prediction, or a list of "
200                  "dicts or dspy.Prediction.",
201                  INVALID_PARAMETER_VALUE,
202              )
203  
204          return {"choices": choices}
205  
206      def predict_stream(self, inputs: Any, params=None):
207          raise NotImplementedError(
208              "Streaming is not supported for DSPy model with task 'llm/v1/chat'."
209          )
210  
211      def _get_model_input(self, inputs: Any) -> str | list[dict[str, Any]]:
212          import pandas as pd
213  
214          if isinstance(inputs, dict):
215              return inputs["messages"]
216          if isinstance(inputs, pd.DataFrame):
217              return inputs.messages[0]
218  
219          raise MlflowException(
220              f"Unsupported input type: {type(inputs)}. To log a DSPy model with task "
221              "'llm/v1/chat', the input must be a dict or a pandas DataFrame.",
222              INVALID_PARAMETER_VALUE,
223          )
224  
225      def _construct_chat_message(self, role: str, content: str) -> dict[str, Any]:
226          return {
227              "index": 0,
228              "message": {
229                  "role": role,
230                  "content": content,
231              },
232              "finish_reason": "stop",
233          }