main.py
1 from langgraph.graph import StateGraph, START, END 2 3 from revolve.data_types import State 4 from revolve.db import get_adapter 5 6 from langgraph.constants import Send 7 from revolve.utils_git import * 8 9 import os 10 11 12 from revolve.nodes import ( 13 router_node, 14 generate_prompt_for_code_generation, 15 process_table, 16 generate_api, 17 test_node, 18 report_node, 19 check_user_request, 20 tool_handler, 21 BasicToolNode, 22 should_continue_tool_call 23 ) 24 25 from revolve.external import get_db_type 26 from revolve.tools import get_tools 27 28 29 def send_message(message): 30 print(f"{message}") 31 32 def run_workflow(task=None, db_config=None, send=None): 33 if send is None: 34 send = send_message 35 test_mode = True if db_config and db_config.get("USE_CLONE_DB", False) else False 36 if db_config: 37 os.environ["DB_NAME"] = db_config["DB_NAME"] 38 os.environ["DB_USER"] = db_config["DB_USER"] 39 os.environ["DB_PASSWORD"] = db_config["DB_PASSWORD"] 40 os.environ["DB_HOST"] = db_config["DB_HOST"] 41 os.environ["DB_PORT"] = db_config["DB_PORT"] 42 os.environ["DB_TYPE"] = db_config["DB_TYPE"] 43 44 adapter = get_adapter(get_db_type()) 45 46 db_test_result = adapter.check_db(db_user=os.environ["DB_USER"], 47 db_password=os.environ["DB_PASSWORD"], 48 db_host=os.environ["DB_HOST"], 49 db_port=os.environ["DB_PORT"], 50 db_name=os.environ["DB_NAME"]) 51 if not db_test_result: 52 send({ 53 "status":"error", 54 "text": "Database connection failed. Please check your database configuration.", 55 "name": "Database Connection Error" 56 }) 57 return 58 59 60 61 graph = StateGraph(State) 62 63 graph.add_node("router_node", router_node) 64 graph.add_node("check_user_request", check_user_request) 65 graph.add_node("generate_prompt_for_code_generation", generate_prompt_for_code_generation) 66 graph.add_node("process_table", process_table) 67 graph.add_node("generate_api", generate_api) 68 graph.add_node("test_node", test_node) 69 graph.add_node("report_node", report_node) 70 graph.add_node("tool_handler", tool_handler) 71 72 tool_executor = BasicToolNode(tools=get_tools()) 73 graph.add_node("tool_executor", tool_executor) 74 75 76 77 graph.add_edge(START, "check_user_request") 78 graph.add_conditional_edges("check_user_request", lambda state: state["classification"], {"create_crud_task" : "router_node", "__end__":END, "respond_back": END, "other_tasks":"tool_handler"}) 79 graph.add_conditional_edges( 80 "router_node", lambda state: state["next_node"], {"generate_prompt_for_code_generation":"generate_prompt_for_code_generation", "test_node": "test_node", "report_node": "report_node", "__end__":END} 81 ) 82 83 graph.add_conditional_edges( 84 "tool_handler", should_continue_tool_call, {"tool_executor": "tool_executor", "__end__": END} 85 ) 86 graph.add_edge("tool_executor", "tool_handler") 87 88 89 graph.add_conditional_edges( 90 "generate_prompt_for_code_generation", lambda state: [Send("process_table", s) for s in state["DBSchema"]["tables"]], ["process_table"] 91 ) 92 93 graph.add_edge("process_table", "generate_api") 94 graph.add_edge("generate_api", "router_node") 95 graph.add_edge("test_node", "router_node") 96 graph.add_edge("report_node", "router_node") 97 98 workflow = graph.compile() 99 100 if not task: 101 #task = "Created crud operations for passes, satellites, ground stations and orbits" 102 task = [ 103 { 104 "role": "user", 105 "content": "Create CRUD operations for passes, satellites, ground stations and orbits." 106 } 107 ] 108 109 for event in workflow.stream({"messages": task, "send":send,"test_mode": test_mode}): 110 name = "" 111 text = "" 112 key = list(event.keys())[0] 113 if event[key]: 114 if "trace" in event[key]: 115 if "description" in event[key]["trace"][-1]: 116 name = event[key]["trace"][-1]["node_name"] 117 text = event[key]["trace"][-1]["description"] 118 level = "workflow" if name in ["report_node","run_tests"] else "system" 119 send({ 120 "status":"processing", 121 "text":text, 122 "name":name, 123 "level":level 124 }) 125 # send({ 126 # "status":"done", 127 # "text":"Task completed.", 128 # "name":"Workflow", 129 # "level":"workflow" 130 # }) 131 132 if __name__ == "__main__": 133 run_workflow()