/ src / revolve / api.py
api.py
  1  import random
  2  import time
  3  import falcon
  4  import logging
  5  import json
  6  import sys
  7  import os
  8  from wsgiref.simple_server import make_server, WSGIServer
  9  from socketserver import ThreadingMixIn
 10  
 11  from revolve.db import get_adapter
 12  from revolve.workflow_generator import run_workflow_generator
 13  from revolve.utils import start_process, stop_process
 14  from revolve.utils import read_python_code
 15  from revolve.functions import get_file_list
 16  from wsgiref.simple_server import WSGIRequestHandler
 17  
 18  
 19  logging.basicConfig(level=logging.INFO, filename="api.log")
 20  logger = logging.getLogger(__name__)
 21  
 22  
 23  class LoggingWSGIRequestHandler(WSGIRequestHandler):
 24      #     daemon_threads = True
 25      def log_message(self, format, *args):
 26          logger.info("%s - - [%s] %s\n" % (
 27              self.client_address[0],
 28              self.log_date_time_string(),
 29              format % args
 30          ))
 31  
 32  class WorkflowResource:
 33      def on_post(self, req, resp):
 34          try:
 35              data = req.media 
 36              messages = data.get("messages", None)
 37              db_config = data.get("dbConfig", {})
 38              settings = data.get("settings", {})
 39  
 40              if not settings.get("sourceFolder"):
 41                  resp.status = falcon.HTTP_400
 42                  resp.media = {"error": "Missing settings parameters - Source folder is required."}
 43                  return
 44              
 45              if settings.get("provider") == "openai":
 46                  if not settings.get("openaiKey") or not settings.get("modelName"):
 47                      resp.status = falcon.HTTP_400
 48                      resp.media = {"error": "Missing settings parameters - OpenAI API key and Model name are required for OpenAI provider."}
 49                      return
 50              elif settings.get("provider") == "opensource":
 51                  if not settings.get("modelName") or not settings.get("baseUrl"):
 52                      resp.status = falcon.HTTP_400
 53                      resp.media = {"error": "Missing settings parameters - Model name and Base URL are required for open source provider."}
 54                      return
 55              
 56              source_folder = settings.get("sourceFolder")
 57              if not os.path.exists(source_folder):
 58                  try:
 59                      os.makedirs(source_folder)
 60                  except Exception:
 61                      resp.status = falcon.HTTP_400
 62                      resp.media = {"error": f"Source folder {source_folder} does not exist."}
 63                      return
 64              
 65              #set env vars 
 66              os.environ["SOURCE_FOLDER"] = source_folder
 67              os.environ["OPENAI_API_KEY"] = settings.get("openaiKey")
 68              os.environ["LLM_PROVIDER"] = settings.get("provider")
 69              os.environ["BASE_URL"] = settings.get("baseUrl", "")
 70              os.environ["MODEL_NAME"] = settings.get("modelName", "")
 71              
 72  
 73              logger.info("Received task: %s", messages[-1]["content"] if messages else "No messages provided")
 74          except Exception:
 75              resp.status = falcon.HTTP_400
 76              resp.media = {"error": "Invalid JSON"}
 77              return
 78  
 79          resp.status = falcon.HTTP_200
 80          resp.content_type = 'application/x-ndjson'
 81  
 82          def generate():
 83              for item in run_workflow_generator(task=messages, db_config=db_config):
 84                  line = json.dumps(item) + "\n"
 85                  yield line.encode("utf-8")
 86  
 87          resp.stream = generate()
 88  
 89  class _MockWorkflowResource:
 90      def on_post(self, req, resp):
 91          try:
 92  
 93              data = req.media
 94              task = data.get("message", None)
 95              logger.info("Received task: %s", task)
 96          except Exception:
 97              resp.status = falcon.HTTP_400
 98              resp.media = {"error": "Invalid JSON"}
 99              return
