/ src / revolve / main.py
main.py
  1  from langgraph.graph import StateGraph, START, END
  2  
  3  from revolve.data_types import State
  4  from revolve.db import get_adapter
  5  
  6  from langgraph.constants import Send
  7  from revolve.utils_git import *
  8  
  9  import os
 10  
 11  
 12  from revolve.nodes import (
 13      router_node,
 14      generate_prompt_for_code_generation,
 15      process_table,
 16      generate_api,
 17      test_node,
 18      report_node,
 19      check_user_request,
 20      tool_handler,
 21      BasicToolNode,
 22      should_continue_tool_call
 23  )
 24  
 25  from revolve.external import get_db_type
 26  from revolve.tools import get_tools
 27  
 28  
 29  def send_message(message):
 30      print(f"{message}")
 31  
 32  def run_workflow(task=None, db_config=None, send=None):
 33      if send is None:
 34          send = send_message
 35      test_mode = True if db_config and db_config.get("USE_CLONE_DB", False) else False
 36      if db_config:
 37          os.environ["DB_NAME"] = db_config["DB_NAME"]
 38          os.environ["DB_USER"] = db_config["DB_USER"]
 39          os.environ["DB_PASSWORD"] = db_config["DB_PASSWORD"]
 40          os.environ["DB_HOST"] = db_config["DB_HOST"]
 41          os.environ["DB_PORT"] = db_config["DB_PORT"]
 42          os.environ["DB_TYPE"] = db_config["DB_TYPE"]
 43  
 44      adapter = get_adapter(get_db_type())
 45  
 46      db_test_result = adapter.check_db(db_user=os.environ["DB_USER"],
 47                                db_password=os.environ["DB_PASSWORD"],
 48                                db_host=os.environ["DB_HOST"],
 49                                db_port=os.environ["DB_PORT"],
 50                                db_name=os.environ["DB_NAME"])
 51      if not db_test_result:
 52          send({
 53              "status":"error",
 54              "text": "Database connection failed. Please check your database configuration.",
 55              "name": "Database Connection Error"
 56          })
 57          return
 58           
 59           
 60      
 61      graph = StateGraph(State)
 62  
 63      graph.add_node("router_node", router_node)
 64      graph.add_node("check_user_request", check_user_request)
 65      graph.add_node("generate_prompt_for_code_generation", generate_prompt_for_code_generation)
 66      graph.add_node("process_table", process_table)
 67      graph.add_node("generate_api", generate_api)
 68      graph.add_node("test_node", test_node)
 69      graph.add_node("report_node", report_node)
 70      graph.add_node("tool_handler", tool_handler)
 71  
 72      tool_executor = BasicToolNode(tools=get_tools())
 73      graph.add_node("tool_executor", tool_executor)
 74  
 75  
 76  
 77      graph.add_edge(START, "check_user_request")
 78      graph.add_conditional_edges("check_user_request", lambda state: state["classification"], {"create_crud_task" : "router_node",  "__end__":END, "respond_back": END, "other_tasks":"tool_handler"})
 79      graph.add_conditional_edges(
 80          "router_node", lambda state: state["next_node"], {"generate_prompt_for_code_generation":"generate_prompt_for_code_generation", "test_node": "test_node", "report_node": "report_node", "__end__":END}
 81      )
 82  
 83      graph.add_conditional_edges(
 84          "tool_handler", should_continue_tool_call, {"tool_executor": "tool_executor", "__end__": END}
 85      )
 86      graph.add_edge("tool_executor", "tool_handler")
 87  
 88  
 89      graph.add_conditional_edges(
 90          "generate_prompt_for_code_generation", lambda state: [Send("process_table", s) for s in state["DBSchema"]["tables"]], ["process_table"]
 91      )
 92  
 93      graph.add_edge("process_table", "generate_api")
 94      graph.add_edge("generate_api", "router_node")
 95      graph.add_edge("test_node", "router_node")
 96      graph.add_edge("report_node", "router_node")
 97  
 98      workflow = graph.compile()
 99  
100      if not task:
101          #task = "Created crud operations for passes, satellites, ground stations and orbits"
102          task = [
103              {
104                  "role": "user",
105                  "content": "Create CRUD operations for passes, satellites, ground stations and orbits."
106              }
107          ]
108  
109      for event in workflow.stream({"messages": task, "send":send,"test_mode": test_mode}):
110          name = ""
111          text = ""
112          key = list(event.keys())[0]
113          if event[key]:
114              if "trace" in event[key]:
115                  if "description" in event[key]["trace"][-1]:
116                      name = event[key]["trace"][-1]["node_name"]
117                      text = event[key]["trace"][-1]["description"]
118                      level = "workflow" if name in ["report_node","run_tests"] else "system"
119                      send({
120                          "status":"processing",
121                          "text":text,
122                          "name":name,
123                          "level":level
124                      })
125      # send({
126      #     "status":"done",
127      #     "text":"Task completed.",
128      #     "name":"Workflow",
129      #     "level":"workflow"
130      # })
131      
132  if __name__ == "__main__":
133      run_workflow()