/ test / python / testapi / testapipipeline.py
testapipipeline.py
  1  """
  2  Pipeline API module tests
  3  """
  4  
  5  import os
  6  import tempfile
  7  import unittest
  8  import urllib
  9  
 10  from unittest.mock import patch
 11  
 12  from fastapi.testclient import TestClient
 13  
 14  from txtai.api import API, application
 15  
 16  # pylint: disable=C0411
 17  from utils import Utils
 18  
 19  # Configuration for pipelines
 20  PIPELINES = """
 21  # Image captions
 22  caption:
 23  
 24  # Entity extraction
 25  entity:
 26      path: dslim/bert-base-NER
 27  
 28  # Extractor settings
 29  extractor:
 30      similarity: similarity
 31      path: llm
 32  
 33  # Label settings
 34  labels:
 35      path: prajjwal1/bert-medium-mnli
 36  
 37  # LLM settings
 38  llm:
 39      path: hf-internal-testing/tiny-random-gpt2
 40      task: language-generation
 41  
 42  # Image objects
 43  objects:
 44  
 45  # Text segmentation
 46  segmentation:
 47      sentences: true
 48  
 49  # Enable pipeline similarity backed by zero shot classifier
 50  similarity:
 51  
 52  # Summarization
 53  summary:
 54      path: t5-small
 55  
 56  # Tabular
 57  tabular:
 58  
 59  # Text extraction
 60  textractor:
 61      safeopen: /tmp/txtai
 62  
 63  # Text to speech
 64  texttospeech:
 65  
 66  # Transcription
 67  transcription:
 68  
 69  # Translation:
 70  translation:
 71  
 72  # Enable file uploads
 73  upload:
 74  """
 75  
 76  
 77  # pylint: disable=R0904
 78  class TestPipeline(unittest.TestCase):
 79      """
 80      API tests for pipelines.
 81      """
 82  
 83      @staticmethod
 84      @patch.dict(os.environ, {"CONFIG": os.path.join(tempfile.gettempdir(), "testapi.yml"), "API_CLASS": "txtai.api.API"})
 85      def start():
 86          """
 87          Starts a mock FastAPI client.
 88          """
 89  
 90          config = os.path.join(tempfile.gettempdir(), "testapi.yml")
 91  
 92          with open(config, "w", encoding="utf-8") as output:
 93              output.write(PIPELINES)
 94  
 95          # Create new application and set on client
 96          application.app = application.create()
 97          client = TestClient(application.app)
 98          application.start()
 99  
