/ web_demo_mm.py
web_demo_mm.py
  1  # Copyright (c) Alibaba Cloud.
  2  #
  3  # This source code is licensed under the license found in the
  4  # LICENSE file in the root directory of this source tree.
  5  
  6  import copy
  7  import re
  8  from argparse import ArgumentParser
  9  from threading import Thread
 10  
 11  import gradio as gr
 12  import torch
 13  from qwen_vl_utils import process_vision_info
 14  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, TextIteratorStreamer
 15  
 16  DEFAULT_CKPT_PATH = 'Qwen/Qwen2-VL-7B-Instruct'
 17  
 18  
 19  def _get_args():
 20      parser = ArgumentParser()
 21  
 22      parser.add_argument('-c',
 23                          '--checkpoint-path',
 24                          type=str,
 25                          default=DEFAULT_CKPT_PATH,
 26                          help='Checkpoint name or path, default to %(default)r')
 27      parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
 28  
 29      parser.add_argument('--flash-attn2',
 30                          action='store_true',
 31                          default=False,
 32                          help='Enable flash_attention_2 when loading the model.')
 33      parser.add_argument('--share',
 34                          action='store_true',
 35                          default=False,
 36                          help='Create a publicly shareable link for the interface.')
 37      parser.add_argument('--inbrowser',
 38                          action='store_true',
 39                          default=False,
 40                          help='Automatically launch the interface in a new tab on the default browser.')
 41      parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
 42      parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
 43  
 44      args = parser.parse_args()
 45      return args
 46  
 47  
 48  def _load_model_processor(args):
 49      if args.cpu_only:
 50          device_map = 'cpu'
 51      else:
 52          device_map = 'auto'
 53  
 54      # Check if flash-attn2 flag is enabled and load model accordingly
 55      if args.flash_attn2:
 56          model = Qwen2VLForConditionalGeneration.from_pretrained(args.checkpoint_path,
 57                                                                  torch_dtype='auto',
 58                                                                  attn_implementation='flash_attention_2',
 59                                                                  device_map=device_map)
 60      else:
 61          model = Qwen2VLForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map=device_map)
 62  
 63      processor = AutoProcessor.from_pretrained(args.checkpoint_path)
 64      return model, processor
 65  
 66  
 67  def _parse_text(text):
 68      lines = text.split('\n')
 69      lines = [line for line in lines if line != '']
 70      count = 0
 71      for i, line in enumerate(lines):
 72          if '```' in line:
 73              count += 1
 74              items = line.split('`')
 75              if count % 2 == 1:
 76                  lines[i] = f'<pre><code class="language-{items[-1]}">'
 77              else:
 78                  lines[i] = '<br></code></pre>'
 79          else:
 80              if i > 0:
 81                  if count % 2 == 1:
 82                      line = line.replace('`', r'\`')
 83                      line = line.replace('<', '&lt;')
 84                      line = line.replace('>', '&gt;')
 85                      line = line.replace(' ', '&nbsp;')
 86                      line = line.replace('*', '&ast;')
 87                      line = line.replace('_', '&lowbar;')
 88                      line = line.replace('-', '&#45;')
 89                      line = line.replace('.', '&#46;')
 90                      line = line.replace('!', '&#33;')
 91                      line = line.replace('(', '&#40;')
 92                      line = line.replace(')', '&#41;')
 93                      line = line.replace('$', '&#36;')
 94                  lines[i] = '<br>' + line
 95      text = ''.join(lines)
 96      return text
 97  
 98  
 99  def _remove_image_special(text):
