/ src / python / txtai / api / routers / openai.py
openai.py
  1  """
  2  Defines an OpenAI-compatible API endpoint for txtai.
  3  
  4  See the following specification for more information:
  5  https://github.com/openai/openai-openapi
  6  """
  7  
  8  import uuid
  9  import json
 10  import time
 11  
 12  from typing import List, Optional, Union
 13  
 14  from fastapi import APIRouter, Body, Form, UploadFile
 15  from fastapi.responses import Response, StreamingResponse
 16  
 17  from .. import application
 18  from ..route import EncodingAPIRoute
 19  
 20  router = APIRouter(route_class=EncodingAPIRoute)
 21  
 22  
 23  # pylint: disable=W0622
 24  @router.post("/v1/chat/completions")
 25  def chat(
 26      messages: List[dict] = Body(...),
 27      model: str = Body(...),
 28      max_completion_tokens: Optional[int] = Body(default=None),
 29      stream: Optional[bool] = Body(default=False),
 30  ):
 31      """
 32      Runs a chat completion request.
 33  
 34      Args:
 35          messages: list of messages [{"role": role, "content": content}]
 36          model: agent name, workflow name, pipeline name or embeddings
 37          max_completion_tokens: sets the max length to generate
 38          stream: streams response if True
 39  
 40      Returns:
 41          chat completion
 42      """
 43  
 44      # Build keyword arguments
 45      kwargs = {key: value for key, value in [("stream", stream), ("maxlength", max_completion_tokens)] if value}
 46  
 47      # Get first message
 48      message = messages[0]["content"]
 49  
 50      # Agent
 51      if model in application.get().agents:
 52          result = application.get().agent(model, message, **kwargs)
 53  
 54      # Embeddings search
 55      elif model == "embeddings":
 56          result = application.get().search(message, 1, **kwargs)[0]["text"]
 57  
 58      # Pipeline
 59      elif model in application.get().pipelines and model != "llm":
 60          result = application.get().pipeline(model, message, **kwargs)
 61  
 62      # Workflow
 63      elif model in application.get().workflows:
 64          result = list(application.get().workflow(model, [message], **kwargs))[0]
 65  
 66      # Default to running all messages through default LLM
 67      else:
 68          result = application.get().pipeline("llm", messages, **kwargs)
 69  
 70      # Write response
 71      return StreamingResponse(StreamingChatResponse()(model, result)) if stream else ChatResponse()(model, result)
 72  
 73  
 74  @router.post("/v1/embeddings")
 75  def embeddings(input: Union[str, List[str]] = Body(...), model: str = Body(...)):
 76      """
 77      Creates an embeddings vector for the input text.
 78  
 79      Args:
 80          input: text|list
 81          model: model name
 82  
 83      Returns:
 84          list of embeddings vectors
 85      """
 86  
 87      # Convert to embeddings
 88      result = application.get().batchtransform([input] if isinstance(input, str) else input)
 89  
 90      # Build and return response
 91      data = []
 92      for index, embedding in enumerate(result):
 93          data.append({"object": "embedding", "embedding": embedding, "index": index})
 94  
 95      return {"object": "list", "data": data, "model": model}
 96  
 97  
 98  @router.post("/v1/audio/speech")
 99  def speech(input: str = Body(...), voice: str = Body(...), response_format: Optional[str] = Body(default="mp3")):
100      """
101      Generates speech for the input text.
102  
103      Args:
104          input: input text
105          voice: speaker name
106          response_format: audio encoding format, defaults to mp3
107  
108      Returns:
109          audio data
110      """
111  
112      # Convert to audio
113      audio = application.get().pipeline("texttospeech", input, speaker=voice, encoding=response_format)
114  
115      # Write audio
116      return Response(audio)
117  
118  
119  @router.post("/v1/audio/transcriptions")
120  def transcribe(file: UploadFile, language: Optional[str] = Form(default=None), response_format: Optional[str] = Form(default="json")):
121      """
122      Transcribes audio to text.
123  
124      Args:
125          file: audio input file
126          language: language of input audio
127          response_format: output format (json or text)
128  
129      Returns:
130          transcribed text
131      """
132  
133      # Transcribe
134      text = application.get().pipeline("transcription", file.file, language=language, task="transcribe")
135      return text if response_format == "text" else {"text": text}
136  
137  
138  @router.post("/v1/audio/translations")
139  def translate(
140      file: UploadFile,
141      response_format: Optional[str] = Form(default="json"),
142  ):
143      """
144      Translates audio to English.
145  
146      Args:
147          file: audio input file
148          response_format: output format (json or text)
149  
150      Returns:
151          translated text
152      """
153  
154      # Transcribe and translate to English
155      text = application.get().pipeline("transcription", file.file, language="English", task="translate")
156      return text if response_format == "text" else {"text": text}
157  
158  
159  class ChatResponse:
160      """
161      Returns a chat response object.
162      """
163  
164      def __call__(self, model, result):
165          return {
166              "id": str(uuid.uuid4()),
167              "object": "chat.completion",
168              "created": int(time.time() * 1000),
169              "model": model,
170              "choices": [{"id": 0, "message": {"role": "assistant", "content": result}, "finish_reason": "stop"}],
171          }
172  
173  
174  class StreamingChatResponse:
175      """
176      Returns a streaming chat response object.
177      """
178  
179      def __call__(self, model, result):
180          for chunk in result:
181              yield "data: " + json.dumps(
182                  {
183                      "id": str(uuid.uuid4()),
184                      "object": "chat.completion.chunk",
185                      "created": int(time.time() * 1000),
186                      "model": model,
187                      "choices": [{"id": 0, "delta": {"content": chunk}}],
188                  }
189              ) + "\n\n"
190  
191          yield "data: [DONE]\n\n"