/ tests / genai / simulators / test_utils.py
test_utils.py
  1  from unittest import mock
  2  
  3  import pytest
  4  
  5  from mlflow.genai.simulators.utils import (
  6      format_history,
  7      get_default_simulation_model,
  8      invoke_model_without_tracing,
  9  )
 10  
 11  
 12  @pytest.mark.parametrize(
 13      ("history", "expected"),
 14      [
 15          ([], None),
 16          ([{"role": "user", "content": "Hello"}], "user: Hello"),
 17          (
 18              [
 19                  {"role": "user", "content": "Hello"},
 20                  {"role": "assistant", "content": "Hi there!"},
 21                  {"role": "user", "content": "How are you?"},
 22              ],
 23              "user: Hello\nassistant: Hi there!\nuser: How are you?",
 24          ),
 25          ([{"content": "Hello"}], "unknown: Hello"),
 26          ([{"role": "user"}], "user: "),
 27          ([{"role": None, "content": None}], "unknown: "),
 28      ],
 29  )
 30  def test_format_history(history, expected):
 31      assert format_history(history) == expected
 32  
 33  
 34  @pytest.mark.parametrize(
 35      "model_uri",
 36      [
 37          "openai:/gpt-4o-mini",
 38          "anthropic:/claude-3-haiku",
 39      ],
 40  )
 41  def test_invoke_model_without_tracing_with_provider(model_uri):
 42      from mlflow.types.llm import ChatMessage
 43  
 44      messages = [ChatMessage(role="user", content="Hello")]
 45  
 46      with mock.patch(
 47          "mlflow.genai.scorers.llm_backend.ScorerLLMClient.complete", return_value="Hi there!"
 48      ) as mock_invoke:
 49          result = invoke_model_without_tracing(model_uri=model_uri, messages=messages)
 50  
 51          assert result == "Hi there!"
 52          mock_invoke.assert_called_once()
 53  
 54  
 55  def test_invoke_model_without_tracing_with_inference_params():
 56      from mlflow.types.llm import ChatMessage
 57  
 58      messages = [ChatMessage(role="user", content="Hello")]
 59  
 60      with mock.patch(
 61          "mlflow.genai.scorers.llm_backend.ScorerLLMClient.complete", return_value="Response"
 62      ) as mock_invoke:
 63          invoke_model_without_tracing(
 64              model_uri="openai:/gpt-4o-mini",
 65              messages=messages,
 66              inference_params={"temperature": 0.5},
 67          )
 68  
 69          mock_invoke.assert_called_once_with(
 70              [{"role": "user", "content": "Hello"}],
 71              response_format=None,
 72              num_retries=3,
 73              temperature=0.5,
 74          )
 75  
 76  
 77  @pytest.mark.parametrize("model_uri", ["databricks", "gpt-oss-120b"])
 78  def test_invoke_model_without_tracing_with_databricks(model_uri):
 79      from mlflow.types.llm import ChatMessage
 80  
 81      messages = [ChatMessage(role="user", content="Hello")]
 82  
 83      with (
 84          mock.patch("mlflow.genai.simulators.utils.call_chat_completions") as mock_call,
 85          mock.patch(
 86              "mlflow.genai.simulators.utils._create_message_from_databricks_response"
 87          ) as mock_create,
 88      ):
 89          mock_call.return_value = mock.MagicMock(error_code=None, output_json='{"content": "Hi"}')
 90          mock_create.return_value = mock.MagicMock(content="Hi from Databricks")
 91  
 92          result = invoke_model_without_tracing(model_uri=model_uri, messages=messages)
 93  
 94          assert result == "Hi from Databricks"
 95          mock_call.assert_called_once()
 96  
 97  
 98  def test_get_default_simulation_model_non_databricks():
 99      with mock.patch("mlflow.genai.simulators.utils.is_databricks_uri", return_value=False):
100          model = get_default_simulation_model()
101          assert model == "openai:/gpt-5"
102  
103  
104  def test_get_default_simulation_model_databricks():
105      with mock.patch("mlflow.genai.simulators.utils.is_databricks_uri", return_value=True):
106          model = get_default_simulation_model()
107          assert model == "gpt-oss-120b"