100  
101          resp.status = falcon.HTTP_200
102          resp.content_type = 'application/x-ndjson'
103  
104          def generate():
105              levels = ["system", "workflow", "notification"]
106              for i in range(10):
107                  random_level = random.choice(levels)
108                  message = {
109                      "status": "processing",
110                      "name": "node",
111                      "level":random_level,
112                      "text": f"Step {i+1} test completed puya..."
113                  }
114                  yield (json.dumps(message) + "\n").encode("utf-8")
115                  time.sleep(1)  # Delay of 1 second
116  
117              final_message = {
118                  "status": "done",
119                  "level":"workflow",
120                  "name": "workflow",
121                  "text": "Task completed successfully",
122              }
123              yield (json.dumps(final_message) + "\n").encode("utf-8")
124  
125          resp.stream = generate()
126  
127  class EnvResource:
128      def on_get(self, req, resp):
129          if req.path.endswith('/settings'):
130              env_vars = {
131              "SOURCE_FOLDER": os.environ.get("SOURCE_FOLDER", ""),
132              "OPENAI_API_KEY": os.environ.get("OPENAI_API_KEY", ""),
133              "PROVIDER": os.environ.get("PROVIDER", "openai"),
134              "BASE_URL": os.environ.get("BASE_URL", ""),
135              "MODEL_NAME": os.environ.get("MODEL_NAME", "")
136          }
137          elif req.path.endswith('/db'):
138              env_vars = {
139                  "DB_NAME": os.environ.get("DB_NAME", ""),
140                  "DB_USER": os.environ.get("DB_USER", ""),
141                  "DB_PASSWORD": os.environ.get("DB_PASSWORD", ""),
142                  "DB_HOST": os.environ.get("DB_HOST", ""),
143                  "DB_PORT": os.environ.get("DB_PORT", ""),
144                  "DB_TYPE": os.environ.get("DB_TYPE", ""),
145              }
146          resp.status = falcon.HTTP_200
147          resp.media = env_vars
148  
149  
150  class FileResource:
151      def on_get(self, req, resp):
152          path = req.path
153  
154          if path.endswith('/get-file-list'):
155              self.get_file_list(req, resp)
156          elif path.endswith('/get-file'):
157              self.get_file(req, resp)
158          else:
159              resp.status = falcon.HTTP_404
160              resp.media = {"error": "Unknown file endpoint"}
161  
162      def get_file_list(self, req, resp):
163          
164          try:
165              file_list = get_file_list()
166              file_list = [f for f in file_list if f.endswith(('.py', '.json', '.md'))]
167              file_list.sort()
168              resp.status = falcon.HTTP_200
169              resp.media = {"files": file_list}
170          except Exception as e:
171              resp.status = falcon.HTTP_500
172              resp.media = {"error": str(e)}
173  
174      def get_file(self, req, resp):
175          file_name = req.get_param("name")
176          content = read_python_code(file_name)
177          if file_name.endswith(".py"):
178              content = f"```python\n{content}\n```"
179          elif file_name.endswith(".json"):
180              content = f"```json\n{content}\n```"
181          resp.status = falcon.HTTP_200
182          resp.media = {"content": content}
183  
184      def get_file(self, req, resp):
185          file_name = req.get_param("name")
186          content = read_python_code(file_name)
187          if file_name.endswith(".py"):
188              content = f"```python\n{content}\n```"
189          elif file_name.endswith(".json"):
190              content = f"```json\n{content}\n```"
191          resp.status = falcon.HTTP_200
192          resp.media = {"content": content}
193  
194  class TestDBResource:
195      def on_post(self, req, resp):
196          try:
197              data = req.media
198              db_name = data.get("DB_NAME", None)
199              db_user = data.get("DB_USER", None)
200              db_password = data.get("DB_PASSWORD", None)
201              db_host = data.get("DB_HOST", None)
202              db_port = data.get("DB_PORT", None)
203              db_type = data.get("DB_TYPE")
204  
205              os.environ["DB_NAME"] = db_name
206              os.environ["DB_USER"] = db_user
207              os.environ["DB_PASSWORD"] = db_password
208              os.environ["DB_HOST"] = db_host
209              os.environ["DB_PORT"] = db_port
210              os.environ["DB_TYPE"] = db_type
211  
212              if not all([db_name, db_user, db_password, db_host, db_port]):
213                  resp.status = falcon.HTTP_400
214                  resp.media = {"error": "Missing database connection parameters."}
215                  return
216  
217              adapter = get_adapter(db_type)
218  
219              result = adapter.check_db(
220                  db_name=db_name,
221                  db_user=db_user,
222                  db_password=db_password,
223                  db_host=db_host,
224                  db_port=db_port
225              )
226  
227              permissions = adapter.check_permissions()
228              if permissions["status"]=="error":
229                  resp.status = falcon.HTTP_403
230                  resp.media = permissions
231                  return
232  
233              schemas = adapter.get_schemas_from_db()
234              table_names = list(schemas.keys())
235              random.shuffle(table_names)
236              
237              if result:
238                  resp.status = falcon.HTTP_200
239                  resp.media = {"message": "Connection to DB was successful!", "tables": table_names}
240              else:
241                  resp.status = falcon.HTTP_500
242                  resp.media = {"error": "Connection to DB failed. Please check your credentials."}
243          except Exception as e:
244              print(e)
245              resp.status = falcon.HTTP_500
246              resp.media = {"error": "Database connection failed."}
247  
248  class ServerControlResource:
249      def on_post(self, req, resp):
250          path = req.path
251          if path.endswith('/start'):
252              result = start_process()
253          elif path.endswith('/stop'):
254              result = stop_process()
255          else:
256              resp.status = falcon.HTTP_404
257              resp.media = {"error": "Unknown command"}
258              return
259  
260          resp.status = falcon.HTTP_200
261          resp.media = result
262  
263  app = falcon.App()
264  app.add_route("/api/chat", WorkflowResource())
265  app.add_route("/api/test_db", TestDBResource())
266  app.add_route("/api/start", ServerControlResource())
267  app.add_route("/api/stop", ServerControlResource())
268  app.add_route("/api/get-file-list", FileResource())
269  app.add_route("/api/get-file", FileResource())
270  app.add_route("/api/env/settings", EnvResource())
271  app.add_route("/api/env/db", EnvResource())
272  
273  #get current directory
274  static_resource = f"{os.path.dirname(os.path.abspath(__file__))}/ui/dist"
275  # Route handling:
276  app.add_static_route("/{filepath:path}", static_resource)
277  app.add_static_route("/", static_resource, fallback_filename='index.html')
278  
279  
280  # # Threading WSGI server to handle concurrent requests
281  class ThreadingWSGIServer(ThreadingMixIn, WSGIServer):
282      daemon_threads = True
283  
284  #function to check if env vars are set
285  def check_env_vars():
286      required_vars = ["SOURCE_FOLDER", "OPENAI_API_KEY"]
287      missing = []
288      for var in required_vars:
289          if var not in os.environ:
290              missing.append(var)
291      if len(missing) > 0:
292          #print with emoji and red with f""
293          print("\033[91m" + "❌" + "\033[0m", end=" ")
294          print(f"Missing environment variables: {', '.join(missing)}")
295          #raise exception and exit
296          sys.exit(1)
297  
298  def main():
299      port = int(os.environ.get("API_PORT", "48001"))
300      with make_server("", port, app, server_class=ThreadingWSGIServer, handler_class=LoggingWSGIRequestHandler) as httpd:
301          logger.info(f"Serving on http://localhost:{port}/")
302          #print port with emoji and green
303          print("\033[92m" + "✅" + "\033[0m", end=" ")
304          print(f"Serving on http://localhost:{port}/")
305          httpd.serve_forever()
306  
307  if __name__ == "__main__":
308      main()