/ MemoAI / api_server.py
api_server.py
  1  """
  2  This script implements an API for the ChatGLM3-6B model,
  3  formatted similarly to OpenAI's API (https://platform.openai.com/docs/api-reference/chat).
  4  It's designed to be run as a web server using FastAPI and uvicorn,
  5  making the ChatGLM3-6B model accessible through OpenAI Client.
  6  
  7  Key Components and Features:
  8  - Model and Tokenizer Setup: Configures the model and tokenizer paths and loads them.
  9  - FastAPI Configuration: Sets up a FastAPI application with CORS middleware for handling cross-origin requests.
 10  - API Endpoints:
 11    - "/v1/models": Lists the available models, specifically ChatGLM3-6B.
 12    - "/v1/chat/completions": Processes chat completion requests with options for streaming and regular responses.
 13    - "/v1/embeddings": Processes Embedding request of a list of text inputs.
 14  - Token Limit Caution: In the OpenAI API, 'max_tokens' is equivalent to HuggingFace's 'max_new_tokens', not 'max_length'.
 15  For instance, setting 'max_tokens' to 8192 for a 6b model would result in an error due to the model's inability to output
 16  that many tokens after accounting for the history and prompt tokens.
 17  - Stream Handling and Custom Functions: Manages streaming responses and custom function calls within chat responses.
 18  - Pydantic Models: Defines structured models for requests and responses, enhancing API documentation and type safety.
 19  - Main Execution: Initializes the model and tokenizer, and starts the FastAPI app on the designated host and port.
 20  
 21  Note:
 22      This script doesn't include the setup for special tokens or multi-GPU support by default.
 23      Users need to configure their special tokens and can enable multi-GPU support as per the provided instructions.
 24      Embedding Models only support in One GPU.
 25  
 26  """
 27  
 28  import os
 29  import time
 30  import tiktoken
 31  import torch
 32  import uvicorn
 33  
 34  from fastapi import FastAPI, HTTPException, Response, Body
 35  from fastapi.middleware.cors import CORSMiddleware
 36  
 37  from contextlib import asynccontextmanager
 38  from typing import List, Literal, Optional, Union
 39  from loguru import logger
 40  from peft import AutoPeftModelForCausalLM
 41  from pydantic import BaseModel, Field
 42  from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
 43  from utils import process_response, generate_chatglm3, generate_stream_chatglm3
 44  from sentence_transformers import SentenceTransformer
 45  
 46  from sse_starlette.sse import EventSourceResponse
 47  
 48  # Set up limit request time
 49  EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
 50  
 51  # set LLM path
 52  MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
 53  TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
 54  
 55  # set Embedding Model path
 56  EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', 'BAAI/bge-large-zh-v1.5')
 57  
 58  
 59  @asynccontextmanager
 60  async def lifespan(app: FastAPI):
 61      yield
 62      if torch.cuda.is_available():
 63          torch.cuda.empty_cache()
 64          torch.cuda.ipc_collect()
 65  
 66  
 67  app = FastAPI(lifespan=lifespan)
 68  
 69  app.add_middleware(
 70      CORSMiddleware,
 71      allow_origins=["*"],
 72      allow_credentials=True,
 73      allow_methods=["*"],
 74      allow_headers=["*"],
 75  )
 76  
 77  
 78  class ModelCard(BaseModel):
 79      id: str
 80      object: str = "model"
 81      created: int = Field(default_factory=lambda: int(time.time()))
 82      owned_by: str = "owner"
 83      root: Optional[str] = None
 84      parent: Optional[str] = None
 85      permission: Optional[list] = None
 86  
 87  
 88  class ModelList(BaseModel):
 89      object: str = "list"
 90      data: List[ModelCard] = []
 91  
 92  
 93  class FunctionCallResponse(BaseModel):
 94      name: Optional[str] = None
 95      arguments: Optional[str] = None
 96  
 97  
 98  class ChatMessage(BaseModel):
 99      role: Literal["user", "assistant", "system", "function"]
