base.py
1 import operator 2 import chainlit as cl 3 from typing import TypedDict, Annotated, Sequence, Dict, Optional 4 from langchain_core.messages import AnyMessage, HumanMessage 5 from abc import ABC, abstractmethod 6 from typing import Dict, Any 7 from langgraph.graph import StateGraph, END 8 9 10 class BaseState(TypedDict): 11 # Message history 12 messages: Annotated[Sequence[AnyMessage], operator.add] 13 14 # Name of the workflow 15 chat_profile: str 16 17 18 class BaseWorkflow(ABC): 19 @abstractmethod 20 def create_graph(self) -> StateGraph: 21 """ 22 Define the state graph of the workflow. 23 """ 24 25 @abstractmethod 26 def create_default_state(self) -> Dict[str, Any]: 27 """ 28 Define the default state of the workflow. 29 """ 30 31 @classmethod 32 @abstractmethod 33 def name(cls) -> str: 34 pass 35 36 @property 37 @abstractmethod 38 def output_chat_model(self) -> str: 39 """ 40 The name of the chat model to display in the UI. 41 Normally, this is the name of the chat model that is 42 used to generate the final output. 43 """ 44 45 @classmethod 46 @abstractmethod 47 def chat_profile(cls) -> cl.ChatProfile: 48 """ 49 Chat profile to display in the UI. This is for providing 50 an option in the list of available workflows to the user. 51 """ 52 53 @property 54 @abstractmethod 55 def chat_settings(self) -> cl.ChatSettings: 56 """ 57 Chatt settings to display in the UI. This is for providing 58 customizable settings to the user. 59 """ 60 61 def tool_routing(self, state: BaseState): 62 """ 63 Use in the conditional_edge to route to the ToolNode if the last message 64 has tool calls. Otherwise, route to the end. 65 """ 66 if isinstance(state, list): 67 ai_message = state[-1] 68 elif messages := state.get("messages", []): 69 ai_message = messages[-1] 70 else: 71 raise ValueError( 72 f"No messages found in input state to tool_edge: {state}") 73 if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: 74 return "tools" 75 return END 76 77 async def get_chat_settings(self, state: Optional[BaseState] = None) -> cl.ChatSettings: 78 """ 79 Get the chat settings for the workflow. 80 81 Args: 82 state (Optional[BaseState]): The state of the workflow. Used to resume a chat from previous session. 83 """ 84 settings = self.chat_settings 85 # Resume settings from previous session 86 if state is not None: 87 for widget in settings.inputs: 88 if widget.id in state: 89 if isinstance(widget, cl.input_widget.Select): 90 if widget.items: 91 if state[widget.id] in widget.items.values(): 92 widget.initial = state[widget.id] 93 elif widget.values: 94 if state[widget.id] in widget.values: 95 widget.initial = state[widget.id] 96 elif isinstance(widget, cl.input_widget.Switch): 97 widget.initial = state[widget.id] 98 elif isinstance(widget, cl.input_widget.Slider): 99 if widget.min > state[widget.id]: 100 widget.initial = widget.min 101 elif widget.max < state[widget.id]: 102 widget.initial = widget.max 103 else: 104 widget.initial = state[widget.id] 105 elif isinstance(widget, cl.input_widget.TextInput): 106 widget.initial = state[widget.id] 107 elif isinstance(widget, cl.input_widget.NumberInput): 108 widget.initial = state[widget.id] 109 elif isinstance(widget, cl.input_widget.Tags): 110 if widget.values: 111 widget.initial = [ 112 tag for tag in state[widget.id] if tag in widget.values] 113 else: 114 widget.initial = state[widget.id] 115 return await settings.send() 116 117 def format_message(self, message: cl.Message) -> HumanMessage: 118 return HumanMessage(content=message.content)