test_xai.py
1 from unittest import mock 2 3 import pytest 4 from fastapi.encoders import jsonable_encoder 5 6 from mlflow.gateway.config import EndpointConfig 7 from mlflow.gateway.providers.xai import XAIProvider 8 from mlflow.gateway.schemas import chat 9 10 from tests.gateway.tools import MockAsyncResponse, mock_http_client 11 12 13 def _make_provider() -> XAIProvider: 14 endpoint_config = EndpointConfig( 15 name="xai-endpoint", 16 endpoint_type="llm/v1/chat", 17 model={ 18 "provider": "xai", 19 "name": "grok-3", 20 "config": {"api_key": "xai-test-key"}, 21 }, 22 ) 23 return XAIProvider(endpoint_config) 24 25 26 def _chat_response(): 27 return { 28 "id": "chatcmpl-xai-123", 29 "object": "chat.completion", 30 "created": 1700000000, 31 "model": "grok-3", 32 "usage": { 33 "prompt_tokens": 10, 34 "completion_tokens": 20, 35 "total_tokens": 30, 36 }, 37 "choices": [ 38 { 39 "message": {"role": "assistant", "content": "Hello from Grok!"}, 40 "finish_reason": "stop", 41 "index": 0, 42 } 43 ], 44 "headers": {"Content-Type": "application/json"}, 45 } 46 47 48 def test_default_api_base(): 49 provider = _make_provider() 50 assert provider._api_base == "https://api.x.ai/v1" 51 52 53 def test_headers(): 54 provider = _make_provider() 55 assert provider.headers == {"Authorization": "Bearer xai-test-key"} 56 57 58 def test_name(): 59 provider = _make_provider() 60 assert provider.DISPLAY_NAME == "xAI" 61 assert provider.get_provider_name() == "xai" 62 63 64 @pytest.mark.asyncio 65 async def test_chat(): 66 provider = _make_provider() 67 mock_client = mock_http_client(MockAsyncResponse(_chat_response())) 68 69 with mock.patch("aiohttp.ClientSession", return_value=mock_client): 70 payload = chat.RequestPayload( 71 messages=[{"role": "user", "content": "Hello"}], 72 ) 73 response = await provider.chat(payload) 74 75 result = jsonable_encoder(response) 76 assert result["id"] == "chatcmpl-xai-123" 77 assert result["choices"][0]["message"]["content"] == "Hello from Grok!"