/ tests / dspy / test_save.py
test_save.py
  1  import importlib
  2  import json
  3  from unittest import mock
  4  
  5  import dspy
  6  import dspy.teleprompt
  7  import pytest
  8  from dspy.utils.dummies import DummyLM, dummy_rm
  9  from packaging.version import Version
 10  
 11  import mlflow
 12  from mlflow.models import Model, ModelSignature
 13  from mlflow.types.schema import ColSpec, Schema
 14  
 15  from tests.helper_functions import (
 16      _assert_pip_requirements,
 17      _compare_logged_code_paths,
 18      _mlflow_major_version_string,
 19      expect_status_code,
 20      pyfunc_serve_and_score_model,
 21  )
 22  
 23  _DSPY_VERSION = Version(importlib.metadata.version("dspy"))
 24  
 25  _DSPY_UNDER_2_6 = _DSPY_VERSION < Version("2.6.0")
 26  
 27  _DSPY_2_6_23_OR_OLDER = _DSPY_VERSION <= Version("2.6.23")
 28  skip_if_2_6_23_or_older = pytest.mark.skipif(
 29      _DSPY_2_6_23_OR_OLDER,
 30      reason="Streaming API is only supported in dspy 2.6.24 or later.",
 31  )
 32  
 33  _REASONING_KEYWORD = "rationale" if _DSPY_UNDER_2_6 else "reasoning"
 34  
 35  
 36  @pytest.fixture
 37  def dummy_model():
 38      return DummyLM([
 39          {"answer": answer, _REASONING_KEYWORD: "reason"} for answer in ["4", "6", "8", "10"]
 40      ])
 41  
 42  
 43  class CoT(dspy.Module):
 44      def __init__(self):
 45          super().__init__()
 46          self.prog = dspy.ChainOfThought("question -> answer")
 47  
 48      def forward(self, question):
 49          return self.prog(question=question)
 50  
 51  
 52  class NumericalCoT(dspy.Module):
 53      def __init__(self):
 54          super().__init__()
 55          self.prog = dspy.ChainOfThought("question -> answer: int")
 56  
 57      def forward(self, question):
 58          return self.prog(question=question).answer
 59  
 60  
 61  @pytest.fixture(autouse=True)
 62  def reset_dspy_settings():
 63      yield
 64  
 65      dspy.settings.configure(lm=None, rm=None)
 66  
 67  
 68  use_dspy_model_save_param = pytest.param(
 69      True,
 70      marks=pytest.mark.skipif(
 71          Version(dspy.__version__) <= Version("3.1.0"),
 72          reason="dspy<=3.1.0 does not support 'use_dspy_model_save' param.",
 73      ),
 74  )
 75  
 76  
 77  @pytest.mark.parametrize("use_dspy_model_save", [use_dspy_model_save_param, False])
 78  def test_basic_save(use_dspy_model_save):
 79      if use_dspy_model_save and _DSPY_VERSION <= Version("2.6.0"):
 80          pytest.skip("'use_dspy_model_save' = True does not support dspy <= 2.6.0")
 81  
 82      dspy_model = CoT()
 83      dspy.settings.configure(lm=dspy.LM(model="openai/gpt-4o-mini", max_tokens=250))
 84  
 85      with mlflow.start_run():
 86          model_info = mlflow.dspy.log_model(
 87              dspy_model,
 88              name="model",
 89              use_dspy_model_save=use_dspy_model_save,
 90          )
 91  
 92      # Clear the lm setting to test the loading logic.
 93      dspy.settings.configure(lm=None)
 94  
 95      loaded_model = mlflow.dspy.load_model(model_info.model_uri)
 96  
 97      # Check that the global settings is popped back.
 98      assert dspy.settings.lm.model == "openai/gpt-4o-mini"
 99      assert isinstance(loaded_model, CoT)
