/ chat_workflow / state_serializer.py
state_serializer.py
 1  import json
 2  from typing import Type
 3  from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
 4  from chat_workflow.workflows.base import BaseState
 5  
 6  
 7  class StateSerializer:
 8      """
 9      A utility class for serializing and deserializing BaseState objects.
10  
11      This class provides methods to convert BaseState objects to and from JSON strings,
12      making them suitable for storage in databases or transmission over networks.
13      It handles special cases for message objects and provides a generic approach
14      for other state attributes.
15  
16      Class Methods:
17          serialize(state: BaseState) -> str:
18              Converts a BaseState object to a JSON string.
19  
20          deserialize(serialized_state: str, state_class: Type[BaseState]) -> BaseState:
21              Converts a JSON string back to a BaseState object of the specified class.
22      """
23  
24      @classmethod
25      def serialize(cls, state: BaseState) -> str:
26          serializable_state = state.copy()
27          serializable_state["messages"] = cls._serialize_messages(
28              state["messages"])
29          return json.dumps(serializable_state, default=cls._json_serializer)
30  
31      @classmethod
32      def deserialize(cls, serialized_state: str, state_class: Type[BaseState]) -> BaseState:
33          state_dict = json.loads(serialized_state)
34          state_dict['messages'] = cls._deserialize_messages(
35              state_dict['messages'])
36          return state_class(**cls._json_deserializer(state_dict))
37  
38      @staticmethod
39      def _serialize_messages(messages):
40          return [message.model_dump() for message in messages]
41  
42      @staticmethod
43      def _deserialize_messages(serialized_messages):
44          """
45          Deserialize a list of message dictionaries into their respective BaseMessage subclasses.
46  
47          This method takes a list of serialized message dictionaries and converts them back
48          into instances of the appropriate BaseMessage subclasses (HumanMessage, AIMessage,
49          ToolMessage, SystemMessage) based on the 'type' field in each dictionary.
50  
51          Args:
52              serialized_messages (List[Dict]): A list of dictionaries representing serialized messages.
53                  Each dictionary should contain a 'type' field and other relevant message data.
54  
55          Returns:
56              List[BaseMessage]: A list of deserialized BaseMessage subclass instances.
57  
58          Example:
59              serialized_messages = [
60                  {"type": "human", "content": "Hello"},
61                  {"type": "ai", "content": "Hi there!"},
62                  {"type": "tool", "content": "Processing...", "tool_call_id": "123"}
63              ]
64              deserialized_messages = StateSerializer._deserialize_messages(serialized_messages)
65          """
66          message_type_mapping = {
67              "human": HumanMessage,
68              "ai": AIMessage,
69              "tool": ToolMessage,
70              "system": SystemMessage,
71          }
72  
73          deserialized_messages = []
74          for msg_dict in serialized_messages:
75              msg_type = msg_dict.get("type")
76              msg_class = message_type_mapping.get(msg_type, BaseMessage)
77              deserialized_messages.append(msg_class.model_validate(msg_dict))
78          return deserialized_messages
79  
80      @staticmethod
81      def _json_serializer(obj):
82          if hasattr(obj, '__dict__'):
83              return obj.__dict__
84          return str(obj)
85  
86      @staticmethod
87      def _json_deserializer(dct):
88          for key, value in dct.items():
89              if isinstance(value, str):
90                  try:
91                      dct[key] = eval(value)
92                  except:
93                      pass
94          return dct