/ tests / genai / optimize / test_job.py
test_job.py
  1  """
  2  Unit tests for the optimize_prompts_job wrapper.
  3  
  4  These tests focus on the helper functions and job function logic without
  5  requiring a full job execution infrastructure.
  6  """
  7  
  8  import sys
  9  from unittest import mock
 10  
 11  import pytest
 12  
 13  import mlflow
 14  from mlflow.exceptions import MlflowException
 15  from mlflow.genai.optimize.job import (
 16      OptimizerType,
 17      _build_predict_fn,
 18      _create_optimizer,
 19      _load_scorers,
 20      optimize_prompts_job,
 21  )
 22  from mlflow.genai.optimize.optimizers import GepaPromptOptimizer, MetaPromptOptimizer
 23  from mlflow.genai.scorers import scorer
 24  from mlflow.genai.scorers.builtin_scorers import Correctness, Safety
 25  from mlflow.protos.prompt_optimization_pb2 import (
 26      OPTIMIZER_TYPE_GEPA,
 27      OPTIMIZER_TYPE_METAPROMPT,
 28      OPTIMIZER_TYPE_UNSPECIFIED,
 29  )
 30  
 31  
 32  def test_create_gepa_optimizer_success():
 33      config = {"reflection_model": "openai:/gpt-4o", "max_metric_calls": 50}
 34      optimizer = _create_optimizer("gepa", config)
 35      assert isinstance(optimizer, GepaPromptOptimizer)
 36      assert optimizer.reflection_model == "openai:/gpt-4o"
 37      assert optimizer.max_metric_calls == 50
 38  
 39  
 40  def test_create_gepa_optimizer_case_insensitive():
 41      config = {"reflection_model": "openai:/gpt-4o"}
 42      optimizer = _create_optimizer("GEPA", config)
 43      assert isinstance(optimizer, GepaPromptOptimizer)
 44  
 45  
 46  def test_create_gepa_optimizer_missing_reflection_model():
 47      config = {"max_metric_calls": 50}
 48      with pytest.raises(MlflowException, match="'reflection_model' must be specified"):
 49          _create_optimizer("gepa", config)
 50  
 51  
 52  def test_create_metaprompt_optimizer_success():
 53      config = {"reflection_model": "openai:/gpt-4o", "guidelines": "Be concise"}
 54      optimizer = _create_optimizer("metaprompt", config)
 55      assert isinstance(optimizer, MetaPromptOptimizer)
 56  
 57  
 58  def test_create_metaprompt_optimizer_missing_reflection_model():
 59      config = {"guidelines": "Be concise"}
 60      with pytest.raises(MlflowException, match="'reflection_model' must be specified"):
 61          _create_optimizer("metaprompt", config)
 62  
 63  
 64  def test_create_optimizer_unsupported_type():
 65      with pytest.raises(MlflowException, match="Unsupported optimizer type: 'invalid'"):
 66          _create_optimizer("invalid", None)
 67  
 68  
 69  @pytest.mark.parametrize(
 70      ("proto_value", "expected_type", "expected_str", "error_match"),
 71      [
 72          (OPTIMIZER_TYPE_GEPA, OptimizerType.GEPA, "gepa", None),
 73          (OPTIMIZER_TYPE_METAPROMPT, OptimizerType.METAPROMPT, "metaprompt", None),
 74          (OPTIMIZER_TYPE_UNSPECIFIED, None, None, "optimizer_type is required"),
 75          (999, None, None, "Unsupported optimizer_type value"),
 76      ],
 77  )
 78  def test_optimizer_type_from_proto(proto_value, expected_type, expected_str, error_match):
 79      if error_match:
 80          with pytest.raises(MlflowException, match=error_match):
 81              OptimizerType.from_proto(proto_value)
 82      else:
 83          result = OptimizerType.from_proto(proto_value)
 84          assert result == expected_type
 85          assert result == expected_str
 86  
 87  
 88  @pytest.mark.parametrize(
 89      ("optimizer_type", "expected_proto"),
 90      [
 91          (OptimizerType.GEPA, OPTIMIZER_TYPE_GEPA),
 92          (OptimizerType.METAPROMPT, OPTIMIZER_TYPE_METAPROMPT),
 93      ],
 94  )
 95  def test_optimizer_type_to_proto(optimizer_type, expected_proto):
 96      assert optimizer_type.to_proto() == expected_proto
 97  
 98  
 99  def test_load_builtin_scorers():
