/ test / python / testapi / testopenai.py
testopenai.py
  1  """
  2  OpenAI API module tests
  3  """
  4  
  5  import os
  6  import tempfile
  7  import unittest
  8  
  9  from unittest.mock import patch
 10  
 11  from fastapi.testclient import TestClient
 12  
 13  from txtai.api import application
 14  
 15  # pylint: disable=C0411
 16  from utils import Utils
 17  
 18  # API Configuration
 19  CONFIG = """
 20  # Enable OpenAI-compatible API
 21  openai: True
 22  
 23  # Allow indexing of documents
 24  writable: True
 25  
 26  # Agent configuration
 27  agent:
 28      hello:
 29          max_iterations: 1
 30  
 31  # Embeddings settings
 32  embeddings:
 33      path: sentence-transformers/nli-mpnet-base-v2
 34      content: True        
 35  
 36  # LLM configuration
 37  llm:
 38      path: hf-internal-testing/tiny-random-LlamaForCausalLM
 39  
 40  # Text segmentation
 41  segmentation:
 42  
 43  # Text to speech
 44  texttospeech:
 45  
 46  # Transcription
 47  transcription:
 48  
 49  # Workflow
 50  workflow:
 51      echo:
 52          tasks:
 53              - task: console
 54  """
 55  
 56  
 57  # pylint: disable=R0904
 58  class TestOpenAI(unittest.TestCase):
 59      """
 60      Tests for OpenAI-compatible API endpoint for txtai.
 61      """
 62  
 63      @staticmethod
 64      @patch.dict(os.environ, {"CONFIG": os.path.join(tempfile.gettempdir(), "testopenai.yml"), "API_CLASS": "txtai.api.API"})
 65      def start():
 66          """
 67          Starts a mock FastAPI client.
 68          """
 69  
 70          config = os.path.join(tempfile.gettempdir(), "testopenai.yml")
 71  
 72          with open(config, "w", encoding="utf-8") as output:
 73              output.write(CONFIG)
 74  
 75          # Create new application and set on client
 76          application.app = application.create()
 77          client = TestClient(application.app)
 78          application.start()
 79  
 80          # Patch LLM to generate answer
 81          agent = application.get().agents["hello"]
 82          agent.process.model.llm = lambda *args, **kwargs: 'Action:\n{"name": "final_answer", "arguments": "Hi"}'
 83  
 84          return client
 85  
 86      @classmethod
 87      def setUpClass(cls):
 88          """
 89          Create API client on creation of class.
 90          """
 91  
 92          cls.client = TestOpenAI.start()
 93  
 94          cls.data = [
 95              "US tops 5 million confirmed virus cases",
 96              "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg",
 97              "Beijing mobilises invasion craft along coast as Taiwan tensions escalate",
 98              "The National Park Service warns against sacrificing slower friends in a bear attack",
 99              "Maine man wins $1M from $25 lottery ticket",
100              "Make huge profits without work, earn up to $100,000 a day",
101          ]
102  
103          # Index data
104          cls.client.post("add", json=[{"id": x, "text": row} for x, row in enumerate(cls.data)])
105          cls.client.get("index")
106  
107      def testChatAgent(self):
108          """
109          Test a chat completion with an agent
110          """
111  
112          response = self.client.post("/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}], "model": "hello"}).json()
113  
114          self.assertEqual(response["choices"][0]["message"]["content"], "Hi")
115  
116      def testChatLLM(self):
117          """
118          Test a chat completion with a LLM
119          """
120  
121          response = self.client.post("/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}], "model": "llm"}).json()
122  
123          self.assertIsNotNone(response["choices"][0]["message"]["content"])
124  
125      def testChatPipeline(self):
126          """
127          Test a chat completion with a pipeline
128          """
129  
130          response = self.client.post("/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}], "model": "segmentation"}).json()
131  
132          self.assertEqual(response["choices"][0]["message"]["content"], "Hello")
133  
134      def testChatSearch(self):
135          """
136          Test a chat completion with an embeddings search
137          """
138  
139          response = self.client.post(
140              "/v1/chat/completions", json={"messages": [{"role": "user", "content": "feel good story"}], "model": "embeddings"}
141          ).json()
142  
143          self.assertEqual(response["choices"][0]["message"]["content"], self.data[4])
144  
145      def testChatStream(self):
146          """
147          Test a chat completion with a LLM
148          """
149  
150          response = self.client.post("/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}], "model": "llm", "stream": True})
151  
152          self.assertGreater(len(response.text.split("\n\n")), 0)
153  
154      def testChatWorkflow(self):
155          """
156          Test a chat completion with a workflow
157          """
158  
159          response = self.client.post("/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}], "model": "echo"}).json()
160  
161          self.assertEqual(response["choices"][0]["message"]["content"], "Hello")
162  
163      def testEmbeddings(self):
164          """
165          Test generating embeddings vectors
166          """
167  
168          response = self.client.post("/v1/embeddings", json={"input": "text to embed", "model": "nli-mpnet-base-v2"}).json()
169  
170          self.assertEqual(len(response["data"][0]["embedding"]), 768)
171  
172      def testSpeech(self):
173          """
174          Test generating speech for input text
175          """
176  
177          response = self.client.post(
178              "/v1/audio/speech", json={"model": "tts", "input": "text to speak", "voice": "default", "response_format": "wav"}
179          ).content
180  
181          self.assertTrue(response[0:4] == b"RIFF")
182  
183      def testTranscribe(self):
184          """
185          Test audio to text transcription
186          """
187  
188          path = Utils.PATH + "/Make_huge_profits.wav"
189          with open(path, "rb") as f:
190              text = self.client.post("/v1/audio/transcriptions", files={"file": f}).json()["text"]
191              self.assertEqual(text, "Make huge profits without working make up to one hundred thousand dollars a day")
192  
193      def testTranslate(self):
194          """
195          Test audio translation
196          """
197  
198          path = Utils.PATH + "/Make_huge_profits.wav"
199          with open(path, "rb") as f:
200              text = self.client.post("/v1/audio/translations", files={"file": f}).json()["text"]
201              self.assertEqual(text, "Make huge profits without working make up to one hundred thousand dollars a day")