/ chat_workflow / workflows / simple_chat.py
simple_chat.py
  1  import chainlit as cl
  2  from chainlit.input_widget import Select
  3  from langgraph.graph import StateGraph
  4  from langchain_core.messages import SystemMessage
  5  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
  6  from langchain_core.runnables import Runnable, RunnableConfig
  7  from .base import BaseWorkflow, BaseState
  8  from ..llm import llm_factory, ModelCapability
  9  from ..tools import BasicToolNode
 10  from ..tools.search import get_search_tools
 11  from ..tools.time import get_datetime_now
 12  
 13  
 14  class GraphState(BaseState):
 15      # Model name of the chatbot
 16      chat_model: str
 17  
 18  
 19  class SimpleChatWorkflow(BaseWorkflow):
 20      def __init__(self):
 21          super().__init__()
 22  
 23          self.capabilities = {
 24              ModelCapability.TEXT_TO_TEXT, ModelCapability.TOOL_CALLING}
 25          self.tools = [get_datetime_now] + get_search_tools()
 26  
 27      def create_graph(self) -> StateGraph:
 28          graph = StateGraph(GraphState)
 29          graph.add_node("chat", self.chat_node)
 30          graph.add_node("tools", BasicToolNode(self.tools))
 31  
 32          # TODO: create a router for using multiple tools
 33          graph.set_entry_point("chat")
 34          graph.add_conditional_edges("chat", self.tool_routing)
 35          graph.add_edge("tools", "chat")
 36          return graph
 37  
 38      async def chat_node(self, state: GraphState, config: RunnableConfig) -> GraphState:
 39          prompt = ChatPromptTemplate.from_messages([
 40              SystemMessage(content="You're a helpful assistant."),
 41              MessagesPlaceholder(variable_name="messages"),
 42          ])
 43          llm = llm_factory.create_model(
 44              self.output_chat_model, model=state["chat_model"], tools=self.tools)
 45          chain: Runnable = prompt | llm
 46          return {
 47              "messages": [await chain.ainvoke(state, config=config)]
 48          }
 49  
 50      def create_default_state(self) -> GraphState:
 51          return {
 52              "name": self.name(),
 53              "messages": [],
 54              "chat_model": "",
 55          }
 56  
 57      @classmethod
 58      def name(cls) -> str:
 59          return "Simple Chat"
 60  
 61      @property
 62      def output_chat_model(self) -> str:
 63          return "chat_model"
 64  
 65      @classmethod
 66      def chat_profile(cls) -> cl.ChatProfile:
 67          return cl.ChatProfile(
 68              name=cls.name(),
 69              markdown_description="A ChatGPT-like chatbot.",
 70              icon="https://cdn1.iconfinder.com/data/icons/3d-front-color/128/chat-text-front-color.png",
 71              default=True,
 72              starters=[
 73                  cl.Starter(
 74                      label="Write a snake game in Python.",
 75                      message="Write a snake game in Python.",
 76                      icon="https://cdn1.iconfinder.com/data/icons/photography-calendar-speaker-person-thinking-3d-il/128/13.png",
 77                  ),
 78                  cl.Starter(
 79                      label="What is the weather in San Francisco?",
 80                      message="What is the weather in San Francisco?",
 81                      icon="https://cdn0.iconfinder.com/data/icons/3d-dynamic-color/128/sun-dynamic-color.png",
 82                  ),
 83                  cl.Starter(
 84                      label="How do I make a peanut butter and jelly sandwich?",
 85                      message="How do I make a peanut butter and jelly sandwich?",
 86                      icon="https://cdn0.iconfinder.com/data/icons/fast-food-3d/128/Sandwich.png",
 87                  ),
 88              ],
 89          )
 90  
 91      @property
 92      def chat_settings(self) -> cl.ChatSettings:
 93          return cl.ChatSettings([
 94              Select(
 95                  id="chat_model",
 96                  label="Chat Model",
 97                  values=sorted(llm_factory.list_models(
 98                      capabilities=self.capabilities)),
 99                  initial_index=0,
100              ),
101          ])