100      content: str = None
101      name: Optional[str] = None
102      function_call: Optional[FunctionCallResponse] = None
103  
104  
105  class DeltaMessage(BaseModel):
106      role: Optional[Literal["user", "assistant", "system"]] = None
107      content: Optional[str] = None
108      function_call: Optional[FunctionCallResponse] = None
109  
110  
111  ## for Embedding
112  class EmbeddingRequest(BaseModel):
113      input: List[str]
114      model: str
115  
116  
117  class CompletionUsage(BaseModel):
118      prompt_tokens: int
119      completion_tokens: int
120      total_tokens: int
121  
122  
123  class EmbeddingResponse(BaseModel):
124      data: list
125      model: str
126      object: str
127      usage: CompletionUsage
128  
129  
130  # for ChatCompletionRequest
131  
132  class UsageInfo(BaseModel):
133      prompt_tokens: int = 0
134      total_tokens: int = 0
135      completion_tokens: Optional[int] = 0
136  
137  
138  class ChatCompletionRequest(BaseModel):
139      model: str
140      messages: List[ChatMessage]
141      temperature: Optional[float] = 0.8
142      top_p: Optional[float] = 0.8
143      max_tokens: Optional[int] = None
144      stream: Optional[bool] = False
145      tools: Optional[Union[dict, List[dict]]] = None
146      repetition_penalty: Optional[float] = 1.1
147  
148  
149  class ChatCompletionResponseChoice(BaseModel):
150      index: int
151      message: ChatMessage
152      finish_reason: Literal["stop", "length", "function_call"]
153  
154  
155  class ChatCompletionResponseStreamChoice(BaseModel):
156      delta: DeltaMessage
157      finish_reason: Optional[Literal["stop", "length", "function_call"]]
158      index: int
159  
160  
161  class ChatCompletionResponse(BaseModel):
162      model: str
163      id: str
164      object: Literal["chat.completion", "chat.completion.chunk"]
165      choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
166      created: Optional[int] = Field(default_factory=lambda: int(time.time()))
167      usage: Optional[UsageInfo] = None
168  
169  
170  @app.get("/health")
171  async def health() -> Response:
172      """Health check."""
173      return Response(status_code=200)
174  
175  
176  @app.post("/v1/embeddings", response_model=EmbeddingResponse)
177  async def get_embeddings(request: EmbeddingRequest):
178      embeddings = [embedding_model.encode(text) for text in request.input]
179      embeddings = [embedding.tolist() for embedding in embeddings]
180  
181      def num_tokens_from_string(string: str) -> int:
182          """
183          Returns the number of tokens in a text string.
184          use cl100k_base tokenizer
185          """
186          encoding = tiktoken.get_encoding('cl100k_base')
187          num_tokens = len(encoding.encode(string))
188          return num_tokens
189  
190      response = {
191          "data": [
192              {
193                  "object": "embedding",
194                  "embedding": embedding,
195                  "index": index
196              }
197              for index, embedding in enumerate(embeddings)
198          ],
199          "model": request.model,
200          "object": "list",
201          "usage": CompletionUsage(
202              prompt_tokens=sum(len(text.split()) for text in request.input),
203              completion_tokens=0,
204              total_tokens=sum(num_tokens_from_string(text) for text in request.input),
205          )
206      }
207      return response
208  
209  
210  @app.get("/v1/models", response_model=ModelList)
211  async def list_models():
212      model_card = ModelCard(
213          id="chatglm3-6b"
214      )
215      return ModelList(
216          data=[model_card]
217      )
218  
219  
220  def insert_custom_message(messages: List[ChatMessage] = Body(...)):
221      custom_message = ChatMessage(
222          role="system",
223          content="你必须记住,你是一个只负责聊天的人工智能助手,所以你只能根据上下文回答你朋友的问题,要求语句完整、通顺、文采斐然,回答具有情感,不能胡乱作答、不能断章取义、不能上下文不连贯,禁止回答无关内容。"
224                  "如果你不能理解你朋友的话,你必须开启一个新的话题来引导你朋友进行对话,例如你可以问你的朋友一个日常问题。\n"
225                  "你是司小远,一个聪明、热情、善良的人工智能助手,后面的对话来自你的朋友,你要认真地回答他。"
226      )
227      messages.insert(0, custom_message)
228      return messages
229  
230  
231  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
232  async def create_chat_completion(request: ChatCompletionRequest):
233      global model, tokenizer
234  
235      if len(request.messages) < 1 or request.messages[-1].role == "assistant":
236          raise HTTPException(status_code=400, detail="Invalid request")
237      messages = request.messages
238      if request.messages and request.messages[0].role == 'system':
239          messages = request.messages
240      else:
241          if request.messages:
242              messages = insert_custom_message(request.messages)
243          else:
244              messages = request.messages
245      print(type(request.messages), request.messages)
246      gen_params = dict(
247          messages=messages,
248          temperature=request.temperature,
249          top_p=request.top_p,
250          max_tokens=request.max_tokens or 1024,
251          echo=False,
252          stream=request.stream,
253          repetition_penalty=request.repetition_penalty,
254          tools=request.tools,
255      )
256      logger.debug(f"==== request ====\n{gen_params}")
257  
258      if request.stream:
259  
260          # Use the stream mode to read the first few characters, if it is not a function call, direct stram output
261          predict_stream_generator = predict_stream(request.model, gen_params)
262          # return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
263          output = next(predict_stream_generator)
264          print(output)
265          # logger.debug(f"First result output:\n{output}")
266          if not contains_custom_function(output):
267              return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
268  
269          # Obtain the result directly at one time and determine whether tools needs to be called.
270          # logger.debug(f"First result output:\n{output}")
271  
272          function_call = None
273          if output and request.tools:
274              try:
275                  function_call = process_response(output, use_tool=True)
276              except:
277                  logger.warning("Failed to parse tool call")
278  
279          # CallFunction
280          if isinstance(function_call, dict):
281              function_call = FunctionCallResponse(**function_call)
282  
283              """
284              In this demo, we did not register any tools.
285              You can use the tools that have been implemented in our `tools_using_demo` and implement your own streaming tool implementation here.
286              Similar to the following method:
287                  function_args = json.loads(function_call.arguments)
288                  tool_response = dispatch_tool(tool_name: str, tool_params: dict)
289              """
290              tool_response = ""
291  
292              if not gen_params.get("messages"):
293                  gen_params["messages"] = []
294  
295              gen_params["messages"].append(ChatMessage(
296                  role="assistant",
297                  content=output,
298              ))
299              gen_params["messages"].append(ChatMessage(
300                  role="function",
301                  name=function_call.name,
302                  content=tool_response,
303              ))
304  
305              # Streaming output of results after function calls
306              generate = predict(request.model, gen_params)
307              return EventSourceResponse(generate, media_type="text/event-stream")
308  
309          else:
310              # Handled to avoid exceptions in the above parsing function process.
311              generate = parse_output_text(request.model, output)
312              return EventSourceResponse(generate, media_type="text/event-stream")
313  
314      # Here is the handling of stream = False
315      response = generate_chatglm3(model, tokenizer, gen_params)
316  
317      # Remove the first newline character
318      if response["text"].startswith("\n"):
319          response["text"] = response["text"][1:]
320      response["text"] = response["text"].strip()
321  
322      usage = UsageInfo()
323      function_call, finish_reason = None, "stop"
324      if request.tools:
325          try:
326              function_call = process_response(response["text"], use_tool=True)
327          except:
328              logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.")
329  
330      if isinstance(function_call, dict):
331          finish_reason = "function_call"
332          function_call = FunctionCallResponse(**function_call)
333  
334      message = ChatMessage(
335          role="assistant",
336          content=response["text"],
337          function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
338      )
339  
340      logger.debug(f"==== message ====\n{message}")
341  
342      choice_data = ChatCompletionResponseChoice(
343          index=0,
344          message=message,
345          finish_reason=finish_reason,
346      )
347      task_usage = UsageInfo.model_validate(response["usage"])
348      for usage_key, usage_value in task_usage.model_dump().items():
349          setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
350  
351      return ChatCompletionResponse(
352          model=request.model,
353          id="",  # for open_source model, id is empty
354          choices=[choice_data],
355          object="chat.completion",
356          usage=usage
357      )
358  
359  
360  async def predict(model_id: str, params: dict):
361      global model, tokenizer
362  
363      choice_data = ChatCompletionResponseStreamChoice(
364          index=0,
365          delta=DeltaMessage(role="assistant"),
366          finish_reason=None
367      )
368      chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
369      yield "{}".format(chunk.model_dump_json(exclude_unset=True))
370  
371      previous_text = ""
372      for new_response in generate_stream_chatglm3(model, tokenizer, params):
373          decoded_unicode = new_response["text"]
374          delta_text = decoded_unicode[len(previous_text):]
375          previous_text = decoded_unicode
376  
377          finish_reason = new_response["finish_reason"]
378          if len(delta_text) == 0 and finish_reason != "function_call":
379              continue
380  
381          function_call = None
382          if finish_reason == "function_call":
383              try:
384                  function_call = process_response(decoded_unicode, use_tool=True)
385              except:
386                  logger.warning(
387                      "Failed to parse tool call, maybe the response is not a tool call or have been answered.")
388  
389          if isinstance(function_call, dict):
390              function_call = FunctionCallResponse(**function_call)
391  
392          delta = DeltaMessage(
393              content=delta_text,
394              role="assistant",
395              function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
396          )
397  
398          choice_data = ChatCompletionResponseStreamChoice(
399              index=0,
400              delta=delta,
401              finish_reason=finish_reason
402          )
403          chunk = ChatCompletionResponse(
404              model=model_id,
405              id="",
406              choices=[choice_data],
407              object="chat.completion.chunk"
408          )
409          yield "{}".format(chunk.model_dump_json(exclude_unset=True))
410  
411      choice_data = ChatCompletionResponseStreamChoice(
412          index=0,
413          delta=DeltaMessage(),
414          finish_reason="stop"
415      )
416      chunk = ChatCompletionResponse(
417          model=model_id,
418          id="",
419          choices=[choice_data],
420          object="chat.completion.chunk"
421      )
422      yield "{}".format(chunk.model_dump_json(exclude_unset=True))
423      yield '[DONE]'
424  
425  
426  def predict_stream(model_id, gen_params):
427      """
428      The function call is compatible with stream mode output.
429  
430      The first seven characters are determined.
431      If not a function call, the stream output is directly generated.
432      Otherwise, the complete character content of the function call is returned.
433  
434      :param model_id:
435      :param gen_params:
436      :return:
437      """
438      output = ""
439      is_function_call = False
440      has_send_first_chunk = False
441      print('参数')
442      print(model_id,gen_params)
443      for new_response in generate_stream_chatglm3(model, tokenizer, gen_params):
444          decoded_unicode = new_response["text"]
445          delta_text = decoded_unicode[len(output):]
446          output = decoded_unicode
447  
448          # When it is not a function call and the character length is> 7,
449          # try to judge whether it is a function call according to the special function prefix
450          if not is_function_call:
451  
452              # Determine whether a function is called
453              is_function_call = contains_custom_function(output)
454              if is_function_call:
455                  continue
456  
457              # Non-function call, direct stream output
458              finish_reason = new_response["finish_reason"]
459  
460              # Send an empty string first to avoid truncation by subsequent next() operations.
461              if not has_send_first_chunk:
462                  message = DeltaMessage(
463                      content="",
464                      role="assistant",
465                      function_call=None,
466                  )
467                  choice_data = ChatCompletionResponseStreamChoice(
468                      index=0,
469                      delta=message,
470                      finish_reason=finish_reason
471                  )
472                  chunk = ChatCompletionResponse(
473                      model=model_id,
474                      id="",
475                      choices=[choice_data],
476                      created=int(time.time()),
477                      object="chat.completion.chunk"
478                  )
479                  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
480  
481              send_msg = delta_text if has_send_first_chunk else output
482              has_send_first_chunk = True
483              message = DeltaMessage(
484                  content=send_msg,
485                  role="assistant",
486                  function_call=None,
487              )
488              choice_data = ChatCompletionResponseStreamChoice(
489                  index=0,
490                  delta=message,
491                  finish_reason=finish_reason
492              )
493              chunk = ChatCompletionResponse(
494                  model=model_id,
495                  id="",
496                  choices=[choice_data],
497                  created=int(time.time()),
498                  object="chat.completion.chunk"
499              )
500              yield "{}".format(chunk.model_dump_json(exclude_unset=True))
501  
502      if is_function_call:
503          yield output
504      else:
505          yield '[DONE]'
506  
507  
508  async def parse_output_text(model_id: str, value: str):
509      """
510      Directly output the text content of value
511  
512      :param model_id:
513      :param value:
514      :return:
515      """
516      choice_data = ChatCompletionResponseStreamChoice(
517          index=0,
518          delta=DeltaMessage(role="assistant", content=value),
519          finish_reason=None
520      )
521      chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
522      yield "{}".format(chunk.model_dump_json(exclude_unset=True))
523  
524      choice_data = ChatCompletionResponseStreamChoice(
525          index=0,
526          delta=DeltaMessage(),
527          finish_reason="stop"
528      )
529      chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
530      yield "{}".format(chunk.model_dump_json(exclude_unset=True))
531      yield '[DONE]'
532  
533  
534  def contains_custom_function(value: str) -> bool:
535      """
536      Determine whether 'function_call' according to a special function prefix.
537  
538      For example, the functions defined in "tools_using_demo/tool_register.py" are all "get_xxx" and start with "get_"
539  
540      [Note] This is not a rigorous judgment method, only for reference.
541  
542      :param value:
543      :return:
544      """
545      return value and 'get_' in value
546  
547  
548  from pathlib import Path
549  from typing import Annotated, Union
550  
551  import typer
552  from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
553  from transformers import (
554      AutoModelForCausalLM,
555      AutoTokenizer,
556      PreTrainedModel,
557      PreTrainedTokenizer,
558      PreTrainedTokenizerFast,
559  )
560  
561  ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
562  TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
563  
564  
565  def _resolve_path(path: Union[str, Path]) -> Path:
566      return Path(path).expanduser().resolve()
567  
568  
569  def load_model_and_tokenizer(
570          model_dir: Union[str, Path], trust_remote_code: bool = True
571  ) -> tuple[ModelType, TokenizerType]:
572      model_dir = _resolve_path(model_dir)
573      if (model_dir / 'adapter_config.json').exists():
574          model = AutoPeftModelForCausalLM.from_pretrained(
575              model_dir, trust_remote_code=trust_remote_code, device_map='auto'
576          )
577          tokenizer_dir = model.peft_config['default'].base_model_name_or_path
578      else:
579          model = AutoModelForCausalLM.from_pretrained(
580              model_dir, trust_remote_code=trust_remote_code, device_map='auto'
581          )
582          tokenizer_dir = model_dir
583      tokenizer = AutoTokenizer.from_pretrained(
584          tokenizer_dir, trust_remote_code=trust_remote_code
585      )
586      return model, tokenizer
587  
588  
589  if __name__ == "__main__":
590      # Load LLM
591      # tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
592      # model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
593      # 填微调之后的保存路径
594      model, tokenizer = load_model_and_tokenizer(
595          r'E:\Project\Python\ChatGLM3\finetune_demo\output03-24\checkpoint-224000'
596      )
597      # load Embedding
598      embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")
599      uvicorn.run(app, host='0.0.0.0', port=8002, workers=1)