workflow_factory.py
1 import chainlit as cl 2 import importlib 3 from typing import Type, Dict 4 from .base import BaseWorkflow, BaseState 5 6 7 class WorkflowFactory: 8 _workflows: Dict[str, Type[BaseWorkflow]] = {} 9 _module_map: Dict[str, str] = {} # Maps chat profile names to module names 10 11 @classmethod 12 def register(cls, name: str, workflow_class: Type[BaseWorkflow]): 13 module_name = workflow_class.__module__.split( 14 '.')[-1] # e.g. 'simple_chat' 15 cls._workflows[name] = workflow_class # e.g. 'Simple Chat' 16 cls._module_map[name] = module_name 17 18 @classmethod 19 def unregister(cls, name: str): 20 """Dynamically remove workflows""" 21 cls._workflows.pop(name, None) 22 23 @classmethod 24 def create(cls, name: str, **kwargs) -> BaseWorkflow: 25 if name not in cls._workflows: 26 raise ValueError(f"Workflow {name} not found") 27 return cls._workflows[name](**kwargs) 28 29 @classmethod 30 def list_workflows(cls) -> list[str]: 31 return list(cls._workflows.keys()) 32 33 @classmethod 34 def get_graph_state(cls, chat_profile: str) -> Type[BaseState]: 35 """Get GraphState using chat profile name""" 36 workflow_class = cls._workflows[chat_profile] 37 module = importlib.import_module(workflow_class.__module__) 38 return getattr(module, "GraphState") 39 40 @classmethod 41 def get_chat_profile(cls, name: str) -> cl.ChatProfile: 42 """Get chat profile from workflow class""" 43 workflow_class = cls._workflows[name] 44 return workflow_class.chat_profile()