/ tests / gateway / test_openai_compatibility.py
test_openai_compatibility.py
  1  from unittest import mock
  2  
  3  import openai
  4  import pytest
  5  
  6  from mlflow.gateway.providers.openai import OpenAIProvider
  7  
  8  from tests.gateway.tools import (
  9      UvicornGateway,
 10      save_yaml,
 11  )
 12  
 13  
 14  @pytest.fixture(scope="module")
 15  def config():
 16      return {
 17          "endpoints": [
 18              {
 19                  "name": "chat",
 20                  "endpoint_type": "llm/v1/chat",
 21                  "model": {
 22                      "name": "gpt-4o-mini",
 23                      "provider": "openai",
 24                      "config": {"openai_api_key": "test"},
 25                  },
 26              },
 27              {
 28                  "name": "completions",
 29                  "endpoint_type": "llm/v1/completions",
 30                  "model": {
 31                      "name": "gpt-4",
 32                      "provider": "openai",
 33                      "config": {"openai_api_key": "test"},
 34                  },
 35              },
 36              {
 37                  "name": "embeddings",
 38                  "endpoint_type": "llm/v1/embeddings",
 39                  "model": {
 40                      "provider": "openai",
 41                      "name": "text-embedding-ada-002",
 42                      "config": {
 43                          "openai_api_key": "test",
 44                      },
 45                  },
 46              },
 47          ]
 48      }
 49  
 50  
 51  @pytest.fixture
 52  def server(config, tmp_path):
 53      conf = tmp_path / "config.yaml"
 54      save_yaml(conf, config)
 55      with UvicornGateway(conf) as g:
 56          yield g
 57  
 58  
 59  @pytest.fixture
 60  def client(server) -> openai.OpenAI:
 61      return openai.OpenAI(base_url=f"{server.url}/v1", api_key="test")
 62  
 63  
 64  def test_chat(client):
 65      async def mock_chat(self, payload):
 66          return {
 67              "id": "chatcmpl-abc123",
 68              "object": "chat.completion",
 69              "created": 1677858242,
 70              "model": "gpt-4o-mini",
 71              "choices": [
 72                  {
 73                      "message": {
 74                          "role": "assistant",
 75                          "content": "test",
 76                      },
 77                      "finish_reason": "stop",
 78                      "index": 0,
 79                  }
 80              ],
 81              "usage": {
 82                  "prompt_tokens": 13,
 83                  "completion_tokens": 7,
 84                  "total_tokens": 20,
 85              },
 86          }
 87  
 88      with mock.patch.object(OpenAIProvider, "chat", mock_chat):
 89          chat = client.chat.completions.create(
 90              model="chat", messages=[{"role": "user", "content": "hello"}]
 91          )
 92          assert chat.choices[0].message.content == "test"
 93  
 94  
 95  def test_chat_invalid_endpoint(client):
 96      with pytest.raises(openai.BadRequestError, match="is not a chat endpoint"):
 97          client.chat.completions.create(
 98              model="completions", messages=[{"role": "user", "content": "hello"}]
 99          )
100  
101  
102  def test_completions(client):
103      async def mock_completions(self, payload):
104          return {
105              "id": "cmpl-abc123",
106              "object": "text_completion",
107              "created": 1677858242,
108              "model": "gpt-4",
109              "choices": [
110                  {
111                      "finish_reason": "length",
112                      "index": 0,
113                      "logprobs": None,
114                      "text": "test",
115                  }
116              ],
117              "usage": {"prompt_tokens": 4, "completion_tokens": 4, "total_tokens": 11},
118          }
119  
120      with mock.patch.object(OpenAIProvider, "completions", mock_completions):
121          completions = client.completions.create(
122              model="completions",
123              prompt="hello",
124          )
125          assert completions.choices[0].text == "test"
126  
127  
128  def test_completions_invalid_endpoint(client):
129      with pytest.raises(openai.BadRequestError, match="is not a completions endpoint"):
130          client.completions.create(model="chat", prompt="hello")
131  
132  
133  def test_embeddings(client):
134      async def mock_embeddings(self, payload):
135          return {
136              "object": "list",
137              "data": [
138                  {
139                      "object": "embedding",
140                      "embedding": [
141                          0.1,
142                          0.2,
143                          0.3,
144                      ],
145                      "index": 0,
146                  }
147              ],
148              "model": "text-embedding-ada-002",
149              "usage": {"prompt_tokens": 4, "total_tokens": 4},
150          }
151  
152      with mock.patch.object(OpenAIProvider, "embeddings", mock_embeddings):
153          embeddings = client.embeddings.create(model="embeddings", input="hello")
154          assert embeddings.data[0].embedding == [0.1, 0.2, 0.3]
155  
156  
157  def test_embeddings_invalid_endpoint(client):
158      with pytest.raises(openai.BadRequestError, match="is not an embeddings endpoint"):
159          client.embeddings.create(model="chat", input="hello")