test_openai_compatibility.py
1 from unittest import mock 2 3 import openai 4 import pytest 5 6 from mlflow.gateway.providers.openai import OpenAIProvider 7 8 from tests.gateway.tools import ( 9 UvicornGateway, 10 save_yaml, 11 ) 12 13 14 @pytest.fixture(scope="module") 15 def config(): 16 return { 17 "endpoints": [ 18 { 19 "name": "chat", 20 "endpoint_type": "llm/v1/chat", 21 "model": { 22 "name": "gpt-4o-mini", 23 "provider": "openai", 24 "config": {"openai_api_key": "test"}, 25 }, 26 }, 27 { 28 "name": "completions", 29 "endpoint_type": "llm/v1/completions", 30 "model": { 31 "name": "gpt-4", 32 "provider": "openai", 33 "config": {"openai_api_key": "test"}, 34 }, 35 }, 36 { 37 "name": "embeddings", 38 "endpoint_type": "llm/v1/embeddings", 39 "model": { 40 "provider": "openai", 41 "name": "text-embedding-ada-002", 42 "config": { 43 "openai_api_key": "test", 44 }, 45 }, 46 }, 47 ] 48 } 49 50 51 @pytest.fixture 52 def server(config, tmp_path): 53 conf = tmp_path / "config.yaml" 54 save_yaml(conf, config) 55 with UvicornGateway(conf) as g: 56 yield g 57 58 59 @pytest.fixture 60 def client(server) -> openai.OpenAI: 61 return openai.OpenAI(base_url=f"{server.url}/v1", api_key="test") 62 63 64 def test_chat(client): 65 async def mock_chat(self, payload): 66 return { 67 "id": "chatcmpl-abc123", 68 "object": "chat.completion", 69 "created": 1677858242, 70 "model": "gpt-4o-mini", 71 "choices": [ 72 { 73 "message": { 74 "role": "assistant", 75 "content": "test", 76 }, 77 "finish_reason": "stop", 78 "index": 0, 79 } 80 ], 81 "usage": { 82 "prompt_tokens": 13, 83 "completion_tokens": 7, 84 "total_tokens": 20, 85 }, 86 } 87 88 with mock.patch.object(OpenAIProvider, "chat", mock_chat): 89 chat = client.chat.completions.create( 90 model="chat", messages=[{"role": "user", "content": "hello"}] 91 ) 92 assert chat.choices[0].message.content == "test" 93 94 95 def test_chat_invalid_endpoint(client): 96 with pytest.raises(openai.BadRequestError, match="is not a chat endpoint"): 97 client.chat.completions.create( 98 model="completions", messages=[{"role": "user", "content": "hello"}] 99 ) 100 101 102 def test_completions(client): 103 async def mock_completions(self, payload): 104 return { 105 "id": "cmpl-abc123", 106 "object": "text_completion", 107 "created": 1677858242, 108 "model": "gpt-4", 109 "choices": [ 110 { 111 "finish_reason": "length", 112 "index": 0, 113 "logprobs": None, 114 "text": "test", 115 } 116 ], 117 "usage": {"prompt_tokens": 4, "completion_tokens": 4, "total_tokens": 11}, 118 } 119 120 with mock.patch.object(OpenAIProvider, "completions", mock_completions): 121 completions = client.completions.create( 122 model="completions", 123 prompt="hello", 124 ) 125 assert completions.choices[0].text == "test" 126 127 128 def test_completions_invalid_endpoint(client): 129 with pytest.raises(openai.BadRequestError, match="is not a completions endpoint"): 130 client.completions.create(model="chat", prompt="hello") 131 132 133 def test_embeddings(client): 134 async def mock_embeddings(self, payload): 135 return { 136 "object": "list", 137 "data": [ 138 { 139 "object": "embedding", 140 "embedding": [ 141 0.1, 142 0.2, 143 0.3, 144 ], 145 "index": 0, 146 } 147 ], 148 "model": "text-embedding-ada-002", 149 "usage": {"prompt_tokens": 4, "total_tokens": 4}, 150 } 151 152 with mock.patch.object(OpenAIProvider, "embeddings", mock_embeddings): 153 embeddings = client.embeddings.create(model="embeddings", input="hello") 154 assert embeddings.data[0].embedding == [0.1, 0.2, 0.3] 155 156 157 def test_embeddings_invalid_endpoint(client): 158 with pytest.raises(openai.BadRequestError, match="is not an embeddings endpoint"): 159 client.embeddings.create(model="chat", input="hello")