api_server.py
1 """ 2 This script implements an API for the ChatGLM3-6B model, 3 formatted similarly to OpenAI's API (https://platform.openai.com/docs/api-reference/chat). 4 It's designed to be run as a web server using FastAPI and uvicorn, 5 making the ChatGLM3-6B model accessible through OpenAI Client. 6 7 Key Components and Features: 8 - Model and Tokenizer Setup: Configures the model and tokenizer paths and loads them. 9 - FastAPI Configuration: Sets up a FastAPI application with CORS middleware for handling cross-origin requests. 10 - API Endpoints: 11 - "/v1/models": Lists the available models, specifically ChatGLM3-6B. 12 - "/v1/chat/completions": Processes chat completion requests with options for streaming and regular responses. 13 - "/v1/embeddings": Processes Embedding request of a list of text inputs. 14 - Token Limit Caution: In the OpenAI API, 'max_tokens' is equivalent to HuggingFace's 'max_new_tokens', not 'max_length'. 15 For instance, setting 'max_tokens' to 8192 for a 6b model would result in an error due to the model's inability to output 16 that many tokens after accounting for the history and prompt tokens. 17 - Stream Handling and Custom Functions: Manages streaming responses and custom function calls within chat responses. 18 - Pydantic Models: Defines structured models for requests and responses, enhancing API documentation and type safety. 19 - Main Execution: Initializes the model and tokenizer, and starts the FastAPI app on the designated host and port. 20 21 Note: 22 This script doesn't include the setup for special tokens or multi-GPU support by default. 23 Users need to configure their special tokens and can enable multi-GPU support as per the provided instructions. 24 Embedding Models only support in One GPU. 25 26 """ 27 28 import os 29 import time 30 import tiktoken 31 import torch 32 import uvicorn 33 34 from fastapi import FastAPI, HTTPException, Response, Body 35 from fastapi.middleware.cors import CORSMiddleware 36 37 from contextlib import asynccontextmanager 38 from typing import List, Literal, Optional, Union 39 from loguru import logger 40 from peft import AutoPeftModelForCausalLM 41 from pydantic import BaseModel, Field 42 from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM 43 from utils import process_response, generate_chatglm3, generate_stream_chatglm3 44 from sentence_transformers import SentenceTransformer 45 46 from sse_starlette.sse import EventSourceResponse 47 48 # Set up limit request time 49 EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 50 51 # set LLM path 52 MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b') 53 TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) 54 55 # set Embedding Model path 56 EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', 'BAAI/bge-large-zh-v1.5') 57 58 59 @asynccontextmanager 60 async def lifespan(app: FastAPI): 61 yield 62 if torch.cuda.is_available(): 63 torch.cuda.empty_cache() 64 torch.cuda.ipc_collect() 65 66 67 app = FastAPI(lifespan=lifespan) 68 69 app.add_middleware( 70 CORSMiddleware, 71 allow_origins=["*"], 72 allow_credentials=True, 73 allow_methods=["*"], 74 allow_headers=["*"], 75 ) 76 77 78 class ModelCard(BaseModel): 79 id: str 80 object: str = "model" 81 created: int = Field(default_factory=lambda: int(time.time())) 82 owned_by: str = "owner" 83 root: Optional[str] = None 84 parent: Optional[str] = None 85 permission: Optional[list] = None 86 87 88 class ModelList(BaseModel): 89 object: str = "list" 90 data: List[ModelCard] = [] 91 92 93 class FunctionCallResponse(BaseModel): 94 name: Optional[str] = None 95 arguments: Optional[str] = None 96 97 98 class ChatMessage(BaseModel): 99 role: Literal["user", "assistant", "system", "function"] 100 content: str = None 101 name: Optional[str] = None 102 function_call: Optional[FunctionCallResponse] = None 103 104 105 class DeltaMessage(BaseModel): 106 role: Optional[Literal["user", "assistant", "system"]] = None 107 content: Optional[str] = None 108 function_call: Optional[FunctionCallResponse] = None 109 110 111 ## for Embedding 112 class EmbeddingRequest(BaseModel): 113 input: List[str] 114 model: str 115 116 117 class CompletionUsage(BaseModel): 118 prompt_tokens: int 119 completion_tokens: int 120 total_tokens: int 121 122 123 class EmbeddingResponse(BaseModel): 124 data: list 125 model: str 126 object: str 127 usage: CompletionUsage 128 129 130 # for ChatCompletionRequest 131 132 class UsageInfo(BaseModel): 133 prompt_tokens: int = 0 134 total_tokens: int = 0 135 completion_tokens: Optional[int] = 0 136 137 138 class ChatCompletionRequest(BaseModel): 139 model: str 140 messages: List[ChatMessage] 141 temperature: Optional[float] = 0.8 142 top_p: Optional[float] = 0.8 143 max_tokens: Optional[int] = None 144 stream: Optional[bool] = False 145 tools: Optional[Union[dict, List[dict]]] = None 146 repetition_penalty: Optional[float] = 1.1 147 148 149 class ChatCompletionResponseChoice(BaseModel): 150 index: int 151 message: ChatMessage 152 finish_reason: Literal["stop", "length", "function_call"] 153 154 155 class ChatCompletionResponseStreamChoice(BaseModel): 156 delta: DeltaMessage 157 finish_reason: Optional[Literal["stop", "length", "function_call"]] 158 index: int 159 160 161 class ChatCompletionResponse(BaseModel): 162 model: str 163 id: str 164 object: Literal["chat.completion", "chat.completion.chunk"] 165 choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] 166 created: Optional[int] = Field(default_factory=lambda: int(time.time())) 167 usage: Optional[UsageInfo] = None 168 169 170 @app.get("/health") 171 async def health() -> Response: 172 """Health check.""" 173 return Response(status_code=200) 174 175 176 @app.post("/v1/embeddings", response_model=EmbeddingResponse) 177 async def get_embeddings(request: EmbeddingRequest): 178 embeddings = [embedding_model.encode(text) for text in request.input] 179 embeddings = [embedding.tolist() for embedding in embeddings] 180 181 def num_tokens_from_string(string: str) -> int: 182 """ 183 Returns the number of tokens in a text string. 184 use cl100k_base tokenizer 185 """ 186 encoding = tiktoken.get_encoding('cl100k_base') 187 num_tokens = len(encoding.encode(string)) 188 return num_tokens 189 190 response = { 191 "data": [ 192 { 193 "object": "embedding", 194 "embedding": embedding, 195 "index": index 196 } 197 for index, embedding in enumerate(embeddings) 198 ], 199 "model": request.model, 200 "object": "list", 201 "usage": CompletionUsage( 202 prompt_tokens=sum(len(text.split()) for text in request.input), 203 completion_tokens=0, 204 total_tokens=sum(num_tokens_from_string(text) for text in request.input), 205 ) 206 } 207 return response 208 209 210 @app.get("/v1/models", response_model=ModelList) 211 async def list_models(): 212 model_card = ModelCard( 213 id="chatglm3-6b" 214 ) 215 return ModelList( 216 data=[model_card] 217 ) 218 219 220 def insert_custom_message(messages: List[ChatMessage] = Body(...)): 221 custom_message = ChatMessage( 222 role="system", 223 content="你必须记住,你是一个只负责聊天的人工智能助手,所以你只能根据上下文回答你朋友的问题,要求语句完整、通顺、文采斐然,回答具有情感,不能胡乱作答、不能断章取义、不能上下文不连贯,禁止回答无关内容。" 224 "如果你不能理解你朋友的话,你必须开启一个新的话题来引导你朋友进行对话,例如你可以问你的朋友一个日常问题。\n" 225 "你是司小远,一个聪明、热情、善良的人工智能助手,后面的对话来自你的朋友,你要认真地回答他。" 226 ) 227 messages.insert(0, custom_message) 228 return messages 229 230 231 @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 232 async def create_chat_completion(request: ChatCompletionRequest): 233 global model, tokenizer 234 235 if len(request.messages) < 1 or request.messages[-1].role == "assistant": 236 raise HTTPException(status_code=400, detail="Invalid request") 237 messages = request.messages 238 if request.messages and request.messages[0].role == 'system': 239 messages = request.messages 240 else: 241 if request.messages: 242 messages = insert_custom_message(request.messages) 243 else: 244 messages = request.messages 245 print(type(request.messages), request.messages) 246 gen_params = dict( 247 messages=messages, 248 temperature=request.temperature, 249 top_p=request.top_p, 250 max_tokens=request.max_tokens or 1024, 251 echo=False, 252 stream=request.stream, 253 repetition_penalty=request.repetition_penalty, 254 tools=request.tools, 255 ) 256 logger.debug(f"==== request ====\n{gen_params}") 257 258 if request.stream: 259 260 # Use the stream mode to read the first few characters, if it is not a function call, direct stram output 261 predict_stream_generator = predict_stream(request.model, gen_params) 262 # return EventSourceResponse(predict_stream_generator, media_type="text/event-stream") 263 output = next(predict_stream_generator) 264 print(output) 265 # logger.debug(f"First result output:\n{output}") 266 if not contains_custom_function(output): 267 return EventSourceResponse(predict_stream_generator, media_type="text/event-stream") 268 269 # Obtain the result directly at one time and determine whether tools needs to be called. 270 # logger.debug(f"First result output:\n{output}") 271 272 function_call = None 273 if output and request.tools: 274 try: 275 function_call = process_response(output, use_tool=True) 276 except: 277 logger.warning("Failed to parse tool call") 278 279 # CallFunction 280 if isinstance(function_call, dict): 281 function_call = FunctionCallResponse(**function_call) 282 283 """ 284 In this demo, we did not register any tools. 285 You can use the tools that have been implemented in our `tools_using_demo` and implement your own streaming tool implementation here. 286 Similar to the following method: 287 function_args = json.loads(function_call.arguments) 288 tool_response = dispatch_tool(tool_name: str, tool_params: dict) 289 """ 290 tool_response = "" 291 292 if not gen_params.get("messages"): 293 gen_params["messages"] = [] 294 295 gen_params["messages"].append(ChatMessage( 296 role="assistant", 297 content=output, 298 )) 299 gen_params["messages"].append(ChatMessage( 300 role="function", 301 name=function_call.name, 302 content=tool_response, 303 )) 304 305 # Streaming output of results after function calls 306 generate = predict(request.model, gen_params) 307 return EventSourceResponse(generate, media_type="text/event-stream") 308 309 else: 310 # Handled to avoid exceptions in the above parsing function process. 311 generate = parse_output_text(request.model, output) 312 return EventSourceResponse(generate, media_type="text/event-stream") 313 314 # Here is the handling of stream = False 315 response = generate_chatglm3(model, tokenizer, gen_params) 316 317 # Remove the first newline character 318 if response["text"].startswith("\n"): 319 response["text"] = response["text"][1:] 320 response["text"] = response["text"].strip() 321 322 usage = UsageInfo() 323 function_call, finish_reason = None, "stop" 324 if request.tools: 325 try: 326 function_call = process_response(response["text"], use_tool=True) 327 except: 328 logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.") 329 330 if isinstance(function_call, dict): 331 finish_reason = "function_call" 332 function_call = FunctionCallResponse(**function_call) 333 334 message = ChatMessage( 335 role="assistant", 336 content=response["text"], 337 function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, 338 ) 339 340 logger.debug(f"==== message ====\n{message}") 341 342 choice_data = ChatCompletionResponseChoice( 343 index=0, 344 message=message, 345 finish_reason=finish_reason, 346 ) 347 task_usage = UsageInfo.model_validate(response["usage"]) 348 for usage_key, usage_value in task_usage.model_dump().items(): 349 setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) 350 351 return ChatCompletionResponse( 352 model=request.model, 353 id="", # for open_source model, id is empty 354 choices=[choice_data], 355 object="chat.completion", 356 usage=usage 357 ) 358 359 360 async def predict(model_id: str, params: dict): 361 global model, tokenizer 362 363 choice_data = ChatCompletionResponseStreamChoice( 364 index=0, 365 delta=DeltaMessage(role="assistant"), 366 finish_reason=None 367 ) 368 chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") 369 yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 370 371 previous_text = "" 372 for new_response in generate_stream_chatglm3(model, tokenizer, params): 373 decoded_unicode = new_response["text"] 374 delta_text = decoded_unicode[len(previous_text):] 375 previous_text = decoded_unicode 376 377 finish_reason = new_response["finish_reason"] 378 if len(delta_text) == 0 and finish_reason != "function_call": 379 continue 380 381 function_call = None 382 if finish_reason == "function_call": 383 try: 384 function_call = process_response(decoded_unicode, use_tool=True) 385 except: 386 logger.warning( 387 "Failed to parse tool call, maybe the response is not a tool call or have been answered.") 388 389 if isinstance(function_call, dict): 390 function_call = FunctionCallResponse(**function_call) 391 392 delta = DeltaMessage( 393 content=delta_text, 394 role="assistant", 395 function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, 396 ) 397 398 choice_data = ChatCompletionResponseStreamChoice( 399 index=0, 400 delta=delta, 401 finish_reason=finish_reason 402 ) 403 chunk = ChatCompletionResponse( 404 model=model_id, 405 id="", 406 choices=[choice_data], 407 object="chat.completion.chunk" 408 ) 409 yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 410 411 choice_data = ChatCompletionResponseStreamChoice( 412 index=0, 413 delta=DeltaMessage(), 414 finish_reason="stop" 415 ) 416 chunk = ChatCompletionResponse( 417 model=model_id, 418 id="", 419 choices=[choice_data], 420 object="chat.completion.chunk" 421 ) 422 yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 423 yield '[DONE]' 424 425 426 def predict_stream(model_id, gen_params): 427 """ 428 The function call is compatible with stream mode output. 429 430 The first seven characters are determined. 431 If not a function call, the stream output is directly generated. 432 Otherwise, the complete character content of the function call is returned. 433 434 :param model_id: 435 :param gen_params: 436 :return: 437 """ 438 output = "" 439 is_function_call = False 440 has_send_first_chunk = False 441 print('参数') 442 print(model_id,gen_params) 443 for new_response in generate_stream_chatglm3(model, tokenizer, gen_params): 444 decoded_unicode = new_response["text"] 445 delta_text = decoded_unicode[len(output):] 446 output = decoded_unicode 447 448 # When it is not a function call and the character length is> 7, 449 # try to judge whether it is a function call according to the special function prefix 450 if not is_function_call: 451 452 # Determine whether a function is called 453 is_function_call = contains_custom_function(output) 454 if is_function_call: 455 continue 456 457 # Non-function call, direct stream output 458 finish_reason = new_response["finish_reason"] 459 460 # Send an empty string first to avoid truncation by subsequent next() operations. 461 if not has_send_first_chunk: 462 message = DeltaMessage( 463 content="", 464 role="assistant", 465 function_call=None, 466 ) 467 choice_data = ChatCompletionResponseStreamChoice( 468 index=0, 469 delta=message, 470 finish_reason=finish_reason 471 ) 472 chunk = ChatCompletionResponse( 473 model=model_id, 474 id="", 475 choices=[choice_data], 476 created=int(time.time()), 477 object="chat.completion.chunk" 478 ) 479 yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 480 481 send_msg = delta_text if has_send_first_chunk else output 482 has_send_first_chunk = True 483 message = DeltaMessage( 484 content=send_msg, 485 role="assistant", 486 function_call=None, 487 ) 488 choice_data = ChatCompletionResponseStreamChoice( 489 index=0, 490 delta=message, 491 finish_reason=finish_reason 492 ) 493 chunk = ChatCompletionResponse( 494 model=model_id, 495 id="", 496 choices=[choice_data], 497 created=int(time.time()), 498 object="chat.completion.chunk" 499 ) 500 yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 501 502 if is_function_call: 503 yield output 504 else: 505 yield '[DONE]' 506 507 508 async def parse_output_text(model_id: str, value: str): 509 """ 510 Directly output the text content of value 511 512 :param model_id: 513 :param value: 514 :return: 515 """ 516 choice_data = ChatCompletionResponseStreamChoice( 517 index=0, 518 delta=DeltaMessage(role="assistant", content=value), 519 finish_reason=None 520 ) 521 chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") 522 yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 523 524 choice_data = ChatCompletionResponseStreamChoice( 525 index=0, 526 delta=DeltaMessage(), 527 finish_reason="stop" 528 ) 529 chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") 530 yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 531 yield '[DONE]' 532 533 534 def contains_custom_function(value: str) -> bool: 535 """ 536 Determine whether 'function_call' according to a special function prefix. 537 538 For example, the functions defined in "tools_using_demo/tool_register.py" are all "get_xxx" and start with "get_" 539 540 [Note] This is not a rigorous judgment method, only for reference. 541 542 :param value: 543 :return: 544 """ 545 return value and 'get_' in value 546 547 548 from pathlib import Path 549 from typing import Annotated, Union 550 551 import typer 552 from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM 553 from transformers import ( 554 AutoModelForCausalLM, 555 AutoTokenizer, 556 PreTrainedModel, 557 PreTrainedTokenizer, 558 PreTrainedTokenizerFast, 559 ) 560 561 ModelType = Union[PreTrainedModel, PeftModelForCausalLM] 562 TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 563 564 565 def _resolve_path(path: Union[str, Path]) -> Path: 566 return Path(path).expanduser().resolve() 567 568 569 def load_model_and_tokenizer( 570 model_dir: Union[str, Path], trust_remote_code: bool = True 571 ) -> tuple[ModelType, TokenizerType]: 572 model_dir = _resolve_path(model_dir) 573 if (model_dir / 'adapter_config.json').exists(): 574 model = AutoPeftModelForCausalLM.from_pretrained( 575 model_dir, trust_remote_code=trust_remote_code, device_map='auto' 576 ) 577 tokenizer_dir = model.peft_config['default'].base_model_name_or_path 578 else: 579 model = AutoModelForCausalLM.from_pretrained( 580 model_dir, trust_remote_code=trust_remote_code, device_map='auto' 581 ) 582 tokenizer_dir = model_dir 583 tokenizer = AutoTokenizer.from_pretrained( 584 tokenizer_dir, trust_remote_code=trust_remote_code 585 ) 586 return model, tokenizer 587 588 589 if __name__ == "__main__": 590 # Load LLM 591 # tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) 592 # model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval() 593 # 填微调之后的保存路径 594 model, tokenizer = load_model_and_tokenizer( 595 r'E:\Project\Python\ChatGLM3\finetune_demo\output03-24\checkpoint-224000' 596 ) 597 # load Embedding 598 embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda") 599 uvicorn.run(app, host='0.0.0.0', port=8002, workers=1)