/ tests / gateway / providers / test_xai.py
test_xai.py
 1  from unittest import mock
 2  
 3  import pytest
 4  from fastapi.encoders import jsonable_encoder
 5  
 6  from mlflow.gateway.config import EndpointConfig
 7  from mlflow.gateway.providers.xai import XAIProvider
 8  from mlflow.gateway.schemas import chat
 9  
10  from tests.gateway.tools import MockAsyncResponse, mock_http_client
11  
12  
13  def _make_provider() -> XAIProvider:
14      endpoint_config = EndpointConfig(
15          name="xai-endpoint",
16          endpoint_type="llm/v1/chat",
17          model={
18              "provider": "xai",
19              "name": "grok-3",
20              "config": {"api_key": "xai-test-key"},
21          },
22      )
23      return XAIProvider(endpoint_config)
24  
25  
26  def _chat_response():
27      return {
28          "id": "chatcmpl-xai-123",
29          "object": "chat.completion",
30          "created": 1700000000,
31          "model": "grok-3",
32          "usage": {
33              "prompt_tokens": 10,
34              "completion_tokens": 20,
35              "total_tokens": 30,
36          },
37          "choices": [
38              {
39                  "message": {"role": "assistant", "content": "Hello from Grok!"},
40                  "finish_reason": "stop",
41                  "index": 0,
42              }
43          ],
44          "headers": {"Content-Type": "application/json"},
45      }
46  
47  
48  def test_default_api_base():
49      provider = _make_provider()
50      assert provider._api_base == "https://api.x.ai/v1"
51  
52  
53  def test_headers():
54      provider = _make_provider()
55      assert provider.headers == {"Authorization": "Bearer xai-test-key"}
56  
57  
58  def test_name():
59      provider = _make_provider()
60      assert provider.DISPLAY_NAME == "xAI"
61      assert provider.get_provider_name() == "xai"
62  
63  
64  @pytest.mark.asyncio
65  async def test_chat():
66      provider = _make_provider()
67      mock_client = mock_http_client(MockAsyncResponse(_chat_response()))
68  
69      with mock.patch("aiohttp.ClientSession", return_value=mock_client):
70          payload = chat.RequestPayload(
71              messages=[{"role": "user", "content": "Hello"}],
72          )
73          response = await provider.chat(payload)
74  
75      result = jsonable_encoder(response)
76      assert result["id"] == "chatcmpl-xai-123"
77      assert result["choices"][0]["message"]["content"] == "Hello from Grok!"