testextension.py
1 """ 2 Extension module tests 3 """ 4 5 import os 6 import tempfile 7 import unittest 8 9 from unittest.mock import patch 10 11 from fastapi import APIRouter 12 from fastapi.testclient import TestClient 13 14 from txtai.api import application, Extension 15 from txtai.pipeline import Pipeline 16 17 # Example pipeline extension 18 PIPELINES = """ 19 testapi.testextension.SamplePipeline: 20 """ 21 22 23 class SampleRouter: 24 """ 25 Sample API router. 26 """ 27 28 router = APIRouter() 29 30 @staticmethod 31 @router.get("/sample") 32 def sample(text: str): 33 """ 34 Calls sample pipeline. 35 36 Args: 37 text: input text 38 39 Returns: 40 formatted text 41 """ 42 43 return application.get().pipeline("testapi.testextension.SamplePipeline", (text,)) 44 45 46 class SampleExtension(Extension): 47 """ 48 Sample API extension. 49 """ 50 51 def __call__(self, app): 52 app.include_router(SampleRouter().router) 53 54 55 class SamplePipeline(Pipeline): 56 """ 57 Sample pipeline. 58 """ 59 60 def __call__(self, text): 61 return text.lower() 62 63 64 class TestExtension(unittest.TestCase): 65 """ 66 API tests for extensions. 67 """ 68 69 @staticmethod 70 @patch.dict( 71 os.environ, 72 { 73 "CONFIG": os.path.join(tempfile.gettempdir(), "testapi.yml"), 74 "API_CLASS": "txtai.api.API", 75 "EXTENSIONS": "testapi.testextension.SampleExtension", 76 }, 77 ) 78 def start(): 79 """ 80 Starts a mock FastAPI client. 81 """ 82 83 config = os.path.join(tempfile.gettempdir(), "testapi.yml") 84 85 with open(config, "w", encoding="utf-8") as output: 86 output.write(PIPELINES) 87 88 # Create new application and set on client 89 application.app = application.create() 90 client = TestClient(application.app) 91 application.start() 92 93 return client 94 95 @classmethod 96 def setUpClass(cls): 97 """ 98 Create API client on creation of class. 99 """ 100 101 cls.client = TestExtension.start() 102 103 def testEmpty(self): 104 """ 105 Test an empty extension 106 """ 107 108 extension = Extension() 109 self.assertIsNone(extension(None)) 110 111 def testExtension(self): 112 """ 113 Test a pipeline extension 114 """ 115 116 text = self.client.get("sample?text=Test%20String").json() 117 self.assertEqual(text, "test string")