/ tests / genai / optimize / test_optimize.py
test_optimize.py
  1  from typing import Any
  2  
  3  import pandas as pd
  4  import pytest
  5  
  6  import mlflow
  7  from mlflow.entities.model_registry import PromptModelConfig
  8  from mlflow.exceptions import MlflowException
  9  from mlflow.genai.datasets import create_dataset
 10  from mlflow.genai.optimize.optimize import optimize_prompts
 11  from mlflow.genai.optimize.optimizers.base import BasePromptOptimizer
 12  from mlflow.genai.optimize.types import EvaluationResultRecord, PromptOptimizerOutput
 13  from mlflow.genai.prompts import register_prompt
 14  from mlflow.genai.scorers import scorer
 15  from mlflow.models.model import PromptVersion
 16  from mlflow.utils.import_hooks import _post_import_hooks
 17  
 18  
 19  class MockPromptOptimizer(BasePromptOptimizer):
 20      def __init__(self, reflection_model="openai:/gpt-4o-mini"):
 21          self.model_name = reflection_model
 22  
 23      def optimize(
 24          self,
 25          eval_fn: Any,
 26          train_data: list[dict[str, Any]],
 27          target_prompts: dict[str, str],
 28          enable_tracking: bool = True,
 29      ) -> PromptOptimizerOutput:
 30          optimized_prompts = {}
 31          for prompt_name, template in target_prompts.items():
 32              # Simple optimization: add "Be precise and accurate. " prefix
 33              optimized_prompts[prompt_name] = f"Be precise and accurate. {template}"
 34  
 35          # Verify the optimization by calling eval_fn (only if provided)
 36          if eval_fn is not None:
 37              eval_fn(optimized_prompts, train_data)
 38  
 39          return PromptOptimizerOutput(
 40              optimized_prompts=optimized_prompts,
 41              initial_eval_score=0.5,
 42              final_eval_score=0.9,
 43          )
 44  
 45  
 46  @pytest.fixture
 47  def sample_translation_prompt() -> PromptVersion:
 48      return register_prompt(
 49          name="test_translation_prompt",
 50          template="Translate the following text to {{language}}: {{input_text}}",
 51      )
 52  
 53  
 54  @pytest.fixture
 55  def sample_summarization_prompt() -> PromptVersion:
 56      return register_prompt(
 57          name="test_summarization_prompt",
 58          template="Summarize this text: {{text}}",
 59      )
 60  
 61  
 62  @pytest.fixture
 63  def sample_dataset() -> pd.DataFrame:
 64      return pd.DataFrame({
 65          "inputs": [
 66              {"input_text": "Hello", "language": "Spanish"},
 67              {"input_text": "World", "language": "French"},
 68              {"input_text": "Goodbye", "language": "Spanish"},
 69          ],
 70          "outputs": [
 71              "Hola",
 72              "Monde",
 73              "Adiós",
 74          ],
 75      })
 76  
 77  
 78  @pytest.fixture
 79  def sample_summarization_dataset() -> list[dict[str, Any]]:
 80      return [
 81          {
 82              "inputs": {
 83                  "text": "This is a long document that needs to be summarized into key points."
 84              },
 85              "outputs": "Key points summary",
 86          },
 87          {
 88              "inputs": {"text": "Another document with important information for summarization."},
 89              "outputs": "Important info summary",
 90          },
 91      ]
 92  
 93  
 94  def sample_predict_fn(input_text: str, language: str) -> str:
 95      mlflow.genai.load_prompt("prompts:/test_translation_prompt/1")
 96      translations = {
 97          ("Hello", "Spanish"): "Hola",
 98          ("World", "French"): "Monde",
 99          ("Goodbye", "Spanish"): "Adiós",
100      }
101  
102      # Verify that auto logging is enabled during the evaluation.
103      assert len(_post_import_hooks) > 0
104      return translations.get((input_text, language), f"translated_{input_text}")
105  
106  
107  def sample_summarization_fn(text: str) -> str:
108      return f"Summary of: {text[:20]}..."
109  
110  
111  @mlflow.genai.scorers.scorer(name="equivalence")
112  def equivalence(outputs, expectations):
113      return 1.0 if outputs == expectations["expected_response"] else 0.0
114  
115  
116  def test_optimize_prompts_single_prompt(
117      sample_translation_prompt: PromptVersion, sample_dataset: pd.DataFrame
118  ):
119      mock_optimizer = MockPromptOptimizer()
120  
121      result = optimize_prompts(
122          predict_fn=sample_predict_fn,
123          train_data=sample_dataset,
124          prompt_uris=[
125              f"prompts:/{sample_translation_prompt.name}/{sample_translation_prompt.version}"
126          ],
127          optimizer=mock_optimizer,
128          scorers=[equivalence],
129      )
130  
131      assert len(result.optimized_prompts) == 1
132      optimized_prompt = result.optimized_prompts[0]
133      assert optimized_prompt.name == sample_translation_prompt.name
134      assert optimized_prompt.version == sample_translation_prompt.version + 1
135      assert "Be precise and accurate." in optimized_prompt.template
136      expected_template = "Translate the following text to {{language}}: {{input_text}}"
137      assert expected_template in optimized_prompt.template
138      assert result.initial_eval_score == 0.5
139      assert result.final_eval_score == 0.9
140  
141  
142  def test_optimize_prompts_multiple_prompts(
143      sample_translation_prompt: PromptVersion,
144      sample_summarization_prompt: PromptVersion,
145      sample_dataset: pd.DataFrame,
146  ):
147      mock_optimizer = MockPromptOptimizer()
148  
149      result = optimize_prompts(
150          predict_fn=sample_predict_fn,
151          train_data=sample_dataset,
152          prompt_uris=[
153              f"prompts:/{sample_translation_prompt.name}/{sample_translation_prompt.version}",
154              f"prompts:/{sample_summarization_prompt.name}/{sample_summarization_prompt.version}",
155          ],
156          optimizer=mock_optimizer,
157          scorers=[equivalence],
158      )
159  
160      assert len(result.optimized_prompts) == 2
161      prompt_names = {prompt.name for prompt in result.optimized_prompts}
162      assert sample_translation_prompt.name in prompt_names
163      assert sample_summarization_prompt.name in prompt_names
164      assert result.initial_eval_score == 0.5
165      assert result.final_eval_score == 0.9
166  
167      for prompt in result.optimized_prompts:
168          assert "Be precise and accurate." in prompt.template
169  
170  
171  def test_optimize_prompts_eval_function_behavior(
172      sample_translation_prompt: PromptVersion, sample_dataset: pd.DataFrame
173  ):
174      class TestingOptimizer(BasePromptOptimizer):
175          def __init__(self):
176              self.model_name = "openai:/gpt-4o-mini"
177              self.eval_fn_calls = []
178  
179          def optimize(self, eval_fn, dataset, target_prompts, enable_tracking=True):
180              # Test that eval_fn works correctly
181              test_prompts = {
182                  "test_translation_prompt": "Prompt Candidate: "
183                  "Translate {{input_text}} to {{language}}"
184              }
185              results = eval_fn(test_prompts, dataset)
186              self.eval_fn_calls.append((test_prompts, results))
187  
188              # Verify results structure
189              assert isinstance(results, list)
190              assert len(results) == len(dataset)
191              for i, result in enumerate(results):
192                  assert isinstance(result, EvaluationResultRecord)
193                  assert result.inputs == dataset[i]["inputs"]
194                  assert result.outputs == dataset[i]["outputs"]
195                  assert result.score == 1
196                  assert result.trace is not None
197  
198              return PromptOptimizerOutput(optimized_prompts=target_prompts)
199  
200      predict_called_count = 0
201  
202      def predict_fn(input_text, language):
203          prompt = mlflow.genai.load_prompt("prompts:/test_translation_prompt/1").format(
204              input_text=input_text, language=language
205          )
206          nonlocal predict_called_count
207          # the first call to the predict_fn is the model check
208          if predict_called_count > 0:
209              # validate the prompt is replaced with the candidate prompt
210              assert "Prompt Candidate" in prompt
211          predict_called_count += 1
212  
213          return sample_predict_fn(input_text=input_text, language=language)
214  
215      testing_optimizer = TestingOptimizer()
216  
217      optimize_prompts(
218          predict_fn=predict_fn,
219          train_data=sample_dataset,
220          prompt_uris=[
221              f"prompts:/{sample_translation_prompt.name}/{sample_translation_prompt.version}"
222          ],
223          optimizer=testing_optimizer,
224          scorers=[equivalence],
225      )
226  
227      assert len(testing_optimizer.eval_fn_calls) == 1
228      _, eval_results = testing_optimizer.eval_fn_calls[0]
229      assert len(eval_results) == 3  # Number of records in sample_dataset
230      assert predict_called_count == 4  # 3 records in sample_dataset + 1 for the prediction check
231  
232  
233  def test_optimize_prompts_with_list_dataset(
234      sample_translation_prompt: PromptVersion, sample_summarization_dataset: list[dict[str, Any]]
235  ):
236      mock_optimizer = MockPromptOptimizer()
237  
238      def summarization_predict_fn(text):
239          return f"Summary: {text[:10]}..."
240  
241      result = optimize_prompts(
242          predict_fn=summarization_predict_fn,
243          train_data=sample_summarization_dataset,
244          prompt_uris=[
245              f"prompts:/{sample_translation_prompt.name}/{sample_translation_prompt.version}"
246          ],
247          optimizer=mock_optimizer,
248          scorers=[equivalence],
249      )
250  
251      assert len(result.optimized_prompts) == 1
252      assert result.initial_eval_score == 0.5
253      assert result.final_eval_score == 0.9
254  
255  
256  def test_optimize_prompts_with_model_name(
257      sample_translation_prompt: PromptVersion, sample_dataset: pd.DataFrame
258  ):
259      class TestOptimizer(BasePromptOptimizer):
260          def __init__(self):
261              self.model_name = "test/custom-model"
262  
263          def optimize(self, eval_fn, dataset, target_prompts, enable_tracking=True):
264              return PromptOptimizerOutput(optimized_prompts=target_prompts)
265  
266      testing_optimizer = TestOptimizer()
267  
268      result = optimize_prompts(
269          predict_fn=sample_predict_fn,
270          train_data=sample_dataset,
271          prompt_uris=[
272              f"prompts:/{sample_translation_prompt.name}/{sample_translation_prompt.version}"
273          ],
274          optimizer=testing_optimizer,
275          scorers=[equivalence],
276      )
277  
278      assert len(result.optimized_prompts) == 1
279  
280  
281  def test_optimize_prompts_warns_on_unused_prompt(
282      sample_translation_prompt: PromptVersion,
283      sample_summarization_prompt: PromptVersion,
284      sample_dataset: pd.DataFrame,
285      capsys,
286  ):
287      mock_optimizer = MockPromptOptimizer()
288  
289      # Create predict_fn that only uses translation prompt, not summarization prompt
290      def predict_fn_single_prompt(input_text, language):
291          prompt = mlflow.genai.load_prompt("prompts:/test_translation_prompt/1")
292          prompt.format(input_text=input_text, language=language)
293          return sample_predict_fn(input_text=input_text, language=language)
294  
295      result = optimize_prompts(
296          predict_fn=predict_fn_single_prompt,
297          train_data=sample_dataset,
298          prompt_uris=[
299              f"prompts:/{sample_translation_prompt.name}/{sample_translation_prompt.version}",
300              f"prompts:/{sample_summarization_prompt.name}/{sample_summarization_prompt.version}",
301          ],
302          optimizer=mock_optimizer,
303          scorers=[equivalence],
304      )
305  
306      assert len(result.optimized_prompts) == 2
307  
308      captured = capsys.readouterr()
309      assert "prompts were not used during evaluation" in captured.err
310      assert "test_summarization_prompt" in captured.err
311  
312  
313  def test_optimize_prompts_with_custom_scorers(
314      sample_translation_prompt: PromptVersion, sample_dataset: pd.DataFrame
315  ):
316      # Create a custom scorer for case-insensitive matching
317      @scorer(name="case_insensitive_match")
318      def case_insensitive_match(outputs, expectations):
319          # Extract expected_response if expectations is a dict
320          if isinstance(expectations, dict) and "expected_response" in expectations:
321              expected_value = expectations["expected_response"]
322          else:
323              expected_value = expectations
324          return 1.0 if str(outputs).lower() == str(expected_value).lower() else 0.5
325  
326      class MetricTestOptimizer(BasePromptOptimizer):
327          def __init__(self):
328              self.model_name = "openai:/gpt-4o-mini"
329              self.captured_scores = []
330  
331          def optimize(self, eval_fn, dataset, target_prompts, enable_tracking=True):
332              # Run eval_fn and capture the scores
333              results = eval_fn(target_prompts, dataset)
334              self.captured_scores = [r.score for r in results]
335              return PromptOptimizerOutput(optimized_prompts=target_prompts)
336  
337      testing_optimizer = MetricTestOptimizer()
338  
339      # Create dataset with outputs that will test custom scorer
340      test_dataset = pd.DataFrame({
341          "inputs": [
342              {"input_text": "Hello", "language": "Spanish"},
343              {"input_text": "World", "language": "French"},
344          ],
345          "outputs": ["HOLA", "monde"],  # Different cases to test custom scorer
346      })
347  
348      def predict_fn(input_text, language):
349          mlflow.genai.load_prompt("prompts:/test_translation_prompt/1")
350          # Return lowercase outputs
351          return {"Hello": "hola", "World": "monde"}.get(input_text, "unknown")
352  
353      result = optimize_prompts(
354          predict_fn=predict_fn,
355          train_data=test_dataset,
356          prompt_uris=[
357              f"prompts:/{sample_translation_prompt.name}/{sample_translation_prompt.version}"
358          ],
359          scorers=[case_insensitive_match],
360          optimizer=testing_optimizer,
361      )
362  
363      # Verify custom scorer was used
364      # "hola" vs "HOLA" (case insensitive match) -> 1.0
365      # "monde" vs "monde" (exact match) -> 1.0
366      assert testing_optimizer.captured_scores == [1.0, 1.0]
367      assert len(result.optimized_prompts) == 1
368  
369  
370  @pytest.mark.parametrize(
371      ("train_data", "error_match"),
372      [
373          # Missing inputs validation (handled by _convert_eval_set_to_df)
374          ([{"outputs": "Hola"}], "Either `inputs` or `trace` column is required"),
375          # Empty inputs validation
376          (
377              [{"inputs": {}, "outputs": "Hola"}],
378              "Record 0 is missing required 'inputs' field or it is empty",
379          ),
380      ],
381  )
382  def test_optimize_prompts_validation_errors(
383      sample_translation_prompt: PromptVersion,
384      train_data: list[dict[str, Any]],
385      error_match: str,
386  ):
387      with pytest.raises(MlflowException, match=error_match):
388          optimize_prompts(
389              predict_fn=sample_predict_fn,
390              train_data=train_data,
391              prompt_uris=[
392                  f"prompts:/{sample_translation_prompt.name}/{sample_translation_prompt.version}"
393              ],
394              optimizer=MockPromptOptimizer(),
395              scorers=[equivalence],
396          )
397  
398  
399  def test_optimize_prompts_with_chat_prompt(
400      sample_translation_prompt: PromptVersion, sample_dataset: pd.DataFrame
401  ):
402      chat_prompt = register_prompt(
403          name="test_chat_prompt",
404          template=[{"role": "user", "content": "{{input_text}}"}],
405      )
406      with pytest.raises(MlflowException, match="Only text prompts can be optimized"):
407          optimize_prompts(
408              predict_fn=sample_predict_fn,
409              train_data=sample_dataset,
410              prompt_uris=[f"prompts:/{chat_prompt.name}/{chat_prompt.version}"],
411              optimizer=MockPromptOptimizer(),
412              scorers=[equivalence],
413          )
414  
415  
416  def test_optimize_prompts_with_managed_evaluation_dataset(
417      sample_translation_prompt: PromptVersion,
418      sample_dataset: pd.DataFrame,
419  ):
420      # Create a `ManagedEvaluationDataset` and populate it with records from sample_dataset
421      managed_dataset = create_dataset(name="test_optimize_managed_dataset")
422      managed_dataset.merge_records(sample_dataset)
423  
424      result = optimize_prompts(
425          predict_fn=sample_predict_fn,
426          train_data=managed_dataset,
427          prompt_uris=[
428              f"prompts:/{sample_translation_prompt.name}/{sample_translation_prompt.version}"
429          ],
430          optimizer=MockPromptOptimizer(),
431          scorers=[equivalence],
432      )
433  
434      assert len(result.optimized_prompts) == 1
435      assert result.initial_eval_score == 0.5
436      assert result.final_eval_score == 0.9
437  
438  
439  def test_optimize_prompts_preserves_model_config(sample_dataset: pd.DataFrame):
440      source_model_config = PromptModelConfig(
441          provider="openai",
442          model_name="gpt-4o",
443          temperature=0.7,
444          max_tokens=1000,
445      )
446  
447      prompt_with_config = register_prompt(
448          name="test_prompt_with_model_config",
449          template="Translate the following text to {{language}}: {{input_text}}",
450          model_config=source_model_config,
451      )
452  
453      assert prompt_with_config.model_config is not None
454  
455      def predict_fn(input_text: str, language: str) -> str:
456          mlflow.genai.load_prompt(f"prompts:/{prompt_with_config.name}/1")
457          translations = {
458              ("Hello", "Spanish"): "Hola",
459              ("World", "French"): "Monde",
460              ("Goodbye", "Spanish"): "Adiós",
461          }
462          return translations.get((input_text, language), f"translated_{input_text}")
463  
464      result = optimize_prompts(
465          predict_fn=predict_fn,
466          train_data=sample_dataset,
467          prompt_uris=[f"prompts:/{prompt_with_config.name}/{prompt_with_config.version}"],
468          optimizer=MockPromptOptimizer(),
469          scorers=[equivalence],
470      )
471  
472      assert len(result.optimized_prompts) == 1
473      optimized_prompt = result.optimized_prompts[0]
474  
475      assert optimized_prompt.model_config["provider"] == "openai"
476      assert optimized_prompt.model_config["model_name"] == "gpt-4o"
477      assert optimized_prompt.model_config["temperature"] == 0.7
478      assert optimized_prompt.model_config["max_tokens"] == 1000