base.py
  1  import operator
  2  import chainlit as cl
  3  from typing import TypedDict, Annotated, Sequence, Dict, Optional
  4  from langchain_core.messages import AnyMessage, HumanMessage
  5  from abc import ABC, abstractmethod
  6  from typing import Dict, Any
  7  from langgraph.graph import StateGraph, END
  8  
  9  
 10  class BaseState(TypedDict):
 11      # Message history
 12      messages: Annotated[Sequence[AnyMessage], operator.add]
 13  
 14      # Name of the workflow
 15      chat_profile: str
 16  
 17  
 18  class BaseWorkflow(ABC):
 19      @abstractmethod
 20      def create_graph(self) -> StateGraph:
 21          """
 22          Define the state graph of the workflow.
 23          """
 24  
 25      @abstractmethod
 26      def create_default_state(self) -> Dict[str, Any]:
 27          """
 28          Define the default state of the workflow.
 29          """
 30  
 31      @classmethod
 32      @abstractmethod
 33      def name(cls) -> str:
 34          pass
 35  
 36      @property
 37      @abstractmethod
 38      def output_chat_model(self) -> str:
 39          """
 40          The name of the chat model to display in the UI. 
 41          Normally, this is the name of the chat model that is 
 42          used to generate the final output.
 43          """
 44  
 45      @classmethod
 46      @abstractmethod
 47      def chat_profile(cls) -> cl.ChatProfile:
 48          """
 49          Chat profile to display in the UI. This is for providing
 50          an option in the list of available workflows to the user.
 51          """
 52  
 53      @property
 54      @abstractmethod
 55      def chat_settings(self) -> cl.ChatSettings:
 56          """
 57          Chatt settings to display in the UI. This is for providing
 58          customizable settings to the user.
 59          """
 60  
 61      def tool_routing(self, state: BaseState):
 62          """
 63          Use in the conditional_edge to route to the ToolNode if the last message
 64          has tool calls. Otherwise, route to the end.
 65          """
 66          if isinstance(state, list):
 67              ai_message = state[-1]
 68          elif messages := state.get("messages", []):
 69              ai_message = messages[-1]
 70          else:
 71              raise ValueError(
 72                  f"No messages found in input state to tool_edge: {state}")
 73          if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
 74              return "tools"
 75          return END
 76  
 77      async def get_chat_settings(self, state: Optional[BaseState] = None) -> cl.ChatSettings:
 78          """
 79          Get the chat settings for the workflow.
 80  
 81          Args:
 82              state (Optional[BaseState]): The state of the workflow. Used to resume a chat from previous session.
 83          """
 84          settings = self.chat_settings
 85          # Resume settings from previous session
 86          if state is not None:
 87              for widget in settings.inputs:
 88                  if widget.id in state:
 89                      if isinstance(widget, cl.input_widget.Select):
 90                          if widget.items:
 91                              if state[widget.id] in widget.items.values():
 92                                  widget.initial = state[widget.id]
 93                          elif widget.values:
 94                              if state[widget.id] in widget.values:
 95                                  widget.initial = state[widget.id]
 96                      elif isinstance(widget, cl.input_widget.Switch):
 97                          widget.initial = state[widget.id]
 98                      elif isinstance(widget, cl.input_widget.Slider):
 99                          if widget.min > state[widget.id]:
100                              widget.initial = widget.min
101                          elif widget.max < state[widget.id]:
102                              widget.initial = widget.max
103                          else:
104                              widget.initial = state[widget.id]
105                      elif isinstance(widget, cl.input_widget.TextInput):
106                          widget.initial = state[widget.id]
107                      elif isinstance(widget, cl.input_widget.NumberInput):
108                          widget.initial = state[widget.id]
109                      elif isinstance(widget, cl.input_widget.Tags):
110                          if widget.values:
111                              widget.initial = [
112                                  tag for tag in state[widget.id] if tag in widget.values]
113                          else:
114                              widget.initial = state[widget.id]
115          return await settings.send()
116  
117      def format_message(self, message: cl.Message) -> HumanMessage:
118          return HumanMessage(content=message.content)