test_save.py
1 import importlib 2 import json 3 from unittest import mock 4 5 import dspy 6 import dspy.teleprompt 7 import pytest 8 from dspy.utils.dummies import DummyLM, dummy_rm 9 from packaging.version import Version 10 11 import mlflow 12 from mlflow.models import Model, ModelSignature 13 from mlflow.types.schema import ColSpec, Schema 14 15 from tests.helper_functions import ( 16 _assert_pip_requirements, 17 _compare_logged_code_paths, 18 _mlflow_major_version_string, 19 expect_status_code, 20 pyfunc_serve_and_score_model, 21 ) 22 23 _DSPY_VERSION = Version(importlib.metadata.version("dspy")) 24 25 _DSPY_UNDER_2_6 = _DSPY_VERSION < Version("2.6.0") 26 27 _DSPY_2_6_23_OR_OLDER = _DSPY_VERSION <= Version("2.6.23") 28 skip_if_2_6_23_or_older = pytest.mark.skipif( 29 _DSPY_2_6_23_OR_OLDER, 30 reason="Streaming API is only supported in dspy 2.6.24 or later.", 31 ) 32 33 _REASONING_KEYWORD = "rationale" if _DSPY_UNDER_2_6 else "reasoning" 34 35 36 @pytest.fixture 37 def dummy_model(): 38 return DummyLM([ 39 {"answer": answer, _REASONING_KEYWORD: "reason"} for answer in ["4", "6", "8", "10"] 40 ]) 41 42 43 class CoT(dspy.Module): 44 def __init__(self): 45 super().__init__() 46 self.prog = dspy.ChainOfThought("question -> answer") 47 48 def forward(self, question): 49 return self.prog(question=question) 50 51 52 class NumericalCoT(dspy.Module): 53 def __init__(self): 54 super().__init__() 55 self.prog = dspy.ChainOfThought("question -> answer: int") 56 57 def forward(self, question): 58 return self.prog(question=question).answer 59 60 61 @pytest.fixture(autouse=True) 62 def reset_dspy_settings(): 63 yield 64 65 dspy.settings.configure(lm=None, rm=None) 66 67 68 use_dspy_model_save_param = pytest.param( 69 True, 70 marks=pytest.mark.skipif( 71 Version(dspy.__version__) <= Version("3.1.0"), 72 reason="dspy<=3.1.0 does not support 'use_dspy_model_save' param.", 73 ), 74 ) 75 76 77 @pytest.mark.parametrize("use_dspy_model_save", [use_dspy_model_save_param, False]) 78 def test_basic_save(use_dspy_model_save): 79 if use_dspy_model_save and _DSPY_VERSION <= Version("2.6.0"): 80 pytest.skip("'use_dspy_model_save' = True does not support dspy <= 2.6.0") 81 82 dspy_model = CoT() 83 dspy.settings.configure(lm=dspy.LM(model="openai/gpt-4o-mini", max_tokens=250)) 84 85 with mlflow.start_run(): 86 model_info = mlflow.dspy.log_model( 87 dspy_model, 88 name="model", 89 use_dspy_model_save=use_dspy_model_save, 90 ) 91 92 # Clear the lm setting to test the loading logic. 93 dspy.settings.configure(lm=None) 94 95 loaded_model = mlflow.dspy.load_model(model_info.model_uri) 96 97 # Check that the global settings is popped back. 98 assert dspy.settings.lm.model == "openai/gpt-4o-mini" 99 assert isinstance(loaded_model, CoT) 100 101 102 @pytest.mark.parametrize("use_dspy_model_save", [use_dspy_model_save_param, False]) 103 def test_save_compiled_model(dummy_model, use_dspy_model_save): 104 train_data = [ 105 "What is 2 + 2?", 106 "What is 3 + 3?", 107 "What is 4 + 4?", 108 "What is 5 + 5?", 109 ] 110 train_label = ["4", "6", "8", "10"] 111 trainset = [ 112 dspy.Example(question=q, answer=a).with_inputs("question") 113 for q, a in zip(train_data, train_label) 114 ] 115 116 def dummy_metric(program): 117 return 1.0 118 119 dspy.settings.configure(lm=dummy_model) 120 121 dspy_model = CoT() 122 optimizer = dspy.teleprompt.BootstrapFewShot(metric=dummy_metric) 123 optimized_cot = optimizer.compile(dspy_model, trainset=trainset) 124 125 with mlflow.start_run(): 126 model_info = mlflow.dspy.log_model( 127 optimized_cot, name="model", use_dspy_model_save=use_dspy_model_save 128 ) 129 130 # Clear the lm setting to test the loading logic. 131 dspy.settings.configure(lm=None) 132 133 loaded_model = mlflow.dspy.load_model(model_info.model_uri) 134 135 assert isinstance(loaded_model, CoT) 136 assert loaded_model.prog.predictors()[0].demos == optimized_cot.prog.predictors()[0].demos 137 138 139 @pytest.mark.parametrize("use_dspy_model_save", [use_dspy_model_save_param, False]) 140 def test_dspy_save_preserves_object_state(use_dspy_model_save): 141 class GenerateAnswer(dspy.Signature): 142 """Answer questions with short factoid answers.""" 143 144 context = dspy.InputField(desc="may contain relevant facts") 145 question = dspy.InputField() 146 answer = dspy.OutputField(desc="often between 1 and 5 words") 147 148 class RAG(dspy.Module): 149 def __init__(self, num_passages=3): 150 super().__init__() 151 152 self.retrieve = dspy.Retrieve(k=num_passages) 153 self.generate_answer = dspy.ChainOfThought(GenerateAnswer) 154 155 def forward(self, question): 156 assert question == "What is 2 + 2?" 157 context = self.retrieve(question).passages 158 prediction = self.generate_answer(context=context, question=question) 159 return dspy.Prediction(context=context, answer=prediction.answer) 160 161 def dummy_metric(*args, **kwargs): 162 return 1.0 163 164 model = DummyLM([{"answer": answer, "reasoning": "reason"} for answer in ["4", "6", "8", "10"]]) 165 rm = dummy_rm(passages=["dummy1", "dummy2", "dummy3"]) 166 dspy.settings.configure(lm=model, rm=rm) 167 168 train_data = [ 169 "What is 2 + 2?", 170 "What is 3 + 3?", 171 "What is 4 + 4?", 172 "What is 5 + 5?", 173 ] 174 train_label = ["4", "6", "8", "10"] 175 trainset = [ 176 dspy.Example(question=q, answer=a).with_inputs("question").with_inputs("reasoning") 177 for q, a in zip(train_data, train_label) 178 ] 179 180 dspy_model = RAG() 181 optimizer = dspy.teleprompt.BootstrapFewShot(metric=dummy_metric) 182 optimized_cot = optimizer.compile(dspy_model, trainset=trainset) 183 184 with mlflow.start_run(): 185 model_info = mlflow.dspy.log_model( 186 optimized_cot, name="model", use_dspy_model_save=use_dspy_model_save 187 ) 188 189 original_settings = dict(dspy.settings.config) 190 original_settings["traces"] = None 191 192 # Clear the lm setting to test the loading logic. 193 dspy.settings.configure(lm=None) 194 195 model_url = model_info.model_uri 196 197 input_examples = {"inputs": ["What is 2 + 2?"]} 198 # test that the model can be served 199 response = pyfunc_serve_and_score_model( 200 model_uri=model_url, 201 data=json.dumps(input_examples), 202 content_type="application/json", 203 extra_args=["--env-manager", "local"], 204 ) 205 expect_status_code(response, 200) 206 207 loaded_model = mlflow.dspy.load_model(model_url) 208 assert isinstance(loaded_model, RAG) 209 assert loaded_model.retrieve is not None 210 assert ( 211 loaded_model.generate_answer.predictors()[0].demos 212 == optimized_cot.generate_answer.predictors()[0].demos 213 ) 214 215 loaded_settings = dict(dspy.settings.config) 216 loaded_settings["traces"] = None 217 218 assert loaded_settings["lm"].model == original_settings["lm"].model 219 assert loaded_settings["lm"].model_type == original_settings["lm"].model_type 220 assert loaded_settings["rm"].__dict__ == original_settings["rm"].__dict__ 221 222 del ( 223 loaded_settings["lm"], 224 original_settings["lm"], 225 loaded_settings["rm"], 226 original_settings["rm"], 227 ) 228 229 assert original_settings == loaded_settings 230 231 232 @pytest.mark.parametrize("use_dspy_model_save", [use_dspy_model_save_param, False]) 233 def test_load_logged_model_in_native_dspy(dummy_model, use_dspy_model_save): 234 dspy_model = CoT() 235 # Arbitrary set the demo to test saving/loading has no data loss. 236 dspy_model.prog.predictors()[0].demos = [ 237 "What is 2 + 2?", 238 "What is 3 + 3?", 239 "What is 4 + 4?", 240 "What is 5 + 5?", 241 ] 242 dspy.settings.configure(lm=dummy_model) 243 244 with mlflow.start_run(): 245 model_info = mlflow.dspy.log_model( 246 dspy_model, name="model", use_dspy_model_save=use_dspy_model_save 247 ) 248 loaded_dspy_model = mlflow.dspy.load_model(model_info.model_uri) 249 250 assert isinstance(loaded_dspy_model, CoT) 251 assert loaded_dspy_model.prog.predictors()[0].demos == dspy_model.prog.predictors()[0].demos 252 253 254 def test_serving_logged_model(dummy_model): 255 class CoT(dspy.Module): 256 def __init__(self): 257 super().__init__() 258 self.prog = dspy.ChainOfThought("question -> answer") 259 260 def forward(self, question): 261 assert question == "What is 2 + 2?" 262 return self.prog(question=question) 263 264 dspy_model = CoT() 265 dspy.settings.configure(lm=dummy_model) 266 267 input_examples = {"inputs": ["What is 2 + 2?"]} 268 input_schema = Schema([ColSpec("string")]) 269 output_schema = Schema([ColSpec("string")]) 270 signature = ModelSignature(inputs=input_schema, outputs=output_schema) 271 272 artifact_path = "model" 273 with mlflow.start_run(): 274 model_info = mlflow.dspy.log_model( 275 dspy_model, 276 name=artifact_path, 277 signature=signature, 278 input_example=["What is 2 + 2?"], 279 ) 280 model_uri = model_info.model_uri 281 dspy.settings.configure(lm=None) 282 283 response = pyfunc_serve_and_score_model( 284 model_uri=model_uri, 285 data=json.dumps(input_examples), 286 content_type="application/json", 287 extra_args=["--env-manager", "local"], 288 ) 289 290 expect_status_code(response, 200) 291 292 json_response = json.loads(response.content) 293 294 assert _REASONING_KEYWORD in json_response["predictions"] 295 assert "answer" in json_response["predictions"] 296 297 298 def test_log_model_multi_inputs(dummy_model): 299 class MultiInputCoT(dspy.Module): 300 def __init__(self): 301 super().__init__() 302 self.prog = dspy.ChainOfThought("question, hint -> answer") 303 304 def forward(self, question, hint): 305 assert question == "What is 2 + 2?" 306 assert hint == "Hint: 2 + 2 = ?" 307 return self.prog(question=question, hint=hint) 308 309 dspy_model = MultiInputCoT() 310 311 dspy.settings.configure(lm=dummy_model) 312 313 input_example = {"question": "What is 2 + 2?", "hint": "Hint: 2 + 2 = ?"} 314 315 with mlflow.start_run(): 316 model_info = mlflow.dspy.log_model( 317 dspy_model, 318 name="model", 319 input_example=input_example, 320 ) 321 322 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 323 assert loaded_model.predict(input_example) == {"answer": "6", _REASONING_KEYWORD: "reason"} 324 325 response = pyfunc_serve_and_score_model( 326 model_uri=model_info.model_uri, 327 data=json.dumps({"inputs": [input_example]}), 328 content_type="application/json", 329 extra_args=["--env-manager", "local"], 330 ) 331 332 expect_status_code(response, 200) 333 334 json_response = json.loads(response.content) 335 336 assert _REASONING_KEYWORD in json_response["predictions"] 337 assert "answer" in json_response["predictions"] 338 339 340 def test_save_chat_model_with_string_output(dummy_model): 341 class CoT(dspy.Module): 342 def __init__(self): 343 super().__init__() 344 self.prog = dspy.ChainOfThought("question -> answer") 345 346 def forward(self, inputs): 347 # DSPy chat model's inputs is a list of dict with keys roles (optional) and content. 348 # And here we output a single string. 349 return self.prog(question=inputs[0]["content"]).answer 350 351 dspy_model = CoT() 352 dspy.settings.configure(lm=dummy_model) 353 354 input_examples = {"messages": [{"role": "user", "content": "What is 2 + 2?"}]} 355 356 artifact_path = "model" 357 with mlflow.start_run(): 358 model_info = mlflow.dspy.log_model( 359 dspy_model, 360 name=artifact_path, 361 task="llm/v1/chat", 362 input_example=input_examples, 363 ) 364 loaded_pyfunc = mlflow.pyfunc.load_model(model_info.model_uri) 365 response = loaded_pyfunc.predict(input_examples) 366 367 assert "choices" in response 368 assert len(response["choices"]) == 1 369 assert "message" in response["choices"][0] 370 # The content should just be a string. 371 assert response["choices"][0]["message"]["content"] == "4" 372 373 374 def test_serve_chat_model(dummy_model): 375 class CoT(dspy.Module): 376 def __init__(self): 377 super().__init__() 378 self.prog = dspy.ChainOfThought("question -> answer") 379 380 def forward(self, inputs): 381 return self.prog(question=inputs[0]["content"]) 382 383 dspy_model = CoT() 384 dspy.settings.configure(lm=dummy_model) 385 386 input_examples = {"messages": [{"role": "user", "content": "What is 2 + 2?"}]} 387 388 artifact_path = "model" 389 with mlflow.start_run(): 390 model_info = mlflow.dspy.log_model( 391 dspy_model, 392 name=artifact_path, 393 task="llm/v1/chat", 394 input_example=input_examples, 395 ) 396 dspy.settings.configure(lm=None) 397 398 response = pyfunc_serve_and_score_model( 399 model_uri=model_info.model_uri, 400 data=json.dumps(input_examples), 401 content_type="application/json", 402 extra_args=["--env-manager", "local"], 403 ) 404 405 expect_status_code(response, 200) 406 407 json_response = json.loads(response.content) 408 409 assert "choices" in json_response 410 assert len(json_response["choices"]) == 1 411 assert "message" in json_response["choices"][0] 412 assert _REASONING_KEYWORD in json_response["choices"][0]["message"]["content"] 413 assert "answer" in json_response["choices"][0]["message"]["content"] 414 415 416 def test_code_paths_is_used(): 417 artifact_path = "model" 418 dspy_model = CoT() 419 with ( 420 mlflow.start_run(), 421 mock.patch("mlflow.dspy.load._add_code_from_conf_to_system_path") as add_mock, 422 ): 423 model_info = mlflow.dspy.log_model(dspy_model, name=artifact_path, code_paths=[__file__]) 424 _compare_logged_code_paths(__file__, model_info.model_uri, "dspy") 425 mlflow.dspy.load_model(model_info.model_uri) 426 add_mock.assert_called() 427 428 429 def test_additional_pip_requirements(): 430 expected_mlflow_version = _mlflow_major_version_string() 431 artifact_path = "model" 432 dspy_model = CoT() 433 with mlflow.start_run(): 434 model_info = mlflow.dspy.log_model( 435 dspy_model, name=artifact_path, extra_pip_requirements=["dummy"] 436 ) 437 438 _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "dummy"]) 439 440 441 def test_infer_signature_from_input_examples(dummy_model): 442 artifact_path = "model" 443 dspy_model = CoT() 444 dspy.settings.configure(lm=dummy_model) 445 with mlflow.start_run(): 446 model_info = mlflow.dspy.log_model( 447 dspy_model, name=artifact_path, input_example="what is 2 + 2?" 448 ) 449 450 loaded_model = Model.load(model_info.model_uri) 451 assert loaded_model.signature.inputs == Schema([ColSpec("string")]) 452 assert loaded_model.signature.outputs == Schema([ 453 ColSpec(name="answer", type="string"), 454 ColSpec(name=_REASONING_KEYWORD, type="string"), 455 ]) 456 457 458 @skip_if_2_6_23_or_older 459 def test_predict_stream_unsupported_schema(dummy_model): 460 dspy_model = NumericalCoT() 461 dspy.settings.configure(lm=dummy_model) 462 463 model_info = mlflow.dspy.log_model(dspy_model, name="model") 464 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 465 466 assert not loaded_model._model_meta.flavors["python_function"]["streamable"] 467 output = loaded_model.predict_stream({"question": "What is 2 + 2?"}) 468 with pytest.raises( 469 mlflow.exceptions.MlflowException, 470 match="This model does not support predict_stream method.", 471 ): 472 next(output) 473 474 475 @skip_if_2_6_23_or_older 476 def test_predict_stream_success(dummy_model): 477 dspy_model = CoT() 478 dspy.settings.configure(lm=dummy_model) 479 480 model_info = mlflow.dspy.log_model( 481 dspy_model, name="model", input_example={"question": "what is 2 + 2?"} 482 ) 483 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 484 485 assert loaded_model._model_meta.flavors["python_function"]["streamable"] 486 results = [] 487 488 def dummy_streamify(*args, **kwargs): 489 # In dspy>=3, `StreamResponse` requires `is_last_chunk` argument. 490 # https://github.com/stanfordnlp/dspy/pull/8587 491 extra_kwargs = {"is_last_chunk": False} if _DSPY_VERSION.major >= 3 else {} 492 yield dspy.streaming.StreamResponse( 493 predict_name="prog.predict", 494 signature_field_name="answer", 495 chunk="2", 496 **extra_kwargs, 497 ) 498 extra_kwargs = {"is_last_chunk": True} if _DSPY_VERSION.major >= 3 else {} 499 yield dspy.streaming.StreamResponse( 500 predict_name="prog.predict", 501 signature_field_name=_REASONING_KEYWORD, 502 chunk="reason", 503 **extra_kwargs, 504 ) 505 506 with mock.patch("dspy.streamify", return_value=dummy_streamify): 507 output = loaded_model.predict_stream({"question": "What is 2 + 2?"}) 508 for o in output: 509 results.append(o) 510 511 assert len(results) == 2 512 extra_kwargs = {"is_last_chunk": False} if _DSPY_VERSION.major >= 3 else {} 513 assert results[0] == { 514 "predict_name": "prog.predict", 515 "signature_field_name": "answer", 516 "chunk": "2", 517 **extra_kwargs, 518 } 519 extra_kwargs = {"is_last_chunk": True} if _DSPY_VERSION.major >= 3 else {} 520 assert results[1] == { 521 "predict_name": "prog.predict", 522 "signature_field_name": _REASONING_KEYWORD, 523 "chunk": "reason", 524 **extra_kwargs, 525 } 526 527 528 def test_predict_output(dummy_model): 529 class MockModelReturningNonPrediction(dspy.Module): 530 def forward(self, question): 531 # Return a plain dict instead of dspy.Prediction 532 return {"answer": "4", "custom_field": "custom_value"} 533 534 class MockModelReturningPrediction(dspy.Module): 535 def forward(self, question): 536 # Return a dspy.Prediction 537 prediction = dspy.Prediction() 538 prediction.answer = "4" 539 prediction.custom_field = "custom_value" 540 return prediction 541 542 dspy.settings.configure(lm=dummy_model) 543 544 non_prediction_model = MockModelReturningNonPrediction() 545 model_info = mlflow.dspy.log_model(non_prediction_model, name="non_prediction_model") 546 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 547 result = loaded_model.predict("What is 2 + 2?") 548 549 assert isinstance(result, dict) 550 assert result == {"answer": "4", "custom_field": "custom_value"} 551 552 prediction_model = MockModelReturningPrediction() 553 model_info = mlflow.dspy.log_model(prediction_model, name="prediction_model") 554 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 555 result = loaded_model.predict("What is 2 + 2?") 556 557 assert isinstance(result, dict) 558 assert result == {"answer": "4", "custom_field": "custom_value"}