test_job.py
1 """ 2 Unit tests for the optimize_prompts_job wrapper. 3 4 These tests focus on the helper functions and job function logic without 5 requiring a full job execution infrastructure. 6 """ 7 8 import sys 9 from unittest import mock 10 11 import pytest 12 13 import mlflow 14 from mlflow.exceptions import MlflowException 15 from mlflow.genai.optimize.job import ( 16 OptimizerType, 17 _build_predict_fn, 18 _create_optimizer, 19 _load_scorers, 20 optimize_prompts_job, 21 ) 22 from mlflow.genai.optimize.optimizers import GepaPromptOptimizer, MetaPromptOptimizer 23 from mlflow.genai.scorers import scorer 24 from mlflow.genai.scorers.builtin_scorers import Correctness, Safety 25 from mlflow.protos.prompt_optimization_pb2 import ( 26 OPTIMIZER_TYPE_GEPA, 27 OPTIMIZER_TYPE_METAPROMPT, 28 OPTIMIZER_TYPE_UNSPECIFIED, 29 ) 30 31 32 def test_create_gepa_optimizer_success(): 33 config = {"reflection_model": "openai:/gpt-4o", "max_metric_calls": 50} 34 optimizer = _create_optimizer("gepa", config) 35 assert isinstance(optimizer, GepaPromptOptimizer) 36 assert optimizer.reflection_model == "openai:/gpt-4o" 37 assert optimizer.max_metric_calls == 50 38 39 40 def test_create_gepa_optimizer_case_insensitive(): 41 config = {"reflection_model": "openai:/gpt-4o"} 42 optimizer = _create_optimizer("GEPA", config) 43 assert isinstance(optimizer, GepaPromptOptimizer) 44 45 46 def test_create_gepa_optimizer_missing_reflection_model(): 47 config = {"max_metric_calls": 50} 48 with pytest.raises(MlflowException, match="'reflection_model' must be specified"): 49 _create_optimizer("gepa", config) 50 51 52 def test_create_metaprompt_optimizer_success(): 53 config = {"reflection_model": "openai:/gpt-4o", "guidelines": "Be concise"} 54 optimizer = _create_optimizer("metaprompt", config) 55 assert isinstance(optimizer, MetaPromptOptimizer) 56 57 58 def test_create_metaprompt_optimizer_missing_reflection_model(): 59 config = {"guidelines": "Be concise"} 60 with pytest.raises(MlflowException, match="'reflection_model' must be specified"): 61 _create_optimizer("metaprompt", config) 62 63 64 def test_create_optimizer_unsupported_type(): 65 with pytest.raises(MlflowException, match="Unsupported optimizer type: 'invalid'"): 66 _create_optimizer("invalid", None) 67 68 69 @pytest.mark.parametrize( 70 ("proto_value", "expected_type", "expected_str", "error_match"), 71 [ 72 (OPTIMIZER_TYPE_GEPA, OptimizerType.GEPA, "gepa", None), 73 (OPTIMIZER_TYPE_METAPROMPT, OptimizerType.METAPROMPT, "metaprompt", None), 74 (OPTIMIZER_TYPE_UNSPECIFIED, None, None, "optimizer_type is required"), 75 (999, None, None, "Unsupported optimizer_type value"), 76 ], 77 ) 78 def test_optimizer_type_from_proto(proto_value, expected_type, expected_str, error_match): 79 if error_match: 80 with pytest.raises(MlflowException, match=error_match): 81 OptimizerType.from_proto(proto_value) 82 else: 83 result = OptimizerType.from_proto(proto_value) 84 assert result == expected_type 85 assert result == expected_str 86 87 88 @pytest.mark.parametrize( 89 ("optimizer_type", "expected_proto"), 90 [ 91 (OptimizerType.GEPA, OPTIMIZER_TYPE_GEPA), 92 (OptimizerType.METAPROMPT, OPTIMIZER_TYPE_METAPROMPT), 93 ], 94 ) 95 def test_optimizer_type_to_proto(optimizer_type, expected_proto): 96 assert optimizer_type.to_proto() == expected_proto 97 98 99 def test_load_builtin_scorers(): 100 scorers = _load_scorers(["Correctness", "Safety"], "exp-123") 101 102 assert len(scorers) == 2 103 assert isinstance(scorers[0], Correctness) 104 assert isinstance(scorers[1], Safety) 105 106 107 def test_load_custom_scorers(): 108 with ( 109 mock.patch("mlflow.genai.scorers.base.is_databricks_uri", return_value=True), 110 ): 111 experiment_id = mlflow.create_experiment("test_load_custom_scorers") 112 113 @scorer 114 def custom_scorer_1(outputs) -> bool: 115 return len(outputs) > 0 116 117 @scorer 118 def custom_scorer_2(outputs) -> bool: 119 return len(outputs) > 0 120 121 custom_scorer_1.register(experiment_id=experiment_id, name="custom_scorer_1") 122 custom_scorer_2.register(experiment_id=experiment_id, name="custom_scorer_2") 123 124 scorers = _load_scorers(["custom_scorer_1", "custom_scorer_2"], experiment_id) 125 126 assert len(scorers) == 2 127 assert scorers[0].name == "custom_scorer_1" 128 assert scorers[1].name == "custom_scorer_2" 129 130 mlflow.delete_experiment(experiment_id) 131 132 133 def test_load_scorer_not_found_raises_error(): 134 experiment_id = mlflow.create_experiment("test_load_scorer_not_found") 135 136 with pytest.raises(MlflowException, match="Scorer 'unknown_scorer' not found"): 137 _load_scorers(["unknown_scorer"], experiment_id) 138 139 mlflow.delete_experiment(experiment_id) 140 141 142 def test_build_predict_fn_success(): 143 mock_prompt = mock.MagicMock() 144 mock_prompt.model_config = {"provider": "openai", "model_name": "gpt-4o"} 145 mock_prompt.format.return_value = "formatted prompt" 146 147 mock_litellm = mock.MagicMock() 148 mock_response = mock.MagicMock() 149 mock_response.choices = [mock.MagicMock()] 150 mock_response.choices[0].message.content = "response text" 151 mock_litellm.completion.return_value = mock_response 152 153 with ( 154 mock.patch("mlflow.genai.optimize.job.load_prompt", return_value=mock_prompt), 155 mock.patch.dict("sys.modules", {"litellm": mock_litellm}), 156 ): 157 predict_fn = _build_predict_fn("prompts:/test/1") 158 result = predict_fn(question="What is AI?") 159 160 assert result == "response text" 161 mock_litellm.completion.assert_called_once() 162 call_args = mock_litellm.completion.call_args 163 assert call_args.kwargs["model"] == "openai/gpt-4o" 164 mock_prompt.format.assert_called_with(question="What is AI?") 165 166 167 def test_build_predict_fn_missing_model_config(): 168 mock_prompt = mock.MagicMock() 169 mock_prompt.model_config = None 170 171 mock_litellm = mock.MagicMock() 172 173 with ( 174 mock.patch("mlflow.genai.optimize.job.load_prompt", return_value=mock_prompt), 175 mock.patch.dict("sys.modules", {"litellm": mock_litellm}), 176 ): 177 with pytest.raises(MlflowException, match="doesn't have a model configuration"): 178 _build_predict_fn("prompts:/test/1") 179 180 181 def test_build_predict_fn_missing_provider(): 182 mock_prompt = mock.MagicMock() 183 mock_prompt.model_config = {"model_name": "gpt-4o"} 184 185 mock_litellm = mock.MagicMock() 186 187 with ( 188 mock.patch("mlflow.genai.optimize.job.load_prompt", return_value=mock_prompt), 189 mock.patch.dict("sys.modules", {"litellm": mock_litellm}), 190 ): 191 with pytest.raises(MlflowException, match="doesn't have a model configuration"): 192 _build_predict_fn("prompts:/test/1") 193 194 195 def test_build_predict_fn_missing_litellm(): 196 # Simulate litellm not being installed 197 with mock.patch.dict(sys.modules, {"litellm": None}): 198 with pytest.raises( 199 MlflowException, match="'litellm' package is required for prompt optimization" 200 ): 201 _build_predict_fn("prompts:/test/1") 202 203 204 def test_optimize_prompts_job_has_metadata(): 205 assert hasattr(optimize_prompts_job, "_job_fn_metadata") 206 metadata = optimize_prompts_job._job_fn_metadata 207 assert metadata.name == "optimize_prompts" 208 assert metadata.max_workers == 2 209 210 211 def test_optimize_prompts_job_calls(): 212 mock_dataset = mock.MagicMock() 213 214 mock_prompt = mock.MagicMock() 215 mock_prompt.model_config = {"provider": "openai", "model_name": "gpt-4o"} 216 217 mock_optimizer = mock.MagicMock() 218 mock_loaded_scorers = [mock.MagicMock(), mock.MagicMock()] 219 mock_predict_fn = mock.MagicMock() 220 221 mock_result = mock.MagicMock() 222 mock_result.optimized_prompts = [mock.MagicMock()] 223 mock_result.optimized_prompts[0].uri = "prompts:/test/2" 224 mock_result.optimizer_name = "GepaPromptOptimizer" 225 mock_result.initial_eval_score = 0.5 226 mock_result.final_eval_score = 0.9 227 228 optimizer_config = {"reflection_model": "openai:/gpt-4o"} 229 230 with ( 231 mock.patch("mlflow.genai.optimize.job.get_dataset", return_value=mock_dataset), 232 mock.patch("mlflow.genai.optimize.job.load_prompt", return_value=mock_prompt), 233 mock.patch( 234 "mlflow.genai.optimize.job._create_optimizer", return_value=mock_optimizer 235 ) as mock_create_optimizer, 236 mock.patch( 237 "mlflow.genai.optimize.job._load_scorers", return_value=mock_loaded_scorers 238 ) as mock_load_scorers, 239 mock.patch( 240 "mlflow.genai.optimize.job._build_predict_fn", return_value=mock_predict_fn 241 ) as mock_build_predict_fn, 242 mock.patch("mlflow.genai.optimize.job.set_experiment"), 243 mock.patch("mlflow.genai.optimize.job.start_run"), 244 mock.patch("mlflow.genai.optimize.job.MlflowClient"), 245 mock.patch( 246 "mlflow.genai.optimize.job.optimize_prompts", return_value=mock_result 247 ) as mock_optimize_prompts, 248 ): 249 optimize_prompts_job( 250 run_id="run-123", 251 experiment_id="exp-123", 252 prompt_uri="prompts:/test/1", 253 dataset_id="dataset-123", 254 optimizer_type="gepa", 255 optimizer_config=optimizer_config, 256 scorer_names=["Correctness", "Safety"], 257 ) 258 259 # Verify _create_optimizer was called with correct args 260 mock_create_optimizer.assert_called_once_with("gepa", optimizer_config) 261 262 # Verify _load_scorers was called with correct args 263 mock_load_scorers.assert_called_once_with(["Correctness", "Safety"], "exp-123") 264 265 # Verify _build_predict_fn was called with correct args 266 mock_build_predict_fn.assert_called_once_with("prompts:/test/1") 267 268 # Verify optimize_prompts was called with correct args 269 mock_optimize_prompts.assert_called_once_with( 270 predict_fn=mock_predict_fn, 271 train_data=mock_dataset, 272 prompt_uris=["prompts:/test/1"], 273 optimizer=mock_optimizer, 274 scorers=mock_loaded_scorers, 275 enable_tracking=True, 276 ) 277 278 279 def test_optimize_prompts_job_result_structure(): 280 mock_dataset = mock.MagicMock() 281 282 mock_prompt = mock.MagicMock() 283 mock_prompt.model_config = {"provider": "openai", "model_name": "gpt-4o"} 284 285 mock_optimizer = mock.MagicMock() 286 mock_result = mock.MagicMock() 287 mock_result.optimized_prompts = [mock.MagicMock()] 288 mock_result.optimized_prompts[0].uri = "prompts:/test/2" 289 mock_result.optimizer_name = "GepaPromptOptimizer" 290 mock_result.initial_eval_score = 0.5 291 mock_result.final_eval_score = 0.9 292 293 optimizer_config = {"reflection_model": "openai:/gpt-4o"} 294 295 with ( 296 mock.patch("mlflow.genai.optimize.job.get_dataset", return_value=mock_dataset), 297 mock.patch("mlflow.genai.optimize.job.load_prompt", return_value=mock_prompt), 298 mock.patch("mlflow.genai.optimize.job._create_optimizer", return_value=mock_optimizer), 299 mock.patch("mlflow.genai.optimize.job._load_scorers", return_value=[mock.MagicMock()]), 300 mock.patch("mlflow.genai.optimize.job._build_predict_fn", return_value=lambda **k: "r"), 301 mock.patch("mlflow.genai.optimize.job.set_experiment"), 302 mock.patch("mlflow.genai.optimize.job.start_run"), 303 mock.patch("mlflow.genai.optimize.job.MlflowClient"), 304 mock.patch("mlflow.genai.optimize.job.optimize_prompts", return_value=mock_result), 305 ): 306 result = optimize_prompts_job( 307 run_id="run-123", 308 experiment_id="exp-123", 309 prompt_uri="prompts:/test/1", 310 dataset_id="dataset-123", 311 optimizer_type="gepa", 312 optimizer_config=optimizer_config, 313 scorer_names=["Correctness", "Safety"], 314 ) 315 316 # Verify result structure (returned as dict from asdict()) 317 assert result["run_id"] == "run-123" 318 assert result["source_prompt_uri"] == "prompts:/test/1" 319 assert result["optimized_prompt_uri"] == "prompts:/test/2" 320 assert result["optimizer_name"] == "GepaPromptOptimizer" 321 assert result["initial_eval_score"] == 0.5 322 assert result["final_eval_score"] == 0.9 323 assert result["dataset_id"] == "dataset-123" 324 assert result["scorer_names"] == ["Correctness", "Safety"]