/ web_demo_streaming / app.py
app.py
1 from threading import Thread 2 from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, TextIteratorStreamer 3 import openai 4 import copy 5 import shutil 6 from PIL import Image 7 from argparse import ArgumentParser 8 import io 9 import pathlib 10 import gradio as gr 11 import time 12 13 import base64 14 import pathlib 15 from typing import Dict 16 17 import gradio as gr 18 import os 19 import time 20 21 from qwen_vl_utils import process_vision_info, smart_resize 22 import tempfile 23 import time 24 import imagesize 25 import uuid 26 27 from PIL import Image, ImageFile 28 29 ImageFile.LOAD_TRUNCATED_IMAGES = True 30 ImageFile.MAX_IMAGE_PIXELS = None 31 Image.MAX_IMAGE_PIXELS = None 32 33 image_transform = None 34 oss_reader = None 35 36 37 MAX_SEQ_LEN = 32000 38 39 DEFAULT_CKPT_PATH = 'Qwen/Qwen2-VL-7B-Instruct' 40 41 def compute_seqlen_estimated(tokenizer, json_input, sample_strategy_func): 42 total_seq_len, img_seq_len, text_seq_len = 0, 0, 0 43 for chat_block in json_input: 44 45 chat_block['seq_len'] = 4 46 role_length = len(tokenizer.tokenize(chat_block['role'])) 47 chat_block['seq_len'] += role_length 48 text_seq_len += role_length 49 50 for element in chat_block['content']: 51 if 'image' in element: 52 if 'width' not in element: 53 element['width'], element['height'] = imagesize.get( 54 element['image'].split('file://')[1]) 55 height, width = element['height'], element['width'] 56 height, width = sample_strategy_func(height, width) 57 resized_height, resized_width = smart_resize( 58 height, width, max_pixels=14*14*4*5120) # , min_pixels=14*14*4*512 59 seq_len = resized_height * resized_width // 28 // 28 + 2 # add img_bos & img_eos 60 element.update({ 61 'resized_height': resized_height, 62 'resized_width': resized_width, 63 'seq_len': seq_len, 64 }) 65 img_seq_len += element['seq_len'] 66 chat_block['seq_len'] += element['seq_len'] 67 elif 'video' in element: 68 if isinstance(element['video'], (list, tuple)): 69 if 'width' not in element: 70 element['width'], element['height'] = imagesize.get( 71 element['video'][0].split('file://')[1]) 72 height, width = element['height'], element['width'] 73 height, width = sample_strategy_func(height, width) 74 resized_height, resized_width = smart_resize( 75 height, width, max_pixels=14*14*4*5120) # , min_pixels=14*14*4*512 76 seq_len = (resized_height * resized_width // 28 // 28) * \ 77 (len(element['video'])//2)+2 # add img_bos & img_eos 78 element.update({ 79 'resized_height': resized_height, 80 'resized_width': resized_width, 81 'seq_len': seq_len, 82 }) 83 img_seq_len += element['seq_len'] 84 chat_block['seq_len'] += element['seq_len'] 85 else: 86 raise NotImplementedError 87 elif 'text' in element: 88 if 'seq_len' in element: 89 text_seq_len += element['seq_len'] 90 else: 91 element['seq_len'] = len( 92 tokenizer.tokenize(element['text'])) 93 text_seq_len += element['seq_len'] 94 chat_block['seq_len'] += element['seq_len'] 95 elif 'prompt' in element: 96 if 'seq_len' in element: 97 text_seq_len += element['seq_len'] 98 else: 99 element['seq_len'] = len( 100 tokenizer.tokenize(element['prompt'])) 101 text_seq_len += element['seq_len'] 102 chat_block['seq_len'] += element['seq_len'] 103 else: 104 raise ValueError('Unknown element: ' + str(element)) 105 total_seq_len += chat_block['seq_len'] 106 assert img_seq_len + text_seq_len + 4 * len(json_input) == total_seq_len 107 total_seq_len += 1 108 return { 109 'content': json_input, 110 'img_seq_len': img_seq_len, 111 'text_seq_len': text_seq_len, 112 'seq_len': total_seq_len, 113 } 114 115 116 def _get_args(): 117 parser = ArgumentParser() 118 119 parser.add_argument('-c', 120 '--checkpoint-path', 121 type=str, 122 default=DEFAULT_CKPT_PATH, 123 help='Checkpoint name or path, default to %(default)r') 124 parser.add_argument('--cpu-only', action='store_true', 125 help='Run demo with CPU only') 126 127 parser.add_argument('--flash-attn2', 128 action='store_true', 129 default=False, 130 help='Enable flash_attention_2 when loading the model.') 131 parser.add_argument('--share', 132 action='store_true', 133 default=False, 134 help='Create a publicly shareable link for the interface.') 135 parser.add_argument('--inbrowser', 136 action='store_true', 137 default=False, 138 help='Automatically launch the interface in a new tab on the default browser.') 139 parser.add_argument('--server-port', type=int, 140 default=7860, help='Demo server port.') 141 parser.add_argument('--server-name', type=str, 142 default='127.0.0.1', help='Demo server name.') 143 144 args = parser.parse_args() 145 return args 146 147 148 def _load_model_processor(args): 149 if args.cpu_only: 150 device_map = 'cpu' 151 else: 152 device_map = 'auto' 153 154 # Check if flash-attn2 flag is enabled and load model accordingly 155 if args.flash_attn2: 156 model = Qwen2VLForConditionalGeneration.from_pretrained(args.checkpoint_path, 157 torch_dtype='auto', 158 attn_implementation='flash_attention_2', 159 device_map=device_map) 160 else: 161 model = Qwen2VLForConditionalGeneration.from_pretrained( 162 args.checkpoint_path, device_map=device_map) 163 164 processor = AutoProcessor.from_pretrained(args.checkpoint_path) 165 return model, processor 166 167 168 class ChatSessionState: 169 def __init__(self, session_id: str): 170 self.session_id: str = session_id 171 self.system_prompt: str = 'You are a helpful assistant.' 172 self.model_name = '' 173 self.image_cache = [] 174 175 176 def _transform_messages(original_messages): 177 transformed_messages = [] 178 for message in original_messages: 179 new_content = [] 180 for item in message['content']: 181 if 'image' in item: 182 new_item = {'type': 'image', 'image': item['image']} 183 elif 'text' in item: 184 new_item = {'type': 'text', 'text': item['text']} 185 elif 'video' in item: 186 new_item = {'type': 'video', 'video': item['video']} 187 else: 188 continue 189 new_content.append(new_item) 190 191 new_message = {'role': message['role'], 'content': new_content} 192 transformed_messages.append(new_message) 193 194 return transformed_messages 195 196 197 class Worker: 198 def __init__(self): 199 self.uids = [] 200 201 capture_image_dir = os.path.join("/tmp/captured_images") 202 os.makedirs(capture_image_dir, exist_ok=True) 203 self.capture_image_dir = capture_image_dir # uid-to-messages 204 205 self.save_dir = dict() 206 self.messages = dict() # uid-to-messages 207 self.resized_width, self.resized_height = 640, 420 208 # self.message_truncate = 0 209 self.message_truncate = {} 210 self.chat_session_states: Dict[str, ChatSessionState] = {} 211 self.image_cache = {} 212 213 def convert_image_to_base64(self, file_name): 214 if file_name not in self.image_cache: 215 self.image_cache[file_name] = {} 216 if 'data_url' not in self.image_cache[file_name]: 217 with open(file_name, 'rb') as f: 218 self.image_cache[file_name]['data_url'] = 'data:image/png;base64,' + \ 219 base64.b64encode(f.read()).decode('utf-8') 220 assert self.image_cache[file_name]['data_url'] 221 return self.image_cache[file_name]['data_url'] 222 223 def get_session_state(self, session_id: str) -> ChatSessionState: 224 """ 225 Retrieves the chat session state object for a given session ID. 226 227 If the session ID does not exist in the currently managed session states, 228 a new session state object is created and added to the list of managed sessions. 229 230 Parameters: 231 session_id: The unique identifier for the session. 232 233 Returns: 234 The session state object corresponding to the session ID. 235 """ 236 # Check if the current session state collection already contains this session ID 237 if session_id not in self.chat_session_states: 238 # If it does not exist, create a new session state object and add it to the collection 239 self.chat_session_states[session_id] = ChatSessionState(session_id) 240 # Return the corresponding session state object 241 return self.chat_session_states[session_id] 242 243 def get_message_truncate(self, session_id): 244 if session_id not in self.message_truncate: 245 self.message_truncate[session_id] = 0 246 return self.message_truncate[session_id] 247 248 def truncate_messages_adaptive(self, messages): 249 while True: 250 seq_len = compute_seqlen_estimated(tokenizer, copy.deepcopy( 251 messages), sample_strategy_func=lambda h, w: (h, w))['seq_len'] 252 if seq_len < MAX_SEQ_LEN: 253 break 254 # Remove the first block in content history: 255 if len(messages[0]['content']) > 0 and 'video' in messages[0]['content'][0]: 256 messages[0]['content'][0]['video'] = messages[0]['content'][0]['video'][2:] 257 if len(messages[0]['content'][0]['video']) == 0: 258 messages[0]['content'] = messages[0]['content'][1:] 259 else: 260 messages[0]['content'] = messages[0]['content'][1:] 261 262 # If the first block is empty, remove it: 263 if len(messages[0]['content']) == 0: 264 messages.pop(0) 265 266 # If role is assistant, remove the first block in content history: 267 if messages[0]['role'] == 'assistant': 268 messages.pop(0) 269 return messages 270 271 def truncate_messages_by_count(self, messages, cnt): 272 for i in range(cnt): 273 # Remove the first block in content history: 274 if len(messages[0]['content']) > 0 and 'video' in messages[0]['content'][0]: 275 messages[0]['content'][0]['video'] = messages[0]['content'][0]['video'][2:] 276 if len(messages[0]['content'][0]['video']) == 0: 277 messages[0]['content'] = messages[0]['content'][1:] 278 else: 279 messages[0]['content'] = messages[0]['content'][1:] 280 281 # If the first block is empty, remove it: 282 if len(messages[0]['content']) == 0: 283 messages.pop(0) 284 285 # If role is assistant, remove the first block in content history: 286 if messages[0]['role'] == 'assistant': 287 messages.pop(0) 288 289 def get_save_dir(self, session_id): 290 if self.save_dir.get(session_id) is None: 291 temp_dir = tempfile.mkdtemp(dir=self.capture_image_dir) 292 self.save_dir[session_id] = temp_dir 293 return self.save_dir[session_id] 294 295 def get_messages(self, session_id): 296 if self.messages.get(session_id) is None: 297 self.messages[session_id] = [] 298 return self.messages[session_id] 299 300 def update_messages(self, session_id, role, content): 301 if self.messages.get(session_id) is None: 302 self.messages[session_id] = [] 303 messages = self.messages[session_id] 304 if len(messages) == 0 or messages[-1]["role"] != role: 305 messages.append({ 306 "role": role, 307 "content": [content] 308 }) 309 elif "video" in content and isinstance(content["video"], (list, tuple)) and "video" in messages[-1]["content"][-1] and isinstance(messages[-1]["content"][-1]["video"], (list, tuple)): 310 messages[-1]["content"][-1]['video'].extend(content["video"]) 311 else: 312 # If content and last message are all with type text, merge them 313 if 'text' in messages[-1]["content"][-1] and 'text' in content: 314 messages[-1]["content"][-1]['text'] += content["text"] 315 else: 316 messages[-1]["content"].append(content) 317 self.messages[session_id] = messages 318 319 def get_timestamp(self): 320 return time.time() 321 322 def chat(self, messages, request: gr.Request): 323 messages = _transform_messages(messages) 324 325 text = processor.apply_chat_template( 326 messages, tokenize=False, add_generation_prompt=True) 327 image_inputs, video_inputs = process_vision_info(messages) 328 inputs = processor(text=[text], images=image_inputs, 329 videos=video_inputs, padding=True, return_tensors='pt') 330 inputs = inputs.to(model.device) 331 332 streamer = TextIteratorStreamer( 333 tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) 334 335 gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs} 336 337 thread = Thread(target=model.generate, kwargs=gen_kwargs) 338 thread.start() 339 340 for new_text in streamer: 341 yield new_text 342 343 def add_text(self, history, text, request: gr.Request): 344 session_id = request.session_hash 345 session_state: ChatSessionState = self.get_session_state( 346 request.session_hash) 347 348 if len(session_state.image_cache) > 0: 349 for i, (timestamp, image_path) in enumerate(session_state.image_cache): 350 if i % 2 == 0: 351 content = {"video": [f"file://{image_path}"]} 352 else: 353 content["video"].append(f"file://{image_path}") 354 self.update_messages( 355 session_id, role="user", content=content) 356 if i == len(session_state.image_cache)-1 and i % 2 == 0: 357 content["video"].append(content["video"][-1]) 358 self.update_messages( 359 session_id, role="user", content=content) 360 361 session_state.image_cache.clear() 362 363 self.update_messages(session_id, role="user", content={ 364 "type": "text", "text": text}) 365 366 history = history + [(text, None)] 367 return history, "" 368 369 def add_file(self, history, file, request: gr.Request): 370 session_id = request.session_hash 371 session_state: ChatSessionState = self.get_session_state(session_id) 372 if isinstance(file, str) and file.startswith('data:'): 373 # get binary bytes 374 data = base64.b64decode(file.split('base64,')[1]) 375 # Create a file name using uuid 376 filename = f'{uuid.uuid4()}.jpg' 377 save_dir = self.get_save_dir(session_id) 378 savename = os.path.join(save_dir, filename) 379 # Save the file 380 with open(savename, 'wb') as f: 381 f.write(data) 382 self.update_messages(session_id, role="user", content={ 383 "image": f"file://{savename}"}) 384 else: 385 filename = os.path.basename(file.name) 386 save_dir = self.get_save_dir(session_id) 387 savename = os.path.join(save_dir, filename) 388 if file.name.endswith('.mp4') or file.name.endswith('.mov'): 389 shutil.copy(file.name, savename) 390 os.makedirs(file.name + '.frames', exist_ok=True) 391 os.system( 392 f'ffmpeg -i {file.name} -vf "scale=320:-1" -r 0.25 {file.name}.frames/%d.jpg') 393 file_index = 1 394 frame_list = [] 395 while True: 396 if os.path.isfile(os.path.join(f'{file.name}.frames/{file_index}.jpg')): 397 frame_list.append(os.path.join( 398 f'file://{file.name}.frames/{file_index}.jpg')) 399 file_index += 1 400 else: 401 break 402 if len(frame_list) % 2 != 0: 403 frame_list = frame_list[1:] 404 self.update_messages(session_id, role="user", content={ 405 "video": frame_list}) 406 else: 407 shutil.copy(file.name, savename) 408 self.update_messages(session_id, role="user", content={ 409 "image": f"file://{savename}"}) 410 411 history = history + [((savename,), None)] 412 return history 413 414 def add_image_to_streaming_cache(self, file, width, height, request: gr.Request): 415 session_id = request.session_hash 416 session_state: ChatSessionState = self.get_session_state(session_id) 417 timestamp = self.get_timestamp() 418 # If file is an image url starswith data:, save it to the session directory 419 if isinstance(file, str) and file.startswith('data:'): 420 # get binary bytes 421 data = base64.b64decode(file.split('base64,')[1]) 422 width, height = int(width), int(height) 423 # Load the image using PIL 424 image = Image.open(io.BytesIO(data)) 425 # If width == -1, no need to scale the image 426 if width == -1: 427 pass 428 else: 429 # If height == -1, keep aspect ratio 430 if height == -1: 431 height = round(width * image.height / float(image.width)) 432 image = image.resize((width, height), Image.LANCZOS) 433 # Create a file name using uuid 434 filename = f'{uuid.uuid4()}.jpg' 435 save_dir = self.get_save_dir(session_id) 436 savename = os.path.join(save_dir, filename) 437 # Save the file 438 image.save(savename, "JPEG") 439 else: 440 filename = os.path.basename(file.name) 441 save_dir = self.get_save_dir(session_id) 442 savename = os.path.join(save_dir, filename) 443 shutil.copy(file.name, savename) 444 445 session_state.image_cache.append((timestamp, savename)) 446 447 def response(self, chatbot_messages, request: gr.Request): 448 session_id = request.session_hash 449 messages = self.get_messages(session_id) 450 self.truncate_messages_adaptive(messages) 451 messages = copy.deepcopy(messages) 452 chatbot_messages = copy.deepcopy(chatbot_messages) 453 if chatbot_messages is None: 454 chatbot_messages = [] 455 truncate_count = 0 456 while True: 457 compiled_messages = copy.deepcopy(messages) 458 self.truncate_messages_by_count( 459 compiled_messages, cnt=truncate_count) 460 # Convert file:// image urls to data:base64 urls 461 for message in compiled_messages: 462 for content in message['content']: 463 if 'image' in content: 464 if content['image'].startswith('file://'): 465 content['image'] = self.convert_image_to_base64( 466 content['image'][7:]) 467 elif 'video' in content and isinstance(content['video'], (list, tuple)): 468 for frame_i in range(len(content['video'])): 469 if content['video'][frame_i].startswith('file://'): 470 content['video'][frame_i] = self.convert_image_to_base64( 471 content['video'][frame_i][7:]) 472 rep = self.chat(compiled_messages, request=request) 473 try: 474 for content in rep: 475 if not content: 476 continue 477 self.update_messages(session_id, role="assistant", content={ 478 "type": "text", "text": content}) 479 if not chatbot_messages[-1][-1]: 480 chatbot_messages[-1][-1] = content 481 else: 482 chatbot_messages[-1][-1] += content 483 yield chatbot_messages 484 break 485 except openai.BadRequestError as e: 486 print(e) 487 if 'maximum context length' not in str(e): 488 raise e 489 if self.messages[session_id][-1]['role'] == 'assistant': 490 chatbot_messages[-1][-1] = '' 491 self.messages[session_id] = self.messages[session_id][:-1] 492 # self.messages[session_id][-1]['content'][-1] = {'text': ''} 493 self.message_truncate[session_id] += 1 494 495 496 recorder_js = pathlib.Path('recorder.js').read_text() 497 main_js = pathlib.Path('main.js').read_text() 498 GLOBAL_JS = pathlib.Path('global.js').read_text().replace('let recorder_js = null;', recorder_js).replace( 499 'let main_js = null;', main_js) 500 501 502 def main(): 503 with gr.Blocks(js=GLOBAL_JS) as demo: 504 gr.Markdown("""\ 505 <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 80px"/><p>""" 506 ) 507 gr.Markdown("""<center><font size=8>Qwen2-VL</center>""") 508 gr.Markdown("""\ 509 <center><font size=3>This WebUI is based on Qwen2-VL, developed by Alibaba Cloud.</center>""") 510 gr.Markdown("""<center><font size=3>本WebUI基于Qwen2-VL。</center>""") 511 with gr.Accordion("Advanced Settings", open=False): 512 with gr.Accordion("System Prompt", open=False): 513 textbox_system_prompt = gr.Textbox( 514 value="You are a helpful assistant.", label="System Prompt") 515 with gr.Row(): 516 with gr.Column(scale=1): 517 with gr.Tab("Camera"): 518 image_camera = gr.Image(sources='webcam', label="Camera Preview", 519 mirror_webcam=False, elem_id="gradio_image_camera_preview") 520 with gr.Accordion("Camera Settings", open=False): 521 with gr.Row(): 522 camera_frame_interval = gr.Textbox( 523 "1", label="Frame interval or (1 / FPS)", elem_id="gradio_camera_frame_interval", interactive=True) 524 with gr.Row(): 525 camera_width = gr.Textbox( 526 "640", label="Width (-1 = original resolution)") 527 camera_height = gr.Textbox( 528 "-1", label="Height (-1 = keep aspect ratio)") 529 with gr.Row(): 530 button_camera_stream = gr.Button( 531 "Stream", elem_id="gradio_button_camera_stream") 532 button_camera_snapshot = gr.Button( 533 "Snapshot", elem_id="gradio_button_camera_snapshot") 534 button_camera_stream_submit = gr.Button( 535 "Snapshot", elem_id="gradio_button_camera_stream_submit", visible=False) 536 with gr.Tab("Screen"): 537 image_screen = gr.Image( 538 sources='webcam', label="Screen Preview", elem_id="gradio_image_screen_preview") 539 with gr.Accordion("Screen Settings", open=False): 540 with gr.Row(): 541 screen_frame_interval = gr.Textbox( 542 "5", label="Frame interval or (1 / FPS)", elem_id="gradio_screen_frame_interval", interactive=True) 543 with gr.Row(): 544 screen_width = gr.Textbox( 545 "-1", label="Width (-1 = original resolution)") 546 screen_height = gr.Textbox( 547 "-1", label="Height (-1 = keep aspect ratio)") 548 with gr.Row(): 549 button_screen_stream = gr.Button( 550 "Stream", elem_id="gradio_button_screen_stream") 551 button_screen_snapshot = gr.Button( 552 "Snapshot", elem_id="gradio_button_screen_snapshot") 553 button_screen_stream_submit = gr.Button( 554 "Snapshot", elem_id="gradio_button_screen_stream_submit", visible=False) 555 556 with gr.Column(scale=2): 557 chatbot = gr.Chatbot([], elem_id="chatofa", height=500) 558 with gr.Row(): 559 txt = gr.Textbox( 560 show_label=False, 561 placeholder="Enter text and press enter, or upload an image", 562 container=False, 563 scale=5, 564 ) 565 btn = gr.UploadButton( 566 "📁", file_types=["image", "video", "audio"], scale=1) 567 568 txt.submit( 569 fn=worker.add_text, 570 inputs=[chatbot, txt], 571 outputs=[chatbot, txt] 572 ).then( 573 fn=worker.response, 574 inputs=[chatbot], 575 outputs=chatbot 576 ) 577 578 btn.upload( 579 worker.add_file, 580 inputs=[chatbot, btn], 581 outputs=[chatbot] 582 ) 583 584 # Camera 585 button_camera_snapshot.click( 586 worker.add_file, 587 inputs=[chatbot, button_camera_snapshot], 588 outputs=[chatbot], 589 js="(p1, p2) => [p1, window.getCameraFrame()]", 590 ) 591 button_camera_stream_submit.click( 592 worker.add_image_to_streaming_cache, 593 inputs=[button_camera_stream_submit, 594 camera_width, camera_height], 595 outputs=[], 596 js="(p1, p2, p3) => [window.getCameraFrame(), p2, p3]", 597 ) 598 button_camera_stream.click( 599 lambda x: None, 600 inputs=[button_camera_stream], 601 outputs=[], 602 js="(p1, p2) => (window.startCameraStreaming())" 603 ) 604 605 # Screen 606 button_screen_snapshot.click( 607 worker.add_file, 608 inputs=[chatbot, button_screen_snapshot], 609 outputs=[chatbot], 610 js="(p1, p2) => [p1, window.getScreenshotFrame()]", 611 ) 612 button_screen_stream_submit.click( 613 worker.add_image_to_streaming_cache, 614 inputs=[button_screen_stream_submit, 615 screen_width, screen_height], 616 outputs=[], 617 js="(p1, p2, p3) => [window.getScreenshotFrame(), p2, p3]", 618 ) 619 button_screen_stream.click( 620 lambda x: None, 621 inputs=[button_screen_stream], 622 outputs=[], 623 js="(p1, p2) => (window.startScreenStreaming())" 624 ) 625 with gr.Row(): 626 gr.Markdown("""\ 627 <font size=2>Note: This demo is governed by the original license of Qwen2-VL. \ 628 We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \ 629 including hate speech, violence, pornography, deception, etc. \ 630 (注:本演示受Qwen2-VL的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\ 631 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""") 632 demo.launch( 633 share=args.share, 634 inbrowser=args.inbrowser, 635 server_port=args.server_port, 636 server_name=args.server_name, 637 ) 638 639 640 if __name__ == '__main__': 641 worker = Worker() 642 args = _get_args() 643 model, processor = _load_model_processor(args) 644 tokenizer = processor.tokenizer 645 main()