100      text = text.replace('<ref>', '').replace('</ref>', '')
101      return re.sub(r'<box>.*?(</box>|$)', '', text)
102  
103  
104  def _is_video_file(filename):
105      video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
106      return any(filename.lower().endswith(ext) for ext in video_extensions)
107  
108  
109  def _gc():
110      import gc
111      gc.collect()
112      if torch.cuda.is_available():
113          torch.cuda.empty_cache()
114  
115  
116  def _transform_messages(original_messages):
117      transformed_messages = []
118      for message in original_messages:
119          new_content = []
120          for item in message['content']:
121              if 'image' in item:
122                  new_item = {'type': 'image', 'image': item['image']}
123              elif 'text' in item:
124                  new_item = {'type': 'text', 'text': item['text']}
125              elif 'video' in item:
126                  new_item = {'type': 'video', 'video': item['video']}
127              else:
128                  continue
129              new_content.append(new_item)
130  
131          new_message = {'role': message['role'], 'content': new_content}
132          transformed_messages.append(new_message)
133  
134      return transformed_messages
135  
136  
137  def _launch_demo(args, model, processor):
138  
139      def call_local_model(model, processor, messages):
140  
141          messages = _transform_messages(messages)
142  
143          text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
144          image_inputs, video_inputs = process_vision_info(messages)
145          inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt')
146          inputs = inputs.to(model.device)
147  
148          tokenizer = processor.tokenizer
149          streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
150  
151          gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs}
152  
153          thread = Thread(target=model.generate, kwargs=gen_kwargs)
154          thread.start()
155  
156          generated_text = ''
157          for new_text in streamer:
158              generated_text += new_text
159              yield generated_text
160  
161      def create_predict_fn():
162  
163          def predict(_chatbot, task_history):
164              nonlocal model, processor
165              chat_query = _chatbot[-1][0]
166              query = task_history[-1][0]
167              if len(chat_query) == 0:
168                  _chatbot.pop()
169                  task_history.pop()
170                  return _chatbot
171              print('User: ' + _parse_text(query))
172              history_cp = copy.deepcopy(task_history)
173              full_response = ''
174              messages = []
175              content = []
176              for q, a in history_cp:
177                  if isinstance(q, (tuple, list)):
178                      if _is_video_file(q[0]):
179                          content.append({'video': f'file://{q[0]}'})
180                      else:
181                          content.append({'image': f'file://{q[0]}'})
182                  else:
183                      content.append({'text': q})
184                      messages.append({'role': 'user', 'content': content})
185                      messages.append({'role': 'assistant', 'content': [{'text': a}]})
186                      content = []
187              messages.pop()
188  
189              for response in call_local_model(model, processor, messages):
190                  _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
191  
192                  yield _chatbot
193                  full_response = _parse_text(response)
194  
195              task_history[-1] = (query, full_response)
196              print('Qwen-VL-Chat: ' + _parse_text(full_response))
197              yield _chatbot
198  
199          return predict
200  
201      def create_regenerate_fn():
202  
203          def regenerate(_chatbot, task_history):
204              nonlocal model, processor
205              if not task_history:
206                  return _chatbot
207              item = task_history[-1]
208              if item[1] is None:
209                  return _chatbot
210              task_history[-1] = (item[0], None)
211              chatbot_item = _chatbot.pop(-1)
212              if chatbot_item[0] is None:
213                  _chatbot[-1] = (_chatbot[-1][0], None)
214              else:
215                  _chatbot.append((chatbot_item[0], None))
216              _chatbot_gen = predict(_chatbot, task_history)
217              for _chatbot in _chatbot_gen:
218                  yield _chatbot
219  
220          return regenerate
221  
222      predict = create_predict_fn()
223      regenerate = create_regenerate_fn()
224  
225      def add_text(history, task_history, text):
226          task_text = text
227          history = history if history is not None else []
228          task_history = task_history if task_history is not None else []
229          history = history + [(_parse_text(text), None)]
230          task_history = task_history + [(task_text, None)]
231          return history, task_history, ''
232  
233      def add_file(history, task_history, file):
234          history = history if history is not None else []
235          task_history = task_history if task_history is not None else []
236          history = history + [((file.name,), None)]
237          task_history = task_history + [((file.name,), None)]
238          return history, task_history
239  
240      def reset_user_input():
241          return gr.update(value='')
242  
243      def reset_state(_chatbot, task_history):
244          task_history.clear()
245          _chatbot.clear()
246          _gc()
247          return []
248  
249      with gr.Blocks() as demo:
250          gr.Markdown("""\
251  <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 80px"/><p>"""
252                     )
253          gr.Markdown("""<center><font size=8>Qwen2-VL</center>""")
254          gr.Markdown("""\
255  <center><font size=3>This WebUI is based on Qwen2-VL, developed by Alibaba Cloud.</center>""")
256          gr.Markdown("""<center><font size=3>本WebUI基于Qwen2-VL。</center>""")
257  
258          chatbot = gr.Chatbot(label='Qwen2-VL', elem_classes='control-height', height=500)
259          query = gr.Textbox(lines=2, label='Input')
260          task_history = gr.State([])
261  
262          with gr.Row():
263              addfile_btn = gr.UploadButton('📁 Upload (上传文件)', file_types=['image', 'video'])
264              submit_btn = gr.Button('🚀 Submit (发送)')
265              regen_btn = gr.Button('🤔️ Regenerate (重试)')
266              empty_bin = gr.Button('🧹 Clear History (清除历史)')
267  
268          submit_btn.click(add_text, [chatbot, task_history, query],
269                           [chatbot, task_history]).then(predict, [chatbot, task_history], [chatbot], show_progress=True)
270          submit_btn.click(reset_user_input, [], [query])
271          empty_bin.click(reset_state, [chatbot, task_history], [chatbot], show_progress=True)
272          regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
273          addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
274  
275          gr.Markdown("""\
276  <font size=2>Note: This demo is governed by the original license of Qwen2-VL. \
277  We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
278  including hate speech, violence, pornography, deception, etc. \
279  (注:本演示受Qwen2-VL的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
280  包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
281  
282      demo.queue().launch(
283          share=args.share,
284          inbrowser=args.inbrowser,
285          server_port=args.server_port,
286          server_name=args.server_name,
287      )
288  
289  
290  def main():
291      args = _get_args()
292      model, processor = _load_model_processor(args)
293      _launch_demo(args, model, processor)
294  
295  
296  if __name__ == '__main__':
297      main()