100          return client
101  
102      @classmethod
103      def setUpClass(cls):
104          """
105          Create API client on creation of class.
106          """
107  
108          cls.client = TestPipeline.start()
109  
110          cls.data = [
111              "US tops 5 million confirmed virus cases",
112              "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg",
113              "Beijing mobilises invasion craft along coast as Taiwan tensions escalate",
114              "The National Park Service warns against sacrificing slower friends in a bear attack",
115              "Maine man wins $1M from $25 lottery ticket",
116              "Make huge profits without work, earn up to $100,000 a day",
117          ]
118  
119          cls.text = (
120              "Search is the base of many applications. Once data starts to pile up, users want to be able to find it. It's the foundation "
121              "of the internet and an ever-growing challenge that is never solved or done. The field of Natural Language Processing (NLP) is "
122              "rapidly evolving with a number of new developments. Large-scale general language models are an exciting new capability "
123              "allowing us to add amazing functionality quickly with limited compute and people. Innovation continues with new models "
124              "and advancements coming in at what seems a weekly basis. This article introduces txtai, an AI-powered search engine "
125              "that enables Natural Language Understanding (NLU) based search in any application."
126          )
127  
128      def testCaption(self):
129          """
130          Test caption via API
131          """
132  
133          caption = self.client.get(f"caption?file={Utils.PATH}/books.jpg").json()
134  
135          self.assertEqual(caption, "a book shelf filled with books and a stack of books")
136  
137      def testCaptionBatch(self):
138          """
139          Test batch caption via API
140          """
141  
142          path = Utils.PATH + "/books.jpg"
143  
144          captions = self.client.post("batchcaption", json=[path, path]).json()
145          self.assertEqual(captions, ["a book shelf filled with books and a stack of books"] * 2)
146  
147      def testEntity(self):
148          """
149          Test entity extraction via API
150          """
151  
152          entities = self.client.get(f"entity?text={self.data[1]}").json()
153          self.assertEqual([e[0] for e in entities], ["Canada", "Manhattan"])
154  
155      def testEntityBatch(self):
156          """
157          Test batch entity via API
158          """
159  
160          entities = self.client.post("batchentity", json=[self.data[1]]).json()
161          self.assertEqual([e[0] for e in entities[0]], ["Canada", "Manhattan"])
162  
163      def testEmpty(self):
164          """
165          Test empty API configuration
166          """
167  
168          api = API({})
169  
170          self.assertIsNone(api.label("test", ["test"]))
171          self.assertIsNone(api.pipeline("junk", "test"))
172  
173      def testLabel(self):
174          """
175          Test label via API
176          """
177  
178          labels = self.client.post("label", json={"text": "this is the best sentence ever", "labels": ["positive", "negative"]}).json()
179  
180          self.assertEqual(labels[0]["id"], 0)
181  
182      def testLabelBatch(self):
183          """
184          Test batch label via API
185          """
186  
187          labels = self.client.post(
188              "batchlabel", json={"texts": ["this is the best sentence ever", "This is terrible"], "labels": ["positive", "negative"]}
189          ).json()
190  
191          results = [l[0]["id"] for l in labels]
192          self.assertEqual(results, [0, 1])
193  
194      def testLLM(self):
195          """
196          Test LLM inference via API
197          """
198  
199          response = self.client.get("llm?text=test").json()
200          self.assertIsInstance(response, str)
201  
202      def testLLMBatch(self):
203          """
204          Test batch LLM inference via API
205          """
206  
207          response = self.client.post("batchllm", json={"texts": ["test", "test"]}).json()
208          self.assertEqual(len(response), 2)
209  
210      def testObjects(self):
211          """
212          Test objects via API
213          """
214  
215          objects = self.client.get(f"objects?file={Utils.PATH}/books.jpg").json()
216  
217          self.assertEqual(objects[0][0], "book")
218  
219      def testObjectsBatch(self):
220          """
221          Test batch objects via API
222          """
223  
224          path = Utils.PATH + "/books.jpg"
225  
226          objects = self.client.post("batchobjects", json=[path, path]).json()
227          self.assertEqual([o[0][0] for o in objects], ["book"] * 2)
228  
229      def testSegment(self):
230          """
231          Test segmentation via API
232          """
233  
234          text = self.client.get("segment?text=This is a test. And another test.").json()
235  
236          # Check array length is 2
237          self.assertEqual(len(text), 2)
238  
239      def testSegmentBatch(self):
240          """
241          Test batch segmentation via API
242          """
243  
244          text = "This is a test. And another test."
245          texts = self.client.post("batchsegment", json=[text, text]).json()
246  
247          # Check array length is 2 and first element length is 2
248          self.assertEqual(len(texts), 2)
249          self.assertEqual(len(texts[0]), 2)
250  
251      def testSimilarity(self):
252          """
253          Test similarity via API
254          """
255  
256          uid = self.client.post("similarity", json={"query": "feel good story", "texts": self.data}).json()[0]["id"]
257  
258          self.assertEqual(self.data[uid], self.data[4])
259  
260      def testSimilarityBatch(self):
261          """
262          Test batch similarity via API
263          """
264  
265          results = self.client.post("batchsimilarity", json={"queries": ["feel good story", "climate change"], "texts": self.data}).json()
266  
267          uids = [result[0]["id"] for result in results]
268          self.assertEqual(uids, [4, 1])
269  
270      def testSummary(self):
271          """
272          Test summary via API
273          """
274  
275          summary = self.client.get(f"summary?text={urllib.parse.quote(self.text)}&minlength=15&maxlength=15").json()
276          self.assertEqual(summary, "the field of natural language processing (NLP) is rapidly evolving")
277  
278      def testSummaryBatch(self):
279          """
280          Test batch summary via API
281          """
282  
283          summaries = self.client.post("batchsummary", json={"texts": [self.text, self.text], "minlength": 15, "maxlength": 15}).json()
284          self.assertEqual(summaries, ["the field of natural language processing (NLP) is rapidly evolving"] * 2)
285  
286      def testTabular(self):
287          """
288          Test tabular via API
289          """
290  
291          results = self.client.get(f"tabular?file={Utils.PATH}/tabular.csv").json()
292  
293          # Check length of results is as expected
294          self.assertEqual(len(results), 6)
295  
296      def testTabularBatch(self):
297          """
298          Test batch tabular via API
299          """
300  
301          path = Utils.PATH + "/tabular.csv"
302  
303          results = self.client.post("batchtabular", json=[path, path]).json()
304          self.assertEqual((len(results[0]), len(results[1])), (6, 6))
305  
306      def testTextractor(self):
307          """
308          Test textractor via API
309          """
310  
311          text = self.client.get(f"textract?file={Utils.PATH}/article.pdf").json()
312  
313          # Check length of text is as expected
314          self.assertEqual(len(text), 2471)
315  
316          # Check invalid URLs
317          for url in ["http://192.168.1.1/path", "http://127.0.0.1/path", "http://invalid", "/etc/config"]:
318              with self.assertRaises(IOError):
319                  self.client.get(f"textract?file={url}").json()
320  
321      def testTextractorBatch(self):
322          """
323          Test batch textractor via API
324          """
325  
326          path = Utils.PATH + "/article.pdf"
327  
328          texts = self.client.post("batchtextract", json=[path, path]).json()
329          self.assertEqual((len(texts[0]), len(texts[1])), (2471, 2471))
330  
331      def testTextToSpeech(self):
332          """
333          Test text to speech
334          """
335  
336          # Generate audio and check for WAV signature
337          audio = self.client.get("texttospeech?text=hello&encoding=wav").content
338          self.assertTrue(audio[0:4] == b"RIFF")
339  
340      def testTranscribe(self):
341          """
342          Test transcribe via API
343          """
344  
345          text = self.client.get(f"transcribe?file={Utils.PATH}/Make_huge_profits.wav").json()
346  
347          # Check length of text is as expected
348          self.assertEqual(text, "Make huge profits without working make up to one hundred thousand dollars a day")
349  
350      def testTranscribeBatch(self):
351          """
352          Test batch transcribe via API
353          """
354  
355          path = Utils.PATH + "/Make_huge_profits.wav"
356  
357          texts = self.client.post("batchtranscribe", json=[path, path]).json()
358          self.assertEqual(texts, ["Make huge profits without working make up to one hundred thousand dollars a day"] * 2)
359  
360      def testTranslate(self):
361          """
362          Test translate via API
363          """
364  
365          translation = self.client.get(f"translate?text={urllib.parse.quote('This is a test translation into Spanish')}&target=es").json()
366          self.assertEqual(translation, "Esta es una traducción de prueba al español")
367  
368      def testTranslateBatch(self):
369          """
370          Test batch translate via API
371          """
372  
373          text = "This is a test translation into Spanish"
374          translations = self.client.post("batchtranslate", json={"texts": [text, text], "target": "es"}).json()
375          self.assertEqual(translations, ["Esta es una traducción de prueba al español"] * 2)
376  
377      def testUpload(self):
378          """
379          Test file upload
380          """
381  
382          path = Utils.PATH + "/article.pdf"
383          with open(path, "rb") as f:
384              path = self.client.post("upload", files={"files": f}).json()[0]
385              self.assertTrue(os.path.exists(path))