/ test / python / testapi / testextension.py
testextension.py
  1  """
  2  Extension module tests
  3  """
  4  
  5  import os
  6  import tempfile
  7  import unittest
  8  
  9  from unittest.mock import patch
 10  
 11  from fastapi import APIRouter
 12  from fastapi.testclient import TestClient
 13  
 14  from txtai.api import application, Extension
 15  from txtai.pipeline import Pipeline
 16  
 17  # Example pipeline extension
 18  PIPELINES = """
 19  testapi.testextension.SamplePipeline:
 20  """
 21  
 22  
 23  class SampleRouter:
 24      """
 25      Sample API router.
 26      """
 27  
 28      router = APIRouter()
 29  
 30      @staticmethod
 31      @router.get("/sample")
 32      def sample(text: str):
 33          """
 34          Calls sample pipeline.
 35  
 36          Args:
 37              text: input text
 38  
 39          Returns:
 40              formatted text
 41          """
 42  
 43          return application.get().pipeline("testapi.testextension.SamplePipeline", (text,))
 44  
 45  
 46  class SampleExtension(Extension):
 47      """
 48      Sample API extension.
 49      """
 50  
 51      def __call__(self, app):
 52          app.include_router(SampleRouter().router)
 53  
 54  
 55  class SamplePipeline(Pipeline):
 56      """
 57      Sample pipeline.
 58      """
 59  
 60      def __call__(self, text):
 61          return text.lower()
 62  
 63  
 64  class TestExtension(unittest.TestCase):
 65      """
 66      API tests for extensions.
 67      """
 68  
 69      @staticmethod
 70      @patch.dict(
 71          os.environ,
 72          {
 73              "CONFIG": os.path.join(tempfile.gettempdir(), "testapi.yml"),
 74              "API_CLASS": "txtai.api.API",
 75              "EXTENSIONS": "testapi.testextension.SampleExtension",
 76          },
 77      )
 78      def start():
 79          """
 80          Starts a mock FastAPI client.
 81          """
 82  
 83          config = os.path.join(tempfile.gettempdir(), "testapi.yml")
 84  
 85          with open(config, "w", encoding="utf-8") as output:
 86              output.write(PIPELINES)
 87  
 88          # Create new application and set on client
 89          application.app = application.create()
 90          client = TestClient(application.app)
 91          application.start()
 92  
 93          return client
 94  
 95      @classmethod
 96      def setUpClass(cls):
 97          """
 98          Create API client on creation of class.
 99          """
100  
101          cls.client = TestExtension.start()
102  
103      def testEmpty(self):
104          """
105          Test an empty extension
106          """
107  
108          extension = Extension()
109          self.assertIsNone(extension(None))
110  
111      def testExtension(self):
112          """
113          Test a pipeline extension
114          """
115  
116          text = self.client.get("sample?text=Test%20String").json()
117          self.assertEqual(text, "test string")