/ chat_workflow / workflows / multimodal_chat.py
multimodal_chat.py
  1  import chainlit as cl
  2  import base64
  3  from chainlit.input_widget import Select
  4  from langgraph.graph import StateGraph
  5  from langchain_core.messages import SystemMessage, HumanMessage
  6  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
  7  from langchain_core.runnables import Runnable, RunnableConfig
  8  from .base import BaseWorkflow, BaseState
  9  from ..llm import llm_factory, ModelCapability
 10  from ..tools import BasicToolNode
 11  from ..tools.search import get_search_tools
 12  from ..tools.time import get_datetime_now
 13  
 14  
 15  class GraphState(BaseState):
 16      # Model name of the chatbot
 17      chat_model: str
 18  
 19  
 20  class MultimodalChatWorkflow(BaseWorkflow):
 21      def __init__(self):
 22          super().__init__()
 23  
 24          self.capabilities = {
 25              ModelCapability.TEXT_TO_TEXT, ModelCapability.IMAGE_TO_TEXT, ModelCapability.TOOL_CALLING}
 26          self.tools = [get_datetime_now] + get_search_tools()
 27  
 28      def create_graph(self) -> StateGraph:
 29          graph = StateGraph(GraphState)
 30          graph.add_node("chat", self.chat_node)
 31          graph.add_node("tools", BasicToolNode(self.tools))
 32  
 33          # TODO: create a router for using multiple tools
 34          graph.set_entry_point("chat")
 35          graph.add_conditional_edges("chat", self.tool_routing)
 36          graph.add_edge("tools", "chat")
 37          return graph
 38  
 39      async def chat_node(self, state: GraphState, config: RunnableConfig) -> GraphState:
 40          prompt = ChatPromptTemplate.from_messages([
 41              SystemMessage(content="You're a helpful assistant."),
 42              MessagesPlaceholder(variable_name="messages"),
 43          ])
 44          llm = llm_factory.create_model(
 45              self.output_chat_model, model=state["chat_model"], tools=self.tools)
 46          chain: Runnable = prompt | llm
 47          return {
 48              "messages": [await chain.ainvoke(state, config=config)]
 49          }
 50  
 51      def create_default_state(self) -> GraphState:
 52          return {
 53              "name": self.name(),
 54              "messages": [],
 55              "chat_model": "",
 56          }
 57  
 58      @classmethod
 59      def name(cls) -> str:
 60          return "Multimodal Chat"
 61  
 62      @property
 63      def output_chat_model(self) -> str:
 64          return "chat_model"
 65  
 66      @classmethod
 67      def chat_profile(cls) -> cl.ChatProfile:
 68          return cl.ChatProfile(
 69              name=cls.name(),
 70              markdown_description="A ChatGPT-like chatbot.",
 71              icon="https://cdn0.iconfinder.com/data/icons/essential-pack-1-3d/64/picture.png",
 72              starters=[
 73                  cl.Starter(
 74                      label="Write a snake game in Python.",
 75                      message="Write a snake game in Python.",
 76                      icon="https://cdn1.iconfinder.com/data/icons/photography-calendar-speaker-person-thinking-3d-il/128/13.png",
 77                  ),
 78                  cl.Starter(
 79                      label="What is the weather in San Francisco?",
 80                      message="What is the weather in San Francisco?",
 81                      icon="https://cdn0.iconfinder.com/data/icons/3d-dynamic-color/128/sun-dynamic-color.png",
 82                  ),
 83                  cl.Starter(
 84                      label="How do I make a peanut butter and jelly sandwich?",
 85                      message="How do I make a peanut butter and jelly sandwich?",
 86                      icon="https://cdn0.iconfinder.com/data/icons/fast-food-3d/128/Sandwich.png",
 87                  ),
 88              ],
 89          )
 90  
 91      @property
 92      def chat_settings(self) -> cl.ChatSettings:
 93          return cl.ChatSettings([
 94              Select(
 95                  id="chat_model",
 96                  label="Chat Model",
 97                  values=sorted(llm_factory.list_models(
 98                      capabilities=self.capabilities)),
 99                  initial_index=0,
100              ),
101          ])
102  
103      def format_message(self, msg: cl.Message) -> HumanMessage:
104          """Format chainlit message to LangChain message with multimodal support"""
105          if not msg.elements:
106              return HumanMessage(content=msg.content)
107  
108          # Initialize the multimodal content list
109          formatted_content = [{"type": "text", "text": msg.content}]
110  
111          # Process images
112          images = [file for file in msg.elements if "image" in file.mime]
113          for image in images:
114              with open(image.path, "rb") as img_file:
115                  image_data = base64.b64encode(img_file.read()).decode("utf-8")
116                  formatted_content.append({
117                      "type": "image_url",
118                      "image_url": {
119                          "url": f"data:image/jpeg;base64,{image_data}"
120                      }
121                  })
122  
123          return HumanMessage(content=formatted_content)