testapipipeline.py
1 """ 2 Pipeline API module tests 3 """ 4 5 import os 6 import tempfile 7 import unittest 8 import urllib 9 10 from unittest.mock import patch 11 12 from fastapi.testclient import TestClient 13 14 from txtai.api import API, application 15 16 # pylint: disable=C0411 17 from utils import Utils 18 19 # Configuration for pipelines 20 PIPELINES = """ 21 # Image captions 22 caption: 23 24 # Entity extraction 25 entity: 26 path: dslim/bert-base-NER 27 28 # Extractor settings 29 extractor: 30 similarity: similarity 31 path: llm 32 33 # Label settings 34 labels: 35 path: prajjwal1/bert-medium-mnli 36 37 # LLM settings 38 llm: 39 path: hf-internal-testing/tiny-random-gpt2 40 task: language-generation 41 42 # Image objects 43 objects: 44 45 # Text segmentation 46 segmentation: 47 sentences: true 48 49 # Enable pipeline similarity backed by zero shot classifier 50 similarity: 51 52 # Summarization 53 summary: 54 path: t5-small 55 56 # Tabular 57 tabular: 58 59 # Text extraction 60 textractor: 61 safeopen: /tmp/txtai 62 63 # Text to speech 64 texttospeech: 65 66 # Transcription 67 transcription: 68 69 # Translation: 70 translation: 71 72 # Enable file uploads 73 upload: 74 """ 75 76 77 # pylint: disable=R0904 78 class TestPipeline(unittest.TestCase): 79 """ 80 API tests for pipelines. 81 """ 82 83 @staticmethod 84 @patch.dict(os.environ, {"CONFIG": os.path.join(tempfile.gettempdir(), "testapi.yml"), "API_CLASS": "txtai.api.API"}) 85 def start(): 86 """ 87 Starts a mock FastAPI client. 88 """ 89 90 config = os.path.join(tempfile.gettempdir(), "testapi.yml") 91 92 with open(config, "w", encoding="utf-8") as output: 93 output.write(PIPELINES) 94 95 # Create new application and set on client 96 application.app = application.create() 97 client = TestClient(application.app) 98 application.start() 99 100 return client 101 102 @classmethod 103 def setUpClass(cls): 104 """ 105 Create API client on creation of class. 106 """ 107 108 cls.client = TestPipeline.start() 109 110 cls.data = [ 111 "US tops 5 million confirmed virus cases", 112 "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 113 "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 114 "The National Park Service warns against sacrificing slower friends in a bear attack", 115 "Maine man wins $1M from $25 lottery ticket", 116 "Make huge profits without work, earn up to $100,000 a day", 117 ] 118 119 cls.text = ( 120 "Search is the base of many applications. Once data starts to pile up, users want to be able to find it. It's the foundation " 121 "of the internet and an ever-growing challenge that is never solved or done. The field of Natural Language Processing (NLP) is " 122 "rapidly evolving with a number of new developments. Large-scale general language models are an exciting new capability " 123 "allowing us to add amazing functionality quickly with limited compute and people. Innovation continues with new models " 124 "and advancements coming in at what seems a weekly basis. This article introduces txtai, an AI-powered search engine " 125 "that enables Natural Language Understanding (NLU) based search in any application." 126 ) 127 128 def testCaption(self): 129 """ 130 Test caption via API 131 """ 132 133 caption = self.client.get(f"caption?file={Utils.PATH}/books.jpg").json() 134 135 self.assertEqual(caption, "a book shelf filled with books and a stack of books") 136 137 def testCaptionBatch(self): 138 """ 139 Test batch caption via API 140 """ 141 142 path = Utils.PATH + "/books.jpg" 143 144 captions = self.client.post("batchcaption", json=[path, path]).json() 145 self.assertEqual(captions, ["a book shelf filled with books and a stack of books"] * 2) 146 147 def testEntity(self): 148 """ 149 Test entity extraction via API 150 """ 151 152 entities = self.client.get(f"entity?text={self.data[1]}").json() 153 self.assertEqual([e[0] for e in entities], ["Canada", "Manhattan"]) 154 155 def testEntityBatch(self): 156 """ 157 Test batch entity via API 158 """ 159 160 entities = self.client.post("batchentity", json=[self.data[1]]).json() 161 self.assertEqual([e[0] for e in entities[0]], ["Canada", "Manhattan"]) 162 163 def testEmpty(self): 164 """ 165 Test empty API configuration 166 """ 167 168 api = API({}) 169 170 self.assertIsNone(api.label("test", ["test"])) 171 self.assertIsNone(api.pipeline("junk", "test")) 172 173 def testLabel(self): 174 """ 175 Test label via API 176 """ 177 178 labels = self.client.post("label", json={"text": "this is the best sentence ever", "labels": ["positive", "negative"]}).json() 179 180 self.assertEqual(labels[0]["id"], 0) 181 182 def testLabelBatch(self): 183 """ 184 Test batch label via API 185 """ 186 187 labels = self.client.post( 188 "batchlabel", json={"texts": ["this is the best sentence ever", "This is terrible"], "labels": ["positive", "negative"]} 189 ).json() 190 191 results = [l[0]["id"] for l in labels] 192 self.assertEqual(results, [0, 1]) 193 194 def testLLM(self): 195 """ 196 Test LLM inference via API 197 """ 198 199 response = self.client.get("llm?text=test").json() 200 self.assertIsInstance(response, str) 201 202 def testLLMBatch(self): 203 """ 204 Test batch LLM inference via API 205 """ 206 207 response = self.client.post("batchllm", json={"texts": ["test", "test"]}).json() 208 self.assertEqual(len(response), 2) 209 210 def testObjects(self): 211 """ 212 Test objects via API 213 """ 214 215 objects = self.client.get(f"objects?file={Utils.PATH}/books.jpg").json() 216 217 self.assertEqual(objects[0][0], "book") 218 219 def testObjectsBatch(self): 220 """ 221 Test batch objects via API 222 """ 223 224 path = Utils.PATH + "/books.jpg" 225 226 objects = self.client.post("batchobjects", json=[path, path]).json() 227 self.assertEqual([o[0][0] for o in objects], ["book"] * 2) 228 229 def testSegment(self): 230 """ 231 Test segmentation via API 232 """ 233 234 text = self.client.get("segment?text=This is a test. And another test.").json() 235 236 # Check array length is 2 237 self.assertEqual(len(text), 2) 238 239 def testSegmentBatch(self): 240 """ 241 Test batch segmentation via API 242 """ 243 244 text = "This is a test. And another test." 245 texts = self.client.post("batchsegment", json=[text, text]).json() 246 247 # Check array length is 2 and first element length is 2 248 self.assertEqual(len(texts), 2) 249 self.assertEqual(len(texts[0]), 2) 250 251 def testSimilarity(self): 252 """ 253 Test similarity via API 254 """ 255 256 uid = self.client.post("similarity", json={"query": "feel good story", "texts": self.data}).json()[0]["id"] 257 258 self.assertEqual(self.data[uid], self.data[4]) 259 260 def testSimilarityBatch(self): 261 """ 262 Test batch similarity via API 263 """ 264 265 results = self.client.post("batchsimilarity", json={"queries": ["feel good story", "climate change"], "texts": self.data}).json() 266 267 uids = [result[0]["id"] for result in results] 268 self.assertEqual(uids, [4, 1]) 269 270 def testSummary(self): 271 """ 272 Test summary via API 273 """ 274 275 summary = self.client.get(f"summary?text={urllib.parse.quote(self.text)}&minlength=15&maxlength=15").json() 276 self.assertEqual(summary, "the field of natural language processing (NLP) is rapidly evolving") 277 278 def testSummaryBatch(self): 279 """ 280 Test batch summary via API 281 """ 282 283 summaries = self.client.post("batchsummary", json={"texts": [self.text, self.text], "minlength": 15, "maxlength": 15}).json() 284 self.assertEqual(summaries, ["the field of natural language processing (NLP) is rapidly evolving"] * 2) 285 286 def testTabular(self): 287 """ 288 Test tabular via API 289 """ 290 291 results = self.client.get(f"tabular?file={Utils.PATH}/tabular.csv").json() 292 293 # Check length of results is as expected 294 self.assertEqual(len(results), 6) 295 296 def testTabularBatch(self): 297 """ 298 Test batch tabular via API 299 """ 300 301 path = Utils.PATH + "/tabular.csv" 302 303 results = self.client.post("batchtabular", json=[path, path]).json() 304 self.assertEqual((len(results[0]), len(results[1])), (6, 6)) 305 306 def testTextractor(self): 307 """ 308 Test textractor via API 309 """ 310 311 text = self.client.get(f"textract?file={Utils.PATH}/article.pdf").json() 312 313 # Check length of text is as expected 314 self.assertEqual(len(text), 2471) 315 316 # Check invalid URLs 317 for url in ["http://192.168.1.1/path", "http://127.0.0.1/path", "http://invalid", "/etc/config"]: 318 with self.assertRaises(IOError): 319 self.client.get(f"textract?file={url}").json() 320 321 def testTextractorBatch(self): 322 """ 323 Test batch textractor via API 324 """ 325 326 path = Utils.PATH + "/article.pdf" 327 328 texts = self.client.post("batchtextract", json=[path, path]).json() 329 self.assertEqual((len(texts[0]), len(texts[1])), (2471, 2471)) 330 331 def testTextToSpeech(self): 332 """ 333 Test text to speech 334 """ 335 336 # Generate audio and check for WAV signature 337 audio = self.client.get("texttospeech?text=hello&encoding=wav").content 338 self.assertTrue(audio[0:4] == b"RIFF") 339 340 def testTranscribe(self): 341 """ 342 Test transcribe via API 343 """ 344 345 text = self.client.get(f"transcribe?file={Utils.PATH}/Make_huge_profits.wav").json() 346 347 # Check length of text is as expected 348 self.assertEqual(text, "Make huge profits without working make up to one hundred thousand dollars a day") 349 350 def testTranscribeBatch(self): 351 """ 352 Test batch transcribe via API 353 """ 354 355 path = Utils.PATH + "/Make_huge_profits.wav" 356 357 texts = self.client.post("batchtranscribe", json=[path, path]).json() 358 self.assertEqual(texts, ["Make huge profits without working make up to one hundred thousand dollars a day"] * 2) 359 360 def testTranslate(self): 361 """ 362 Test translate via API 363 """ 364 365 translation = self.client.get(f"translate?text={urllib.parse.quote('This is a test translation into Spanish')}&target=es").json() 366 self.assertEqual(translation, "Esta es una traducción de prueba al español") 367 368 def testTranslateBatch(self): 369 """ 370 Test batch translate via API 371 """ 372 373 text = "This is a test translation into Spanish" 374 translations = self.client.post("batchtranslate", json={"texts": [text, text], "target": "es"}).json() 375 self.assertEqual(translations, ["Esta es una traducción de prueba al español"] * 2) 376 377 def testUpload(self): 378 """ 379 Test file upload 380 """ 381 382 path = Utils.PATH + "/article.pdf" 383 with open(path, "rb") as f: 384 path = self.client.post("upload", files={"files": f}).json()[0] 385 self.assertTrue(os.path.exists(path))