100  
101  
102  @pytest.mark.parametrize("use_dspy_model_save", [use_dspy_model_save_param, False])
103  def test_save_compiled_model(dummy_model, use_dspy_model_save):
104      train_data = [
105          "What is 2 + 2?",
106          "What is 3 + 3?",
107          "What is 4 + 4?",
108          "What is 5 + 5?",
109      ]
110      train_label = ["4", "6", "8", "10"]
111      trainset = [
112          dspy.Example(question=q, answer=a).with_inputs("question")
113          for q, a in zip(train_data, train_label)
114      ]
115  
116      def dummy_metric(program):
117          return 1.0
118  
119      dspy.settings.configure(lm=dummy_model)
120  
121      dspy_model = CoT()
122      optimizer = dspy.teleprompt.BootstrapFewShot(metric=dummy_metric)
123      optimized_cot = optimizer.compile(dspy_model, trainset=trainset)
124  
125      with mlflow.start_run():
126          model_info = mlflow.dspy.log_model(
127              optimized_cot, name="model", use_dspy_model_save=use_dspy_model_save
128          )
129  
130      # Clear the lm setting to test the loading logic.
131      dspy.settings.configure(lm=None)
132  
133      loaded_model = mlflow.dspy.load_model(model_info.model_uri)
134  
135      assert isinstance(loaded_model, CoT)
136      assert loaded_model.prog.predictors()[0].demos == optimized_cot.prog.predictors()[0].demos
137  
138  
139  @pytest.mark.parametrize("use_dspy_model_save", [use_dspy_model_save_param, False])
140  def test_dspy_save_preserves_object_state(use_dspy_model_save):
141      class GenerateAnswer(dspy.Signature):
142          """Answer questions with short factoid answers."""
143  
144          context = dspy.InputField(desc="may contain relevant facts")
145          question = dspy.InputField()
146          answer = dspy.OutputField(desc="often between 1 and 5 words")
147  
148      class RAG(dspy.Module):
149          def __init__(self, num_passages=3):
150              super().__init__()
151  
152              self.retrieve = dspy.Retrieve(k=num_passages)
153              self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
154  
155          def forward(self, question):
156              assert question == "What is 2 + 2?"
157              context = self.retrieve(question).passages
158              prediction = self.generate_answer(context=context, question=question)
159              return dspy.Prediction(context=context, answer=prediction.answer)
160  
161      def dummy_metric(*args, **kwargs):
162          return 1.0
163  
164      model = DummyLM([{"answer": answer, "reasoning": "reason"} for answer in ["4", "6", "8", "10"]])
165      rm = dummy_rm(passages=["dummy1", "dummy2", "dummy3"])
166      dspy.settings.configure(lm=model, rm=rm)
167  
168      train_data = [
169          "What is 2 + 2?",
170          "What is 3 + 3?",
171          "What is 4 + 4?",
172          "What is 5 + 5?",
173      ]
174      train_label = ["4", "6", "8", "10"]
175      trainset = [
176          dspy.Example(question=q, answer=a).with_inputs("question").with_inputs("reasoning")
177          for q, a in zip(train_data, train_label)
178      ]
179  
180      dspy_model = RAG()
181      optimizer = dspy.teleprompt.BootstrapFewShot(metric=dummy_metric)
182      optimized_cot = optimizer.compile(dspy_model, trainset=trainset)
183  
184      with mlflow.start_run():
185          model_info = mlflow.dspy.log_model(
186              optimized_cot, name="model", use_dspy_model_save=use_dspy_model_save
187          )
188  
189      original_settings = dict(dspy.settings.config)
190      original_settings["traces"] = None
191  
192      # Clear the lm setting to test the loading logic.
193      dspy.settings.configure(lm=None)
194  
195      model_url = model_info.model_uri
196  
197      input_examples = {"inputs": ["What is 2 + 2?"]}
198      # test that the model can be served
199      response = pyfunc_serve_and_score_model(
200          model_uri=model_url,
201          data=json.dumps(input_examples),
202          content_type="application/json",
203          extra_args=["--env-manager", "local"],
204      )
205      expect_status_code(response, 200)
206  
207      loaded_model = mlflow.dspy.load_model(model_url)
208      assert isinstance(loaded_model, RAG)
209      assert loaded_model.retrieve is not None
210      assert (
211          loaded_model.generate_answer.predictors()[0].demos
212          == optimized_cot.generate_answer.predictors()[0].demos
213      )
214  
215      loaded_settings = dict(dspy.settings.config)
216      loaded_settings["traces"] = None
217  
218      assert loaded_settings["lm"].model == original_settings["lm"].model
219      assert loaded_settings["lm"].model_type == original_settings["lm"].model_type
220      assert loaded_settings["rm"].__dict__ == original_settings["rm"].__dict__
221  
222      del (
223          loaded_settings["lm"],
224          original_settings["lm"],
225          loaded_settings["rm"],
226          original_settings["rm"],
227      )
228  
229      assert original_settings == loaded_settings
230  
231  
232  @pytest.mark.parametrize("use_dspy_model_save", [use_dspy_model_save_param, False])
233  def test_load_logged_model_in_native_dspy(dummy_model, use_dspy_model_save):
234      dspy_model = CoT()
235      # Arbitrary set the demo to test saving/loading has no data loss.
236      dspy_model.prog.predictors()[0].demos = [
237          "What is 2 + 2?",
238          "What is 3 + 3?",
239          "What is 4 + 4?",
240          "What is 5 + 5?",
241      ]
242      dspy.settings.configure(lm=dummy_model)
243  
244      with mlflow.start_run():
245          model_info = mlflow.dspy.log_model(
246              dspy_model, name="model", use_dspy_model_save=use_dspy_model_save
247          )
248      loaded_dspy_model = mlflow.dspy.load_model(model_info.model_uri)
249  
250      assert isinstance(loaded_dspy_model, CoT)
251      assert loaded_dspy_model.prog.predictors()[0].demos == dspy_model.prog.predictors()[0].demos
252  
253  
254  def test_serving_logged_model(dummy_model):
255      class CoT(dspy.Module):
256          def __init__(self):
257              super().__init__()
258              self.prog = dspy.ChainOfThought("question -> answer")
259  
260          def forward(self, question):
261              assert question == "What is 2 + 2?"
262              return self.prog(question=question)
263  
264      dspy_model = CoT()
265      dspy.settings.configure(lm=dummy_model)
266  
267      input_examples = {"inputs": ["What is 2 + 2?"]}
268      input_schema = Schema([ColSpec("string")])
269      output_schema = Schema([ColSpec("string")])
270      signature = ModelSignature(inputs=input_schema, outputs=output_schema)
271  
272      artifact_path = "model"
273      with mlflow.start_run():
274          model_info = mlflow.dspy.log_model(
275              dspy_model,
276              name=artifact_path,
277              signature=signature,
278              input_example=["What is 2 + 2?"],
279          )
280          model_uri = model_info.model_uri
281      dspy.settings.configure(lm=None)
282  
283      response = pyfunc_serve_and_score_model(
284          model_uri=model_uri,
285          data=json.dumps(input_examples),
286          content_type="application/json",
287          extra_args=["--env-manager", "local"],
288      )
289  
290      expect_status_code(response, 200)
291  
292      json_response = json.loads(response.content)
293  
294      assert _REASONING_KEYWORD in json_response["predictions"]
295      assert "answer" in json_response["predictions"]
296  
297  
298  def test_log_model_multi_inputs(dummy_model):
299      class MultiInputCoT(dspy.Module):
300          def __init__(self):
301              super().__init__()
302              self.prog = dspy.ChainOfThought("question, hint -> answer")
303  
304          def forward(self, question, hint):
305              assert question == "What is 2 + 2?"
306              assert hint == "Hint: 2 + 2 = ?"
307              return self.prog(question=question, hint=hint)
308  
309      dspy_model = MultiInputCoT()
310  
311      dspy.settings.configure(lm=dummy_model)
312  
313      input_example = {"question": "What is 2 + 2?", "hint": "Hint: 2 + 2 = ?"}
314  
315      with mlflow.start_run():
316          model_info = mlflow.dspy.log_model(
317              dspy_model,
318              name="model",
319              input_example=input_example,
320          )
321  
322      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
323      assert loaded_model.predict(input_example) == {"answer": "6", _REASONING_KEYWORD: "reason"}
324  
325      response = pyfunc_serve_and_score_model(
326          model_uri=model_info.model_uri,
327          data=json.dumps({"inputs": [input_example]}),
328          content_type="application/json",
329          extra_args=["--env-manager", "local"],
330      )
331  
332      expect_status_code(response, 200)
333  
334      json_response = json.loads(response.content)
335  
336      assert _REASONING_KEYWORD in json_response["predictions"]
337      assert "answer" in json_response["predictions"]
338  
339  
340  def test_save_chat_model_with_string_output(dummy_model):
341      class CoT(dspy.Module):
342          def __init__(self):
343              super().__init__()
344              self.prog = dspy.ChainOfThought("question -> answer")
345  
346          def forward(self, inputs):
347              # DSPy chat model's inputs is a list of dict with keys roles (optional) and content.
348              # And here we output a single string.
349              return self.prog(question=inputs[0]["content"]).answer
350  
351      dspy_model = CoT()
352      dspy.settings.configure(lm=dummy_model)
353  
354      input_examples = {"messages": [{"role": "user", "content": "What is 2 + 2?"}]}
355  
356      artifact_path = "model"
357      with mlflow.start_run():
358          model_info = mlflow.dspy.log_model(
359              dspy_model,
360              name=artifact_path,
361              task="llm/v1/chat",
362              input_example=input_examples,
363          )
364      loaded_pyfunc = mlflow.pyfunc.load_model(model_info.model_uri)
365      response = loaded_pyfunc.predict(input_examples)
366  
367      assert "choices" in response
368      assert len(response["choices"]) == 1
369      assert "message" in response["choices"][0]
370      # The content should just be a string.
371      assert response["choices"][0]["message"]["content"] == "4"
372  
373  
374  def test_serve_chat_model(dummy_model):
375      class CoT(dspy.Module):
376          def __init__(self):
377              super().__init__()
378              self.prog = dspy.ChainOfThought("question -> answer")
379  
380          def forward(self, inputs):
381              return self.prog(question=inputs[0]["content"])
382  
383      dspy_model = CoT()
384      dspy.settings.configure(lm=dummy_model)
385  
386      input_examples = {"messages": [{"role": "user", "content": "What is 2 + 2?"}]}
387  
388      artifact_path = "model"
389      with mlflow.start_run():
390          model_info = mlflow.dspy.log_model(
391              dspy_model,
392              name=artifact_path,
393              task="llm/v1/chat",
394              input_example=input_examples,
395          )
396      dspy.settings.configure(lm=None)
397  
398      response = pyfunc_serve_and_score_model(
399          model_uri=model_info.model_uri,
400          data=json.dumps(input_examples),
401          content_type="application/json",
402          extra_args=["--env-manager", "local"],
403      )
404  
405      expect_status_code(response, 200)
406  
407      json_response = json.loads(response.content)
408  
409      assert "choices" in json_response
410      assert len(json_response["choices"]) == 1
411      assert "message" in json_response["choices"][0]
412      assert _REASONING_KEYWORD in json_response["choices"][0]["message"]["content"]
413      assert "answer" in json_response["choices"][0]["message"]["content"]
414  
415  
416  def test_code_paths_is_used():
417      artifact_path = "model"
418      dspy_model = CoT()
419      with (
420          mlflow.start_run(),
421          mock.patch("mlflow.dspy.load._add_code_from_conf_to_system_path") as add_mock,
422      ):
423          model_info = mlflow.dspy.log_model(dspy_model, name=artifact_path, code_paths=[__file__])
424          _compare_logged_code_paths(__file__, model_info.model_uri, "dspy")
425          mlflow.dspy.load_model(model_info.model_uri)
426          add_mock.assert_called()
427  
428  
429  def test_additional_pip_requirements():
430      expected_mlflow_version = _mlflow_major_version_string()
431      artifact_path = "model"
432      dspy_model = CoT()
433      with mlflow.start_run():
434          model_info = mlflow.dspy.log_model(
435              dspy_model, name=artifact_path, extra_pip_requirements=["dummy"]
436          )
437  
438          _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "dummy"])
439  
440  
441  def test_infer_signature_from_input_examples(dummy_model):
442      artifact_path = "model"
443      dspy_model = CoT()
444      dspy.settings.configure(lm=dummy_model)
445      with mlflow.start_run():
446          model_info = mlflow.dspy.log_model(
447              dspy_model, name=artifact_path, input_example="what is 2 + 2?"
448          )
449  
450          loaded_model = Model.load(model_info.model_uri)
451          assert loaded_model.signature.inputs == Schema([ColSpec("string")])
452          assert loaded_model.signature.outputs == Schema([
453              ColSpec(name="answer", type="string"),
454              ColSpec(name=_REASONING_KEYWORD, type="string"),
455          ])
456  
457  
458  @skip_if_2_6_23_or_older
459  def test_predict_stream_unsupported_schema(dummy_model):
460      dspy_model = NumericalCoT()
461      dspy.settings.configure(lm=dummy_model)
462  
463      model_info = mlflow.dspy.log_model(dspy_model, name="model")
464      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
465  
466      assert not loaded_model._model_meta.flavors["python_function"]["streamable"]
467      output = loaded_model.predict_stream({"question": "What is 2 + 2?"})
468      with pytest.raises(
469          mlflow.exceptions.MlflowException,
470          match="This model does not support predict_stream method.",
471      ):
472          next(output)
473  
474  
475  @skip_if_2_6_23_or_older
476  def test_predict_stream_success(dummy_model):
477      dspy_model = CoT()
478      dspy.settings.configure(lm=dummy_model)
479  
480      model_info = mlflow.dspy.log_model(
481          dspy_model, name="model", input_example={"question": "what is 2 + 2?"}
482      )
483      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
484  
485      assert loaded_model._model_meta.flavors["python_function"]["streamable"]
486      results = []
487  
488      def dummy_streamify(*args, **kwargs):
489          # In dspy>=3, `StreamResponse` requires `is_last_chunk` argument.
490          # https://github.com/stanfordnlp/dspy/pull/8587
491          extra_kwargs = {"is_last_chunk": False} if _DSPY_VERSION.major >= 3 else {}
492          yield dspy.streaming.StreamResponse(
493              predict_name="prog.predict",
494              signature_field_name="answer",
495              chunk="2",
496              **extra_kwargs,
497          )
498          extra_kwargs = {"is_last_chunk": True} if _DSPY_VERSION.major >= 3 else {}
499          yield dspy.streaming.StreamResponse(
500              predict_name="prog.predict",
501              signature_field_name=_REASONING_KEYWORD,
502              chunk="reason",
503              **extra_kwargs,
504          )
505  
506      with mock.patch("dspy.streamify", return_value=dummy_streamify):
507          output = loaded_model.predict_stream({"question": "What is 2 + 2?"})
508          for o in output:
509              results.append(o)
510  
511      assert len(results) == 2
512      extra_kwargs = {"is_last_chunk": False} if _DSPY_VERSION.major >= 3 else {}
513      assert results[0] == {
514          "predict_name": "prog.predict",
515          "signature_field_name": "answer",
516          "chunk": "2",
517          **extra_kwargs,
518      }
519      extra_kwargs = {"is_last_chunk": True} if _DSPY_VERSION.major >= 3 else {}
520      assert results[1] == {
521          "predict_name": "prog.predict",
522          "signature_field_name": _REASONING_KEYWORD,
523          "chunk": "reason",
524          **extra_kwargs,
525      }
526  
527  
528  def test_predict_output(dummy_model):
529      class MockModelReturningNonPrediction(dspy.Module):
530          def forward(self, question):
531              # Return a plain dict instead of dspy.Prediction
532              return {"answer": "4", "custom_field": "custom_value"}
533  
534      class MockModelReturningPrediction(dspy.Module):
535          def forward(self, question):
536              # Return a dspy.Prediction
537              prediction = dspy.Prediction()
538              prediction.answer = "4"
539              prediction.custom_field = "custom_value"
540              return prediction
541  
542      dspy.settings.configure(lm=dummy_model)
543  
544      non_prediction_model = MockModelReturningNonPrediction()
545      model_info = mlflow.dspy.log_model(non_prediction_model, name="non_prediction_model")
546      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
547      result = loaded_model.predict("What is 2 + 2?")
548  
549      assert isinstance(result, dict)
550      assert result == {"answer": "4", "custom_field": "custom_value"}
551  
552      prediction_model = MockModelReturningPrediction()
553      model_info = mlflow.dspy.log_model(prediction_model, name="prediction_model")
554      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
555      result = loaded_model.predict("What is 2 + 2?")
556  
557      assert isinstance(result, dict)
558      assert result == {"answer": "4", "custom_field": "custom_value"}