/ chat_workflow / workflows / workflow_factory.py
workflow_factory.py
 1  import chainlit as cl
 2  import importlib
 3  from typing import Type, Dict
 4  from .base import BaseWorkflow, BaseState
 5  
 6  
 7  class WorkflowFactory:
 8      _workflows: Dict[str, Type[BaseWorkflow]] = {}
 9      _module_map: Dict[str, str] = {}  # Maps chat profile names to module names
10  
11      @classmethod
12      def register(cls, name: str, workflow_class: Type[BaseWorkflow]):
13          module_name = workflow_class.__module__.split(
14              '.')[-1]  # e.g. 'simple_chat'
15          cls._workflows[name] = workflow_class  # e.g. 'Simple Chat'
16          cls._module_map[name] = module_name
17  
18      @classmethod
19      def unregister(cls, name: str):
20          """Dynamically remove workflows"""
21          cls._workflows.pop(name, None)
22  
23      @classmethod
24      def create(cls, name: str, **kwargs) -> BaseWorkflow:
25          if name not in cls._workflows:
26              raise ValueError(f"Workflow {name} not found")
27          return cls._workflows[name](**kwargs)
28  
29      @classmethod
30      def list_workflows(cls) -> list[str]:
31          return list(cls._workflows.keys())
32  
33      @classmethod
34      def get_graph_state(cls, chat_profile: str) -> Type[BaseState]:
35          """Get GraphState using chat profile name"""
36          workflow_class = cls._workflows[chat_profile]
37          module = importlib.import_module(workflow_class.__module__)
38          return getattr(module, "GraphState")
39  
40      @classmethod
41      def get_chat_profile(cls, name: str) -> cl.ChatProfile:
42          """Get chat profile from workflow class"""
43          workflow_class = cls._workflows[name]
44          return workflow_class.chat_profile()