multimodal_chat.py
1 import chainlit as cl 2 import base64 3 from chainlit.input_widget import Select 4 from langgraph.graph import StateGraph 5 from langchain_core.messages import SystemMessage, HumanMessage 6 from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder 7 from langchain_core.runnables import Runnable, RunnableConfig 8 from .base import BaseWorkflow, BaseState 9 from ..llm import llm_factory, ModelCapability 10 from ..tools import BasicToolNode 11 from ..tools.search import get_search_tools 12 from ..tools.time import get_datetime_now 13 14 15 class GraphState(BaseState): 16 # Model name of the chatbot 17 chat_model: str 18 19 20 class MultimodalChatWorkflow(BaseWorkflow): 21 def __init__(self): 22 super().__init__() 23 24 self.capabilities = { 25 ModelCapability.TEXT_TO_TEXT, ModelCapability.IMAGE_TO_TEXT, ModelCapability.TOOL_CALLING} 26 self.tools = [get_datetime_now] + get_search_tools() 27 28 def create_graph(self) -> StateGraph: 29 graph = StateGraph(GraphState) 30 graph.add_node("chat", self.chat_node) 31 graph.add_node("tools", BasicToolNode(self.tools)) 32 33 # TODO: create a router for using multiple tools 34 graph.set_entry_point("chat") 35 graph.add_conditional_edges("chat", self.tool_routing) 36 graph.add_edge("tools", "chat") 37 return graph 38 39 async def chat_node(self, state: GraphState, config: RunnableConfig) -> GraphState: 40 prompt = ChatPromptTemplate.from_messages([ 41 SystemMessage(content="You're a helpful assistant."), 42 MessagesPlaceholder(variable_name="messages"), 43 ]) 44 llm = llm_factory.create_model( 45 self.output_chat_model, model=state["chat_model"], tools=self.tools) 46 chain: Runnable = prompt | llm 47 return { 48 "messages": [await chain.ainvoke(state, config=config)] 49 } 50 51 def create_default_state(self) -> GraphState: 52 return { 53 "name": self.name(), 54 "messages": [], 55 "chat_model": "", 56 } 57 58 @classmethod 59 def name(cls) -> str: 60 return "Multimodal Chat" 61 62 @property 63 def output_chat_model(self) -> str: 64 return "chat_model" 65 66 @classmethod 67 def chat_profile(cls) -> cl.ChatProfile: 68 return cl.ChatProfile( 69 name=cls.name(), 70 markdown_description="A ChatGPT-like chatbot.", 71 icon="https://cdn0.iconfinder.com/data/icons/essential-pack-1-3d/64/picture.png", 72 starters=[ 73 cl.Starter( 74 label="Write a snake game in Python.", 75 message="Write a snake game in Python.", 76 icon="https://cdn1.iconfinder.com/data/icons/photography-calendar-speaker-person-thinking-3d-il/128/13.png", 77 ), 78 cl.Starter( 79 label="What is the weather in San Francisco?", 80 message="What is the weather in San Francisco?", 81 icon="https://cdn0.iconfinder.com/data/icons/3d-dynamic-color/128/sun-dynamic-color.png", 82 ), 83 cl.Starter( 84 label="How do I make a peanut butter and jelly sandwich?", 85 message="How do I make a peanut butter and jelly sandwich?", 86 icon="https://cdn0.iconfinder.com/data/icons/fast-food-3d/128/Sandwich.png", 87 ), 88 ], 89 ) 90 91 @property 92 def chat_settings(self) -> cl.ChatSettings: 93 return cl.ChatSettings([ 94 Select( 95 id="chat_model", 96 label="Chat Model", 97 values=sorted(llm_factory.list_models( 98 capabilities=self.capabilities)), 99 initial_index=0, 100 ), 101 ]) 102 103 def format_message(self, msg: cl.Message) -> HumanMessage: 104 """Format chainlit message to LangChain message with multimodal support""" 105 if not msg.elements: 106 return HumanMessage(content=msg.content) 107 108 # Initialize the multimodal content list 109 formatted_content = [{"type": "text", "text": msg.content}] 110 111 # Process images 112 images = [file for file in msg.elements if "image" in file.mime] 113 for image in images: 114 with open(image.path, "rb") as img_file: 115 image_data = base64.b64encode(img_file.read()).decode("utf-8") 116 formatted_content.append({ 117 "type": "image_url", 118 "image_url": { 119 "url": f"data:image/jpeg;base64,{image_data}" 120 } 121 }) 122 123 return HumanMessage(content=formatted_content)