openai.py
1 """ 2 Defines an OpenAI-compatible API endpoint for txtai. 3 4 See the following specification for more information: 5 https://github.com/openai/openai-openapi 6 """ 7 8 import uuid 9 import json 10 import time 11 12 from typing import List, Optional, Union 13 14 from fastapi import APIRouter, Body, Form, UploadFile 15 from fastapi.responses import Response, StreamingResponse 16 17 from .. import application 18 from ..route import EncodingAPIRoute 19 20 router = APIRouter(route_class=EncodingAPIRoute) 21 22 23 # pylint: disable=W0622 24 @router.post("/v1/chat/completions") 25 def chat( 26 messages: List[dict] = Body(...), 27 model: str = Body(...), 28 max_completion_tokens: Optional[int] = Body(default=None), 29 stream: Optional[bool] = Body(default=False), 30 ): 31 """ 32 Runs a chat completion request. 33 34 Args: 35 messages: list of messages [{"role": role, "content": content}] 36 model: agent name, workflow name, pipeline name or embeddings 37 max_completion_tokens: sets the max length to generate 38 stream: streams response if True 39 40 Returns: 41 chat completion 42 """ 43 44 # Build keyword arguments 45 kwargs = {key: value for key, value in [("stream", stream), ("maxlength", max_completion_tokens)] if value} 46 47 # Get first message 48 message = messages[0]["content"] 49 50 # Agent 51 if model in application.get().agents: 52 result = application.get().agent(model, message, **kwargs) 53 54 # Embeddings search 55 elif model == "embeddings": 56 result = application.get().search(message, 1, **kwargs)[0]["text"] 57 58 # Pipeline 59 elif model in application.get().pipelines and model != "llm": 60 result = application.get().pipeline(model, message, **kwargs) 61 62 # Workflow 63 elif model in application.get().workflows: 64 result = list(application.get().workflow(model, [message], **kwargs))[0] 65 66 # Default to running all messages through default LLM 67 else: 68 result = application.get().pipeline("llm", messages, **kwargs) 69 70 # Write response 71 return StreamingResponse(StreamingChatResponse()(model, result)) if stream else ChatResponse()(model, result) 72 73 74 @router.post("/v1/embeddings") 75 def embeddings(input: Union[str, List[str]] = Body(...), model: str = Body(...)): 76 """ 77 Creates an embeddings vector for the input text. 78 79 Args: 80 input: text|list 81 model: model name 82 83 Returns: 84 list of embeddings vectors 85 """ 86 87 # Convert to embeddings 88 result = application.get().batchtransform([input] if isinstance(input, str) else input) 89 90 # Build and return response 91 data = [] 92 for index, embedding in enumerate(result): 93 data.append({"object": "embedding", "embedding": embedding, "index": index}) 94 95 return {"object": "list", "data": data, "model": model} 96 97 98 @router.post("/v1/audio/speech") 99 def speech(input: str = Body(...), voice: str = Body(...), response_format: Optional[str] = Body(default="mp3")): 100 """ 101 Generates speech for the input text. 102 103 Args: 104 input: input text 105 voice: speaker name 106 response_format: audio encoding format, defaults to mp3 107 108 Returns: 109 audio data 110 """ 111 112 # Convert to audio 113 audio = application.get().pipeline("texttospeech", input, speaker=voice, encoding=response_format) 114 115 # Write audio 116 return Response(audio) 117 118 119 @router.post("/v1/audio/transcriptions") 120 def transcribe(file: UploadFile, language: Optional[str] = Form(default=None), response_format: Optional[str] = Form(default="json")): 121 """ 122 Transcribes audio to text. 123 124 Args: 125 file: audio input file 126 language: language of input audio 127 response_format: output format (json or text) 128 129 Returns: 130 transcribed text 131 """ 132 133 # Transcribe 134 text = application.get().pipeline("transcription", file.file, language=language, task="transcribe") 135 return text if response_format == "text" else {"text": text} 136 137 138 @router.post("/v1/audio/translations") 139 def translate( 140 file: UploadFile, 141 response_format: Optional[str] = Form(default="json"), 142 ): 143 """ 144 Translates audio to English. 145 146 Args: 147 file: audio input file 148 response_format: output format (json or text) 149 150 Returns: 151 translated text 152 """ 153 154 # Transcribe and translate to English 155 text = application.get().pipeline("transcription", file.file, language="English", task="translate") 156 return text if response_format == "text" else {"text": text} 157 158 159 class ChatResponse: 160 """ 161 Returns a chat response object. 162 """ 163 164 def __call__(self, model, result): 165 return { 166 "id": str(uuid.uuid4()), 167 "object": "chat.completion", 168 "created": int(time.time() * 1000), 169 "model": model, 170 "choices": [{"id": 0, "message": {"role": "assistant", "content": result}, "finish_reason": "stop"}], 171 } 172 173 174 class StreamingChatResponse: 175 """ 176 Returns a streaming chat response object. 177 """ 178 179 def __call__(self, model, result): 180 for chunk in result: 181 yield "data: " + json.dumps( 182 { 183 "id": str(uuid.uuid4()), 184 "object": "chat.completion.chunk", 185 "created": int(time.time() * 1000), 186 "model": model, 187 "choices": [{"id": 0, "delta": {"content": chunk}}], 188 } 189 ) + "\n\n" 190 191 yield "data: [DONE]\n\n"