/ web_demo_mm.py
web_demo_mm.py
1 # Copyright (c) Alibaba Cloud. 2 # 3 # This source code is licensed under the license found in the 4 # LICENSE file in the root directory of this source tree. 5 6 import copy 7 import re 8 from argparse import ArgumentParser 9 from threading import Thread 10 11 import gradio as gr 12 import torch 13 from qwen_vl_utils import process_vision_info 14 from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, TextIteratorStreamer 15 16 DEFAULT_CKPT_PATH = 'Qwen/Qwen2-VL-7B-Instruct' 17 18 19 def _get_args(): 20 parser = ArgumentParser() 21 22 parser.add_argument('-c', 23 '--checkpoint-path', 24 type=str, 25 default=DEFAULT_CKPT_PATH, 26 help='Checkpoint name or path, default to %(default)r') 27 parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only') 28 29 parser.add_argument('--flash-attn2', 30 action='store_true', 31 default=False, 32 help='Enable flash_attention_2 when loading the model.') 33 parser.add_argument('--share', 34 action='store_true', 35 default=False, 36 help='Create a publicly shareable link for the interface.') 37 parser.add_argument('--inbrowser', 38 action='store_true', 39 default=False, 40 help='Automatically launch the interface in a new tab on the default browser.') 41 parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.') 42 parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.') 43 44 args = parser.parse_args() 45 return args 46 47 48 def _load_model_processor(args): 49 if args.cpu_only: 50 device_map = 'cpu' 51 else: 52 device_map = 'auto' 53 54 # Check if flash-attn2 flag is enabled and load model accordingly 55 if args.flash_attn2: 56 model = Qwen2VLForConditionalGeneration.from_pretrained(args.checkpoint_path, 57 torch_dtype='auto', 58 attn_implementation='flash_attention_2', 59 device_map=device_map) 60 else: 61 model = Qwen2VLForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map=device_map) 62 63 processor = AutoProcessor.from_pretrained(args.checkpoint_path) 64 return model, processor 65 66 67 def _parse_text(text): 68 lines = text.split('\n') 69 lines = [line for line in lines if line != ''] 70 count = 0 71 for i, line in enumerate(lines): 72 if '```' in line: 73 count += 1 74 items = line.split('`') 75 if count % 2 == 1: 76 lines[i] = f'<pre><code class="language-{items[-1]}">' 77 else: 78 lines[i] = '<br></code></pre>' 79 else: 80 if i > 0: 81 if count % 2 == 1: 82 line = line.replace('`', r'\`') 83 line = line.replace('<', '<') 84 line = line.replace('>', '>') 85 line = line.replace(' ', ' ') 86 line = line.replace('*', '*') 87 line = line.replace('_', '_') 88 line = line.replace('-', '-') 89 line = line.replace('.', '.') 90 line = line.replace('!', '!') 91 line = line.replace('(', '(') 92 line = line.replace(')', ')') 93 line = line.replace('$', '$') 94 lines[i] = '<br>' + line 95 text = ''.join(lines) 96 return text 97 98 99 def _remove_image_special(text): 100 text = text.replace('<ref>', '').replace('</ref>', '') 101 return re.sub(r'<box>.*?(</box>|$)', '', text) 102 103 104 def _is_video_file(filename): 105 video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg'] 106 return any(filename.lower().endswith(ext) for ext in video_extensions) 107 108 109 def _gc(): 110 import gc 111 gc.collect() 112 if torch.cuda.is_available(): 113 torch.cuda.empty_cache() 114 115 116 def _transform_messages(original_messages): 117 transformed_messages = [] 118 for message in original_messages: 119 new_content = [] 120 for item in message['content']: 121 if 'image' in item: 122 new_item = {'type': 'image', 'image': item['image']} 123 elif 'text' in item: 124 new_item = {'type': 'text', 'text': item['text']} 125 elif 'video' in item: 126 new_item = {'type': 'video', 'video': item['video']} 127 else: 128 continue 129 new_content.append(new_item) 130 131 new_message = {'role': message['role'], 'content': new_content} 132 transformed_messages.append(new_message) 133 134 return transformed_messages 135 136 137 def _launch_demo(args, model, processor): 138 139 def call_local_model(model, processor, messages): 140 141 messages = _transform_messages(messages) 142 143 text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 144 image_inputs, video_inputs = process_vision_info(messages) 145 inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt') 146 inputs = inputs.to(model.device) 147 148 tokenizer = processor.tokenizer 149 streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) 150 151 gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs} 152 153 thread = Thread(target=model.generate, kwargs=gen_kwargs) 154 thread.start() 155 156 generated_text = '' 157 for new_text in streamer: 158 generated_text += new_text 159 yield generated_text 160 161 def create_predict_fn(): 162 163 def predict(_chatbot, task_history): 164 nonlocal model, processor 165 chat_query = _chatbot[-1][0] 166 query = task_history[-1][0] 167 if len(chat_query) == 0: 168 _chatbot.pop() 169 task_history.pop() 170 return _chatbot 171 print('User: ' + _parse_text(query)) 172 history_cp = copy.deepcopy(task_history) 173 full_response = '' 174 messages = [] 175 content = [] 176 for q, a in history_cp: 177 if isinstance(q, (tuple, list)): 178 if _is_video_file(q[0]): 179 content.append({'video': f'file://{q[0]}'}) 180 else: 181 content.append({'image': f'file://{q[0]}'}) 182 else: 183 content.append({'text': q}) 184 messages.append({'role': 'user', 'content': content}) 185 messages.append({'role': 'assistant', 'content': [{'text': a}]}) 186 content = [] 187 messages.pop() 188 189 for response in call_local_model(model, processor, messages): 190 _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response))) 191 192 yield _chatbot 193 full_response = _parse_text(response) 194 195 task_history[-1] = (query, full_response) 196 print('Qwen-VL-Chat: ' + _parse_text(full_response)) 197 yield _chatbot 198 199 return predict 200 201 def create_regenerate_fn(): 202 203 def regenerate(_chatbot, task_history): 204 nonlocal model, processor 205 if not task_history: 206 return _chatbot 207 item = task_history[-1] 208 if item[1] is None: 209 return _chatbot 210 task_history[-1] = (item[0], None) 211 chatbot_item = _chatbot.pop(-1) 212 if chatbot_item[0] is None: 213 _chatbot[-1] = (_chatbot[-1][0], None) 214 else: 215 _chatbot.append((chatbot_item[0], None)) 216 _chatbot_gen = predict(_chatbot, task_history) 217 for _chatbot in _chatbot_gen: 218 yield _chatbot 219 220 return regenerate 221 222 predict = create_predict_fn() 223 regenerate = create_regenerate_fn() 224 225 def add_text(history, task_history, text): 226 task_text = text 227 history = history if history is not None else [] 228 task_history = task_history if task_history is not None else [] 229 history = history + [(_parse_text(text), None)] 230 task_history = task_history + [(task_text, None)] 231 return history, task_history, '' 232 233 def add_file(history, task_history, file): 234 history = history if history is not None else [] 235 task_history = task_history if task_history is not None else [] 236 history = history + [((file.name,), None)] 237 task_history = task_history + [((file.name,), None)] 238 return history, task_history 239 240 def reset_user_input(): 241 return gr.update(value='') 242 243 def reset_state(_chatbot, task_history): 244 task_history.clear() 245 _chatbot.clear() 246 _gc() 247 return [] 248 249 with gr.Blocks() as demo: 250 gr.Markdown("""\ 251 <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 80px"/><p>""" 252 ) 253 gr.Markdown("""<center><font size=8>Qwen2-VL</center>""") 254 gr.Markdown("""\ 255 <center><font size=3>This WebUI is based on Qwen2-VL, developed by Alibaba Cloud.</center>""") 256 gr.Markdown("""<center><font size=3>本WebUI基于Qwen2-VL。</center>""") 257 258 chatbot = gr.Chatbot(label='Qwen2-VL', elem_classes='control-height', height=500) 259 query = gr.Textbox(lines=2, label='Input') 260 task_history = gr.State([]) 261 262 with gr.Row(): 263 addfile_btn = gr.UploadButton('📁 Upload (上传文件)', file_types=['image', 'video']) 264 submit_btn = gr.Button('🚀 Submit (发送)') 265 regen_btn = gr.Button('🤔️ Regenerate (重试)') 266 empty_bin = gr.Button('🧹 Clear History (清除历史)') 267 268 submit_btn.click(add_text, [chatbot, task_history, query], 269 [chatbot, task_history]).then(predict, [chatbot, task_history], [chatbot], show_progress=True) 270 submit_btn.click(reset_user_input, [], [query]) 271 empty_bin.click(reset_state, [chatbot, task_history], [chatbot], show_progress=True) 272 regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True) 273 addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True) 274 275 gr.Markdown("""\ 276 <font size=2>Note: This demo is governed by the original license of Qwen2-VL. \ 277 We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \ 278 including hate speech, violence, pornography, deception, etc. \ 279 (注:本演示受Qwen2-VL的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\ 280 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""") 281 282 demo.queue().launch( 283 share=args.share, 284 inbrowser=args.inbrowser, 285 server_port=args.server_port, 286 server_name=args.server_name, 287 ) 288 289 290 def main(): 291 args = _get_args() 292 model, processor = _load_model_processor(args) 293 _launch_demo(args, model, processor) 294 295 296 if __name__ == '__main__': 297 main()