testopenai.py
1 """ 2 OpenAI API module tests 3 """ 4 5 import os 6 import tempfile 7 import unittest 8 9 from unittest.mock import patch 10 11 from fastapi.testclient import TestClient 12 13 from txtai.api import application 14 15 # pylint: disable=C0411 16 from utils import Utils 17 18 # API Configuration 19 CONFIG = """ 20 # Enable OpenAI-compatible API 21 openai: True 22 23 # Allow indexing of documents 24 writable: True 25 26 # Agent configuration 27 agent: 28 hello: 29 max_iterations: 1 30 31 # Embeddings settings 32 embeddings: 33 path: sentence-transformers/nli-mpnet-base-v2 34 content: True 35 36 # LLM configuration 37 llm: 38 path: hf-internal-testing/tiny-random-LlamaForCausalLM 39 40 # Text segmentation 41 segmentation: 42 43 # Text to speech 44 texttospeech: 45 46 # Transcription 47 transcription: 48 49 # Workflow 50 workflow: 51 echo: 52 tasks: 53 - task: console 54 """ 55 56 57 # pylint: disable=R0904 58 class TestOpenAI(unittest.TestCase): 59 """ 60 Tests for OpenAI-compatible API endpoint for txtai. 61 """ 62 63 @staticmethod 64 @patch.dict(os.environ, {"CONFIG": os.path.join(tempfile.gettempdir(), "testopenai.yml"), "API_CLASS": "txtai.api.API"}) 65 def start(): 66 """ 67 Starts a mock FastAPI client. 68 """ 69 70 config = os.path.join(tempfile.gettempdir(), "testopenai.yml") 71 72 with open(config, "w", encoding="utf-8") as output: 73 output.write(CONFIG) 74 75 # Create new application and set on client 76 application.app = application.create() 77 client = TestClient(application.app) 78 application.start() 79 80 # Patch LLM to generate answer 81 agent = application.get().agents["hello"] 82 agent.process.model.llm = lambda *args, **kwargs: 'Action:\n{"name": "final_answer", "arguments": "Hi"}' 83 84 return client 85 86 @classmethod 87 def setUpClass(cls): 88 """ 89 Create API client on creation of class. 90 """ 91 92 cls.client = TestOpenAI.start() 93 94 cls.data = [ 95 "US tops 5 million confirmed virus cases", 96 "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 97 "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 98 "The National Park Service warns against sacrificing slower friends in a bear attack", 99 "Maine man wins $1M from $25 lottery ticket", 100 "Make huge profits without work, earn up to $100,000 a day", 101 ] 102 103 # Index data 104 cls.client.post("add", json=[{"id": x, "text": row} for x, row in enumerate(cls.data)]) 105 cls.client.get("index") 106 107 def testChatAgent(self): 108 """ 109 Test a chat completion with an agent 110 """ 111 112 response = self.client.post("/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}], "model": "hello"}).json() 113 114 self.assertEqual(response["choices"][0]["message"]["content"], "Hi") 115 116 def testChatLLM(self): 117 """ 118 Test a chat completion with a LLM 119 """ 120 121 response = self.client.post("/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}], "model": "llm"}).json() 122 123 self.assertIsNotNone(response["choices"][0]["message"]["content"]) 124 125 def testChatPipeline(self): 126 """ 127 Test a chat completion with a pipeline 128 """ 129 130 response = self.client.post("/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}], "model": "segmentation"}).json() 131 132 self.assertEqual(response["choices"][0]["message"]["content"], "Hello") 133 134 def testChatSearch(self): 135 """ 136 Test a chat completion with an embeddings search 137 """ 138 139 response = self.client.post( 140 "/v1/chat/completions", json={"messages": [{"role": "user", "content": "feel good story"}], "model": "embeddings"} 141 ).json() 142 143 self.assertEqual(response["choices"][0]["message"]["content"], self.data[4]) 144 145 def testChatStream(self): 146 """ 147 Test a chat completion with a LLM 148 """ 149 150 response = self.client.post("/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}], "model": "llm", "stream": True}) 151 152 self.assertGreater(len(response.text.split("\n\n")), 0) 153 154 def testChatWorkflow(self): 155 """ 156 Test a chat completion with a workflow 157 """ 158 159 response = self.client.post("/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}], "model": "echo"}).json() 160 161 self.assertEqual(response["choices"][0]["message"]["content"], "Hello") 162 163 def testEmbeddings(self): 164 """ 165 Test generating embeddings vectors 166 """ 167 168 response = self.client.post("/v1/embeddings", json={"input": "text to embed", "model": "nli-mpnet-base-v2"}).json() 169 170 self.assertEqual(len(response["data"][0]["embedding"]), 768) 171 172 def testSpeech(self): 173 """ 174 Test generating speech for input text 175 """ 176 177 response = self.client.post( 178 "/v1/audio/speech", json={"model": "tts", "input": "text to speak", "voice": "default", "response_format": "wav"} 179 ).content 180 181 self.assertTrue(response[0:4] == b"RIFF") 182 183 def testTranscribe(self): 184 """ 185 Test audio to text transcription 186 """ 187 188 path = Utils.PATH + "/Make_huge_profits.wav" 189 with open(path, "rb") as f: 190 text = self.client.post("/v1/audio/transcriptions", files={"file": f}).json()["text"] 191 self.assertEqual(text, "Make huge profits without working make up to one hundred thousand dollars a day") 192 193 def testTranslate(self): 194 """ 195 Test audio translation 196 """ 197 198 path = Utils.PATH + "/Make_huge_profits.wav" 199 with open(path, "rb") as f: 200 text = self.client.post("/v1/audio/translations", files={"file": f}).json()["text"] 201 self.assertEqual(text, "Make huge profits without working make up to one hundred thousand dollars a day")