/ chat_workflow / state_serializer.py
state_serializer.py
1 import json 2 from typing import Type 3 from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage 4 from chat_workflow.workflows.base import BaseState 5 6 7 class StateSerializer: 8 """ 9 A utility class for serializing and deserializing BaseState objects. 10 11 This class provides methods to convert BaseState objects to and from JSON strings, 12 making them suitable for storage in databases or transmission over networks. 13 It handles special cases for message objects and provides a generic approach 14 for other state attributes. 15 16 Class Methods: 17 serialize(state: BaseState) -> str: 18 Converts a BaseState object to a JSON string. 19 20 deserialize(serialized_state: str, state_class: Type[BaseState]) -> BaseState: 21 Converts a JSON string back to a BaseState object of the specified class. 22 """ 23 24 @classmethod 25 def serialize(cls, state: BaseState) -> str: 26 serializable_state = state.copy() 27 serializable_state["messages"] = cls._serialize_messages( 28 state["messages"]) 29 return json.dumps(serializable_state, default=cls._json_serializer) 30 31 @classmethod 32 def deserialize(cls, serialized_state: str, state_class: Type[BaseState]) -> BaseState: 33 state_dict = json.loads(serialized_state) 34 state_dict['messages'] = cls._deserialize_messages( 35 state_dict['messages']) 36 return state_class(**cls._json_deserializer(state_dict)) 37 38 @staticmethod 39 def _serialize_messages(messages): 40 return [message.model_dump() for message in messages] 41 42 @staticmethod 43 def _deserialize_messages(serialized_messages): 44 """ 45 Deserialize a list of message dictionaries into their respective BaseMessage subclasses. 46 47 This method takes a list of serialized message dictionaries and converts them back 48 into instances of the appropriate BaseMessage subclasses (HumanMessage, AIMessage, 49 ToolMessage, SystemMessage) based on the 'type' field in each dictionary. 50 51 Args: 52 serialized_messages (List[Dict]): A list of dictionaries representing serialized messages. 53 Each dictionary should contain a 'type' field and other relevant message data. 54 55 Returns: 56 List[BaseMessage]: A list of deserialized BaseMessage subclass instances. 57 58 Example: 59 serialized_messages = [ 60 {"type": "human", "content": "Hello"}, 61 {"type": "ai", "content": "Hi there!"}, 62 {"type": "tool", "content": "Processing...", "tool_call_id": "123"} 63 ] 64 deserialized_messages = StateSerializer._deserialize_messages(serialized_messages) 65 """ 66 message_type_mapping = { 67 "human": HumanMessage, 68 "ai": AIMessage, 69 "tool": ToolMessage, 70 "system": SystemMessage, 71 } 72 73 deserialized_messages = [] 74 for msg_dict in serialized_messages: 75 msg_type = msg_dict.get("type") 76 msg_class = message_type_mapping.get(msg_type, BaseMessage) 77 deserialized_messages.append(msg_class.model_validate(msg_dict)) 78 return deserialized_messages 79 80 @staticmethod 81 def _json_serializer(obj): 82 if hasattr(obj, '__dict__'): 83 return obj.__dict__ 84 return str(obj) 85 86 @staticmethod 87 def _json_deserializer(dct): 88 for key, value in dct.items(): 89 if isinstance(value, str): 90 try: 91 dct[key] = eval(value) 92 except: 93 pass 94 return dct