wrapper.py
1 import importlib.metadata 2 import json 3 from dataclasses import asdict, is_dataclass 4 from typing import TYPE_CHECKING, Any 5 6 from packaging.version import Version 7 8 if TYPE_CHECKING: 9 import dspy 10 11 from mlflow.exceptions import INVALID_PARAMETER_VALUE, MlflowException 12 from mlflow.protos.databricks_pb2 import ( 13 INVALID_PARAMETER_VALUE, 14 ) 15 from mlflow.pyfunc import PythonModel 16 from mlflow.types.schema import DataType, Schema 17 18 _INVALID_SIZE_MESSAGE = ( 19 "Dspy model doesn't support batch inference or empty input. Please provide a single input." 20 ) 21 22 23 class DspyModelWrapper(PythonModel): 24 """MLflow PyFunc wrapper class for Dspy models. 25 26 This wrapper serves two purposes: 27 - It stores the Dspy model along with dspy global settings, which are required for seamless 28 saving and loading. 29 - It provides a `predict` method so that it can be loaded as an MLflow pyfunc, which is 30 used at serving time. 31 """ 32 33 def __init__( 34 self, 35 model: "dspy.Module", 36 dspy_settings: dict[str, Any], 37 model_config: dict[str, Any] | None = None, 38 ): 39 self.model = model 40 self.dspy_settings = dspy_settings 41 self.model_config = model_config or {} 42 self.output_schema: Schema | None = None 43 44 def predict(self, inputs: Any, params: dict[str, Any] | None = None): 45 import dspy 46 47 converted_inputs = self._get_model_input(inputs) 48 49 with dspy.context(**self.dspy_settings): 50 if isinstance(converted_inputs, dict): 51 # We pass a dict as keyword args and don't allow DSPy models 52 # to receive a single dict. 53 result = self.model(**converted_inputs) 54 else: 55 result = self.model(converted_inputs) 56 57 if isinstance(result, dspy.Prediction): 58 return result.toDict() 59 else: 60 return result 61 62 def predict_stream(self, inputs: Any, params=None): 63 import dspy 64 65 converted_inputs = self._get_model_input(inputs) 66 67 self._validate_streaming() 68 69 stream_listeners = [ 70 dspy.streaming.StreamListener(signature_field_name=spec.name) 71 for spec in self.output_schema 72 ] 73 stream_model = dspy.streamify( 74 self.model, 75 stream_listeners=stream_listeners, 76 async_streaming=False, 77 include_final_prediction_in_output_stream=False, 78 ) 79 80 if isinstance(converted_inputs, dict): 81 outputs = stream_model(**converted_inputs) 82 else: 83 outputs = stream_model(converted_inputs) 84 85 with dspy.context(**self.dspy_settings): 86 for output in outputs: 87 if is_dataclass(output): 88 yield asdict(output) 89 elif isinstance(output, dspy.Prediction): 90 yield output.toDict() 91 else: 92 yield output 93 94 def _get_model_input(self, inputs: Any) -> str | dict[str, Any]: 95 """Convert the PythonModel input into the DSPy program input 96 97 Examples of expected conversions: 98 - str -> str 99 - dict -> dict 100 - np.ndarray with one element -> single element 101 - pd.DataFrame with one row and string column -> single row dict 102 - pd.DataFrame with one row and non-string column -> single element 103 - list -> raises an exception 104 - np.ndarray with more than one element -> raises an exception 105 - pd.DataFrame with more than one row -> raises an exception 106 """ 107 import numpy as np 108 import pandas as pd 109 110 supported_input_types = (np.ndarray, pd.DataFrame, str, dict) 111 if not isinstance(inputs, supported_input_types): 112 raise MlflowException( 113 f"`inputs` must be one of: {[x.__name__ for x in supported_input_types]}, but " 114 f"received type: {type(inputs)}.", 115 INVALID_PARAMETER_VALUE, 116 ) 117 if isinstance(inputs, pd.DataFrame): 118 if len(inputs) != 1: 119 raise MlflowException( 120 _INVALID_SIZE_MESSAGE, 121 INVALID_PARAMETER_VALUE, 122 ) 123 if all(isinstance(col, str) for col in inputs.columns): 124 inputs = inputs.to_dict(orient="records")[0] 125 else: 126 inputs = inputs.values[0] 127 if isinstance(inputs, np.ndarray): 128 if len(inputs) != 1: 129 raise MlflowException( 130 _INVALID_SIZE_MESSAGE, 131 INVALID_PARAMETER_VALUE, 132 ) 133 inputs = inputs[0] 134 135 return inputs 136 137 def _validate_streaming( 138 self, 139 ): 140 if Version(importlib.metadata.version("dspy")) <= Version("2.6.23"): 141 raise MlflowException( 142 "Streaming API is only supported in dspy 2.6.24 or later. " 143 "Please upgrade your dspy version." 144 ) 145 146 if self.output_schema is None: 147 raise MlflowException( 148 "Output schema of the DSPy model is not set. Please log your DSPy " 149 "model with `signature` or `input_example` to use streaming API.", 150 error_code=INVALID_PARAMETER_VALUE, 151 ) 152 153 if any(spec.type != DataType.string for spec in self.output_schema): 154 raise MlflowException( 155 f"All output fields must be string to use streaming API. Got {self.output_schema}.", 156 error_code=INVALID_PARAMETER_VALUE, 157 ) 158 159 160 class DspyChatModelWrapper(DspyModelWrapper): 161 """MLflow PyFunc wrapper class for Dspy chat models.""" 162 163 def predict(self, inputs: Any, params: dict[str, Any] | None = None): 164 import dspy 165 166 converted_inputs = self._get_model_input(inputs) 167 168 # `dspy.settings` cannot be shared across threads, so we are setting the context at every 169 # predict call. 170 with dspy.context(**self.dspy_settings): 171 outputs = self.model(converted_inputs) 172 173 choices = [] 174 if isinstance(outputs, str): 175 choices.append(self._construct_chat_message("assistant", outputs)) 176 elif isinstance(outputs, dict): 177 role = outputs.get("role", "assistant") 178 choices.append(self._construct_chat_message(role, json.dumps(outputs))) 179 elif isinstance(outputs, dspy.Prediction): 180 choices.append(self._construct_chat_message("assistant", json.dumps(outputs.toDict()))) 181 elif isinstance(outputs, list): 182 for output in outputs: 183 if isinstance(output, dict): 184 role = output.get("role", "assistant") 185 choices.append(self._construct_chat_message(role, json.dumps(outputs))) 186 elif isinstance(output, dspy.Prediction): 187 role = output.get("role", "assistant") 188 choices.append(self._construct_chat_message(role, json.dumps(outputs.toDict()))) 189 else: 190 raise MlflowException( 191 f"Unsupported output type: {type(output)}. To log a DSPy model with task " 192 "'llm/v1/chat', the DSPy model must return a dict, a dspy.Prediction, or a " 193 "list of dicts or dspy.Prediction.", 194 INVALID_PARAMETER_VALUE, 195 ) 196 else: 197 raise MlflowException( 198 f"Unsupported output type: {type(outputs)}. To log a DSPy model with task " 199 "'llm/v1/chat', the DSPy model must return a dict, a dspy.Prediction, or a list of " 200 "dicts or dspy.Prediction.", 201 INVALID_PARAMETER_VALUE, 202 ) 203 204 return {"choices": choices} 205 206 def predict_stream(self, inputs: Any, params=None): 207 raise NotImplementedError( 208 "Streaming is not supported for DSPy model with task 'llm/v1/chat'." 209 ) 210 211 def _get_model_input(self, inputs: Any) -> str | list[dict[str, Any]]: 212 import pandas as pd 213 214 if isinstance(inputs, dict): 215 return inputs["messages"] 216 if isinstance(inputs, pd.DataFrame): 217 return inputs.messages[0] 218 219 raise MlflowException( 220 f"Unsupported input type: {type(inputs)}. To log a DSPy model with task " 221 "'llm/v1/chat', the input must be a dict or a pandas DataFrame.", 222 INVALID_PARAMETER_VALUE, 223 ) 224 225 def _construct_chat_message(self, role: str, content: str) -> dict[str, Any]: 226 return { 227 "index": 0, 228 "message": { 229 "role": role, 230 "content": content, 231 }, 232 "finish_reason": "stop", 233 }