100      scorers = _load_scorers(["Correctness", "Safety"], "exp-123")
101  
102      assert len(scorers) == 2
103      assert isinstance(scorers[0], Correctness)
104      assert isinstance(scorers[1], Safety)
105  
106  
107  def test_load_custom_scorers():
108      with (
109          mock.patch("mlflow.genai.scorers.base.is_databricks_uri", return_value=True),
110      ):
111          experiment_id = mlflow.create_experiment("test_load_custom_scorers")
112  
113          @scorer
114          def custom_scorer_1(outputs) -> bool:
115              return len(outputs) > 0
116  
117          @scorer
118          def custom_scorer_2(outputs) -> bool:
119              return len(outputs) > 0
120  
121          custom_scorer_1.register(experiment_id=experiment_id, name="custom_scorer_1")
122          custom_scorer_2.register(experiment_id=experiment_id, name="custom_scorer_2")
123  
124          scorers = _load_scorers(["custom_scorer_1", "custom_scorer_2"], experiment_id)
125  
126          assert len(scorers) == 2
127          assert scorers[0].name == "custom_scorer_1"
128          assert scorers[1].name == "custom_scorer_2"
129  
130          mlflow.delete_experiment(experiment_id)
131  
132  
133  def test_load_scorer_not_found_raises_error():
134      experiment_id = mlflow.create_experiment("test_load_scorer_not_found")
135  
136      with pytest.raises(MlflowException, match="Scorer 'unknown_scorer' not found"):
137          _load_scorers(["unknown_scorer"], experiment_id)
138  
139      mlflow.delete_experiment(experiment_id)
140  
141  
142  def test_build_predict_fn_success():
143      mock_prompt = mock.MagicMock()
144      mock_prompt.model_config = {"provider": "openai", "model_name": "gpt-4o"}
145      mock_prompt.format.return_value = "formatted prompt"
146  
147      mock_litellm = mock.MagicMock()
148      mock_response = mock.MagicMock()
149      mock_response.choices = [mock.MagicMock()]
150      mock_response.choices[0].message.content = "response text"
151      mock_litellm.completion.return_value = mock_response
152  
153      with (
154          mock.patch("mlflow.genai.optimize.job.load_prompt", return_value=mock_prompt),
155          mock.patch.dict("sys.modules", {"litellm": mock_litellm}),
156      ):
157          predict_fn = _build_predict_fn("prompts:/test/1")
158          result = predict_fn(question="What is AI?")
159  
160          assert result == "response text"
161          mock_litellm.completion.assert_called_once()
162          call_args = mock_litellm.completion.call_args
163          assert call_args.kwargs["model"] == "openai/gpt-4o"
164          mock_prompt.format.assert_called_with(question="What is AI?")
165  
166  
167  def test_build_predict_fn_missing_model_config():
168      mock_prompt = mock.MagicMock()
169      mock_prompt.model_config = None
170  
171      mock_litellm = mock.MagicMock()
172  
173      with (
174          mock.patch("mlflow.genai.optimize.job.load_prompt", return_value=mock_prompt),
175          mock.patch.dict("sys.modules", {"litellm": mock_litellm}),
176      ):
177          with pytest.raises(MlflowException, match="doesn't have a model configuration"):
178              _build_predict_fn("prompts:/test/1")
179  
180  
181  def test_build_predict_fn_missing_provider():
182      mock_prompt = mock.MagicMock()
183      mock_prompt.model_config = {"model_name": "gpt-4o"}
184  
185      mock_litellm = mock.MagicMock()
186  
187      with (
188          mock.patch("mlflow.genai.optimize.job.load_prompt", return_value=mock_prompt),
189          mock.patch.dict("sys.modules", {"litellm": mock_litellm}),
190      ):
191          with pytest.raises(MlflowException, match="doesn't have a model configuration"):
192              _build_predict_fn("prompts:/test/1")
193  
194  
195  def test_build_predict_fn_missing_litellm():
196      # Simulate litellm not being installed
197      with mock.patch.dict(sys.modules, {"litellm": None}):
198          with pytest.raises(
199              MlflowException, match="'litellm' package is required for prompt optimization"
200          ):
201              _build_predict_fn("prompts:/test/1")
202  
203  
204  def test_optimize_prompts_job_has_metadata():
205      assert hasattr(optimize_prompts_job, "_job_fn_metadata")
206      metadata = optimize_prompts_job._job_fn_metadata
207      assert metadata.name == "optimize_prompts"
208      assert metadata.max_workers == 2
209  
210  
211  def test_optimize_prompts_job_calls():
212      mock_dataset = mock.MagicMock()
213  
214      mock_prompt = mock.MagicMock()
215      mock_prompt.model_config = {"provider": "openai", "model_name": "gpt-4o"}
216  
217      mock_optimizer = mock.MagicMock()
218      mock_loaded_scorers = [mock.MagicMock(), mock.MagicMock()]
219      mock_predict_fn = mock.MagicMock()
220  
221      mock_result = mock.MagicMock()
222      mock_result.optimized_prompts = [mock.MagicMock()]
223      mock_result.optimized_prompts[0].uri = "prompts:/test/2"
224      mock_result.optimizer_name = "GepaPromptOptimizer"
225      mock_result.initial_eval_score = 0.5
226      mock_result.final_eval_score = 0.9
227  
228      optimizer_config = {"reflection_model": "openai:/gpt-4o"}
229  
230      with (
231          mock.patch("mlflow.genai.optimize.job.get_dataset", return_value=mock_dataset),
232          mock.patch("mlflow.genai.optimize.job.load_prompt", return_value=mock_prompt),
233          mock.patch(
234              "mlflow.genai.optimize.job._create_optimizer", return_value=mock_optimizer
235          ) as mock_create_optimizer,
236          mock.patch(
237              "mlflow.genai.optimize.job._load_scorers", return_value=mock_loaded_scorers
238          ) as mock_load_scorers,
239          mock.patch(
240              "mlflow.genai.optimize.job._build_predict_fn", return_value=mock_predict_fn
241          ) as mock_build_predict_fn,
242          mock.patch("mlflow.genai.optimize.job.set_experiment"),
243          mock.patch("mlflow.genai.optimize.job.start_run"),
244          mock.patch("mlflow.genai.optimize.job.MlflowClient"),
245          mock.patch(
246              "mlflow.genai.optimize.job.optimize_prompts", return_value=mock_result
247          ) as mock_optimize_prompts,
248      ):
249          optimize_prompts_job(
250              run_id="run-123",
251              experiment_id="exp-123",
252              prompt_uri="prompts:/test/1",
253              dataset_id="dataset-123",
254              optimizer_type="gepa",
255              optimizer_config=optimizer_config,
256              scorer_names=["Correctness", "Safety"],
257          )
258  
259          # Verify _create_optimizer was called with correct args
260          mock_create_optimizer.assert_called_once_with("gepa", optimizer_config)
261  
262          # Verify _load_scorers was called with correct args
263          mock_load_scorers.assert_called_once_with(["Correctness", "Safety"], "exp-123")
264  
265          # Verify _build_predict_fn was called with correct args
266          mock_build_predict_fn.assert_called_once_with("prompts:/test/1")
267  
268          # Verify optimize_prompts was called with correct args
269          mock_optimize_prompts.assert_called_once_with(
270              predict_fn=mock_predict_fn,
271              train_data=mock_dataset,
272              prompt_uris=["prompts:/test/1"],
273              optimizer=mock_optimizer,
274              scorers=mock_loaded_scorers,
275              enable_tracking=True,
276          )
277  
278  
279  def test_optimize_prompts_job_result_structure():
280      mock_dataset = mock.MagicMock()
281  
282      mock_prompt = mock.MagicMock()
283      mock_prompt.model_config = {"provider": "openai", "model_name": "gpt-4o"}
284  
285      mock_optimizer = mock.MagicMock()
286      mock_result = mock.MagicMock()
287      mock_result.optimized_prompts = [mock.MagicMock()]
288      mock_result.optimized_prompts[0].uri = "prompts:/test/2"
289      mock_result.optimizer_name = "GepaPromptOptimizer"
290      mock_result.initial_eval_score = 0.5
291      mock_result.final_eval_score = 0.9
292  
293      optimizer_config = {"reflection_model": "openai:/gpt-4o"}
294  
295      with (
296          mock.patch("mlflow.genai.optimize.job.get_dataset", return_value=mock_dataset),
297          mock.patch("mlflow.genai.optimize.job.load_prompt", return_value=mock_prompt),
298          mock.patch("mlflow.genai.optimize.job._create_optimizer", return_value=mock_optimizer),
299          mock.patch("mlflow.genai.optimize.job._load_scorers", return_value=[mock.MagicMock()]),
300          mock.patch("mlflow.genai.optimize.job._build_predict_fn", return_value=lambda **k: "r"),
301          mock.patch("mlflow.genai.optimize.job.set_experiment"),
302          mock.patch("mlflow.genai.optimize.job.start_run"),
303          mock.patch("mlflow.genai.optimize.job.MlflowClient"),
304          mock.patch("mlflow.genai.optimize.job.optimize_prompts", return_value=mock_result),
305      ):
306          result = optimize_prompts_job(
307              run_id="run-123",
308              experiment_id="exp-123",
309              prompt_uri="prompts:/test/1",
310              dataset_id="dataset-123",
311              optimizer_type="gepa",
312              optimizer_config=optimizer_config,
313              scorer_names=["Correctness", "Safety"],
314          )
315  
316          # Verify result structure (returned as dict from asdict())
317          assert result["run_id"] == "run-123"
318          assert result["source_prompt_uri"] == "prompts:/test/1"
319          assert result["optimized_prompt_uri"] == "prompts:/test/2"
320          assert result["optimizer_name"] == "GepaPromptOptimizer"
321          assert result["initial_eval_score"] == 0.5
322          assert result["final_eval_score"] == 0.9
323          assert result["dataset_id"] == "dataset-123"
324          assert result["scorer_names"] == ["Correctness", "Safety"]