test_utils.py
1 from unittest import mock 2 3 import pytest 4 5 from mlflow.genai.simulators.utils import ( 6 format_history, 7 get_default_simulation_model, 8 invoke_model_without_tracing, 9 ) 10 11 12 @pytest.mark.parametrize( 13 ("history", "expected"), 14 [ 15 ([], None), 16 ([{"role": "user", "content": "Hello"}], "user: Hello"), 17 ( 18 [ 19 {"role": "user", "content": "Hello"}, 20 {"role": "assistant", "content": "Hi there!"}, 21 {"role": "user", "content": "How are you?"}, 22 ], 23 "user: Hello\nassistant: Hi there!\nuser: How are you?", 24 ), 25 ([{"content": "Hello"}], "unknown: Hello"), 26 ([{"role": "user"}], "user: "), 27 ([{"role": None, "content": None}], "unknown: "), 28 ], 29 ) 30 def test_format_history(history, expected): 31 assert format_history(history) == expected 32 33 34 @pytest.mark.parametrize( 35 "model_uri", 36 [ 37 "openai:/gpt-4o-mini", 38 "anthropic:/claude-3-haiku", 39 ], 40 ) 41 def test_invoke_model_without_tracing_with_provider(model_uri): 42 from mlflow.types.llm import ChatMessage 43 44 messages = [ChatMessage(role="user", content="Hello")] 45 46 with mock.patch( 47 "mlflow.genai.scorers.llm_backend.ScorerLLMClient.complete", return_value="Hi there!" 48 ) as mock_invoke: 49 result = invoke_model_without_tracing(model_uri=model_uri, messages=messages) 50 51 assert result == "Hi there!" 52 mock_invoke.assert_called_once() 53 54 55 def test_invoke_model_without_tracing_with_inference_params(): 56 from mlflow.types.llm import ChatMessage 57 58 messages = [ChatMessage(role="user", content="Hello")] 59 60 with mock.patch( 61 "mlflow.genai.scorers.llm_backend.ScorerLLMClient.complete", return_value="Response" 62 ) as mock_invoke: 63 invoke_model_without_tracing( 64 model_uri="openai:/gpt-4o-mini", 65 messages=messages, 66 inference_params={"temperature": 0.5}, 67 ) 68 69 mock_invoke.assert_called_once_with( 70 [{"role": "user", "content": "Hello"}], 71 response_format=None, 72 num_retries=3, 73 temperature=0.5, 74 ) 75 76 77 @pytest.mark.parametrize("model_uri", ["databricks", "gpt-oss-120b"]) 78 def test_invoke_model_without_tracing_with_databricks(model_uri): 79 from mlflow.types.llm import ChatMessage 80 81 messages = [ChatMessage(role="user", content="Hello")] 82 83 with ( 84 mock.patch("mlflow.genai.simulators.utils.call_chat_completions") as mock_call, 85 mock.patch( 86 "mlflow.genai.simulators.utils._create_message_from_databricks_response" 87 ) as mock_create, 88 ): 89 mock_call.return_value = mock.MagicMock(error_code=None, output_json='{"content": "Hi"}') 90 mock_create.return_value = mock.MagicMock(content="Hi from Databricks") 91 92 result = invoke_model_without_tracing(model_uri=model_uri, messages=messages) 93 94 assert result == "Hi from Databricks" 95 mock_call.assert_called_once() 96 97 98 def test_get_default_simulation_model_non_databricks(): 99 with mock.patch("mlflow.genai.simulators.utils.is_databricks_uri", return_value=False): 100 model = get_default_simulation_model() 101 assert model == "openai:/gpt-5" 102 103 104 def test_get_default_simulation_model_databricks(): 105 with mock.patch("mlflow.genai.simulators.utils.is_databricks_uri", return_value=True): 106 model = get_default_simulation_model() 107 assert model == "gpt-oss-120b"