postgres_adapter.py
1 import datetime 2 import json 3 import os 4 import sys 5 from collections import defaultdict, deque 6 from pathlib import Path 7 from typing import Dict, List, Any 8 9 import psycopg2 10 import sqlparse 11 from psycopg2 import sql, errors 12 13 # Add parent directory to sys.path 14 sys.path.append(str(Path(__file__).resolve().parent.parent)) 15 from revolve.utils import log 16 from revolve.db.adapter import DatabaseAdapter, db_tool 17 18 19 class PostgresAdapter(DatabaseAdapter): 20 def __init__(self): 21 self.config = { 22 "dbname": os.getenv("DB_NAME"), 23 "user": os.getenv("DB_USER"), 24 "password": os.getenv("DB_PASSWORD"), 25 "host": os.getenv("DB_HOST"), 26 "port": os.getenv("DB_PORT"), 27 } 28 29 def get_raw_schemas(self): 30 schemas_raw = self.run_query_on_db(""" 31 SELECT jsonb_object_agg( 32 table_name, 33 columns 34 ) AS schema_dict 35 FROM ( 36 SELECT 37 c.table_name, 38 jsonb_agg( 39 jsonb_strip_nulls( 40 jsonb_build_object( 41 'column_name', c.column_name, 42 'data_type', c.data_type, 43 'data_type_s', 44 CASE 45 WHEN def.column_default LIKE 'nextval(%' AND pt.typname = 'int4' THEN 'serial' 46 WHEN def.column_default LIKE 'nextval(%' AND pt.typname = 'int8' THEN 'bigserial' 47 WHEN c.data_type IN ('character varying', 'varchar') AND c.character_maximum_length IS NOT NULL 48 THEN 'varchar(' || c.character_maximum_length || ')' 49 WHEN c.data_type = 'numeric' AND c.numeric_precision IS NOT NULL 50 THEN 'numeric(' || c.numeric_precision || 51 COALESCE(', ' || c.numeric_scale, '') || ')' 52 ELSE c.data_type 53 END, 54 'is_nullable', c.is_nullable, 55 'character_max_length', c.character_maximum_length, 56 'numeric_precision', c.numeric_precision, 57 'numeric_scale', c.numeric_scale, 58 'enum_values', ev.enum_values, 59 'foreign_key', jsonb_build_object( 60 'foreign_table', ccu.table_name, 61 'foreign_column', ccu.column_name 62 ), 63 'default_value', def.column_default 64 ) 65 ) 66 ) AS columns 67 FROM information_schema.columns c 68 LEFT JOIN information_schema.key_column_usage kcu 69 ON c.table_name = kcu.table_name 70 AND c.column_name = kcu.column_name 71 AND c.table_schema = kcu.table_schema 72 LEFT JOIN information_schema.table_constraints tc 73 ON kcu.constraint_name = tc.constraint_name 74 AND tc.constraint_type = 'FOREIGN KEY' 75 LEFT JOIN information_schema.constraint_column_usage ccu 76 ON tc.constraint_name = ccu.constraint_name 77 LEFT JOIN LATERAL ( 78 SELECT jsonb_agg(e.enumlabel) AS enum_values 79 FROM pg_type t 80 JOIN pg_enum e ON t.oid = e.enumtypid 81 JOIN pg_namespace ns ON ns.oid = t.typnamespace 82 WHERE t.typname = c.udt_name 83 ) ev ON true 84 LEFT JOIN LATERAL ( 85 SELECT 86 pg_get_expr(ad.adbin, ad.adrelid) AS column_default 87 FROM pg_attribute a 88 JOIN pg_class cls ON cls.oid = a.attrelid 89 JOIN pg_attrdef ad ON ad.adrelid = a.attrelid AND ad.adnum = a.attnum 90 WHERE cls.relname = c.table_name AND a.attname = c.column_name 91 LIMIT 1 92 ) def ON true 93 LEFT JOIN pg_type pt ON pt.typname = c.udt_name 94 WHERE c.table_schema NOT IN ('pg_catalog', 'information_schema') 95 GROUP BY c.table_name 96 ) AS sub; 97 98 """) 99 return schemas_raw 100 101 102 def get_table_dependencies(self): 103 query_result = self.run_query_on_db(""" 104 WITH fk_info AS ( 105 SELECT 106 kcu.table_name AS from_table, 107 kcu.column_name AS from_column, 108 ccu.table_name AS to_table, 109 ccu.column_name AS to_column, 110 -- Check if from_column is unique or part of primary key 111 EXISTS ( 112 SELECT 1 FROM information_schema.table_constraints tc2 113 JOIN information_schema.key_column_usage kcu2 114 ON tc2.constraint_name = kcu2.constraint_name 115 WHERE tc2.table_name = kcu.table_name 116 AND kcu2.column_name = kcu.column_name 117 AND tc2.constraint_type IN ('PRIMARY KEY', 'UNIQUE') 118 ) AS is_from_unique, 119 -- Check if to_column is unique or primary 120 EXISTS ( 121 SELECT 1 FROM information_schema.table_constraints tc3 122 JOIN information_schema.key_column_usage kcu3 123 ON tc3.constraint_name = kcu3.constraint_name 124 WHERE tc3.table_name = ccu.table_name 125 AND kcu3.column_name = ccu.column_name 126 AND tc3.constraint_type IN ('PRIMARY KEY', 'UNIQUE') 127 ) AS is_to_unique 128 FROM information_schema.table_constraints tc 129 JOIN information_schema.key_column_usage kcu 130 ON tc.constraint_name = kcu.constraint_name 131 JOIN information_schema.constraint_column_usage ccu 132 ON tc.constraint_name = ccu.constraint_name 133 WHERE tc.constraint_type = 'FOREIGN KEY' 134 AND tc.table_schema NOT IN ('pg_catalog', 'information_schema') 135 ) 136 SELECT jsonb_object_agg( 137 from_table, 138 column_mappings 139 ) AS fk_relationships 140 FROM ( 141 SELECT 142 from_table, 143 jsonb_object_agg( 144 from_column, 145 jsonb_build_object( 146 'links_to_table', to_table, 147 'reltype', 148 CASE 149 WHEN is_from_unique AND is_to_unique THEN 'one-to-one' 150 WHEN NOT is_from_unique AND is_to_unique THEN 'many-to-one' 151 ELSE 'uncertain' -- fallback 152 END 153 ) 154 ) AS column_mappings 155 FROM fk_info 156 GROUP BY from_table 157 ) AS rels; 158 159 """) 160 return query_result 161 162 163 def get_schemas_from_db(self): 164 schemas_raw = self.get_raw_schemas() 165 dependencies_raw = self.get_table_dependencies() 166 167 schemas = json.loads(schemas_raw)[0][0] 168 dependencies = json.loads(dependencies_raw)[0][0] if json.loads(dependencies_raw)[0][0] is not None else {} 169 170 for table, columns in schemas.items(): 171 dep_columns = dependencies.get(table, {}) 172 for column in columns: 173 col_name = column.get("column_name") 174 if col_name in dep_columns: 175 # Merge the reltype and links_to_table 176 column.update(dep_columns[col_name]) 177 178 return schemas 179 180 query_result = self.run_query_on_db(""" 181 SELECT jsonb_object_agg( 182 table_name, 183 columns 184 ) AS schema_dict 185 FROM ( 186 SELECT 187 c.table_name, 188 jsonb_agg( 189 jsonb_strip_nulls( 190 jsonb_build_object( 191 'column_name', c.column_name, 192 'data_type', c.data_type, 193 'data_type_s', 194 CASE 195 WHEN def.column_default LIKE 'nextval(%' AND pt.typname = 'int4' THEN 'serial' 196 WHEN def.column_default LIKE 'nextval(%' AND pt.typname = 'int8' THEN 'bigserial' 197 WHEN c.data_type IN ('character varying', 'varchar') AND c.character_maximum_length IS NOT NULL 198 THEN 'varchar(' || c.character_maximum_length || ')' 199 WHEN c.data_type = 'numeric' AND c.numeric_precision IS NOT NULL 200 THEN 'numeric(' || c.numeric_precision || 201 COALESCE(', ' || c.numeric_scale, '') || ')' 202 ELSE c.data_type 203 END, 204 'is_nullable', c.is_nullable, 205 'character_max_length', c.character_maximum_length, 206 'numeric_precision', c.numeric_precision, 207 'numeric_scale', c.numeric_scale, 208 'enum_values', ev.enum_values, 209 'foreign_key', jsonb_build_object( 210 'foreign_table', ccu.table_name, 211 'foreign_column', ccu.column_name 212 ) 213 ) 214 ) 215 ) AS columns 216 FROM information_schema.columns c 217 LEFT JOIN information_schema.key_column_usage kcu 218 ON c.table_name = kcu.table_name 219 AND c.column_name = kcu.column_name 220 AND c.table_schema = kcu.table_schema 221 LEFT JOIN information_schema.table_constraints tc 222 ON kcu.constraint_name = tc.constraint_name 223 AND tc.constraint_type = 'FOREIGN KEY' 224 LEFT JOIN information_schema.constraint_column_usage ccu 225 ON tc.constraint_name = ccu.constraint_name 226 LEFT JOIN LATERAL ( 227 SELECT jsonb_agg(e.enumlabel) AS enum_values 228 FROM pg_type t 229 JOIN pg_enum e ON t.oid = e.enumtypid 230 JOIN pg_namespace ns ON ns.oid = t.typnamespace 231 WHERE t.typname = c.udt_name 232 ) ev ON true 233 LEFT JOIN LATERAL ( 234 SELECT 235 pg_get_expr(ad.adbin, ad.adrelid) AS column_default 236 FROM pg_attribute a 237 JOIN pg_class cls ON cls.oid = a.attrelid 238 JOIN pg_attrdef ad ON ad.adrelid = a.attrelid AND ad.adnum = a.attnum 239 WHERE cls.relname = c.table_name AND a.attname = c.column_name 240 LIMIT 1 241 ) def ON true 242 LEFT JOIN pg_type pt ON pt.typname = c.udt_name 243 WHERE c.table_schema NOT IN ('pg_catalog', 'information_schema') 244 GROUP BY c.table_name 245 ) AS sub 246 """) 247 248 return query_result 249 250 251 def order_tables_by_dependencies(self, dependencies: Dict[str, Any]) -> List[str]: 252 # Extract child tables and all referenced parent tables 253 child_tables = set(dependencies) 254 referenced_parents = { 255 info["links_to_table"] 256 for links in dependencies.values() 257 for info in links.values() 258 } 259 260 # Build dependency map: table -> list of tables it links to 261 final_dependency_map = { 262 table: [info["links_to_table"] for info in links.values()] 263 for table, links in dependencies.items() 264 } 265 266 # Identify and remove tables that are only referenced and have no outgoing links 267 for table in list(final_dependency_map): 268 if table in referenced_parents and not final_dependency_map[table]: 269 del final_dependency_map[table] 270 271 tp_sorted = self.topological_sort(dependencies) 272 273 return final_dependency_map, tp_sorted 274 275 276 def topological_sort(self, dependencies: Dict[str, Any]) -> List[str]: 277 graph = defaultdict(list) # parent -> list of children 278 in_degree = defaultdict(int) 279 all_tables = set(dependencies) # explicitly defined tables 280 281 # Parse dependencies to construct graph and in-degrees 282 for child, columns in dependencies.items(): 283 if isinstance(columns, dict): 284 for col_info in columns.values(): 285 if isinstance(col_info, dict) and "links_to_table" in col_info: 286 parent = col_info["links_to_table"] 287 graph[parent].append(child) 288 in_degree[child] += 1 289 all_tables.add(parent) # include implicit parent tables 290 291 # Start with all tables that have no incoming edges 292 queue = deque([table for table in all_tables if in_degree[table] == 0]) 293 sorted_tables = [] 294 295 # Kahn’s Algorithm for topological sorting 296 while queue: 297 table = queue.popleft() 298 sorted_tables.append(table) 299 300 for dependent in graph[table]: 301 in_degree[dependent] -= 1 302 if in_degree[dependent] == 0: 303 queue.append(dependent) 304 305 # Detect cycles 306 if len(sorted_tables) != len(all_tables): 307 raise ValueError("Cycle detected in table dependencies.") 308 309 return sorted_tables 310 311 def get_tables(self) -> List[Dict[str, Any]]: 312 raise RuntimeError("get_tables is not supported") 313 314 @db_tool 315 def run_query_on_db(self, query: str) -> str: 316 """ 317 This function runs the given query on the postgres database. 318 Args: 319 query (str): The query to be run. 320 """ 321 # log("run_query_on_db", f"Running query: {query}") 322 try: 323 conn = psycopg2.connect( 324 dbname=os.getenv("DB_NAME"), 325 user=os.getenv("DB_USER"), 326 password=os.getenv("DB_PASSWORD"), 327 host=os.getenv("DB_HOST"), 328 port=os.getenv("DB_PORT"), 329 ) 330 cur = conn.cursor() 331 cur.execute(query) 332 # if cur.rowcount: 333 result = cur.fetchall() 334 # else: 335 # result = None 336 conn.commit() 337 cur.close() 338 conn.close() 339 340 except Exception as e: 341 log(f"Error running query: {e}") 342 return f"Error running query: {e}" 343 344 return json.dumps(result, default=default_serializer) 345 346 347 def check_db( 348 self, db_name: str, db_user: str, db_password: str, db_host: str, db_port: str 349 ) -> str: 350 """ 351 This function tests the database connection. 352 Args: 353 db_name (str): The name of the database. 354 db_user (str): The user of the database. 355 db_password (str): The password of the database. 356 db_host (str): The host of the database. 357 db_port (str): The port of the database. 358 """ 359 log("Testing database connection...") 360 try: 361 conn = psycopg2.connect( 362 dbname=db_name, 363 user=db_user, 364 password=db_password, 365 host=db_host, 366 port=db_port, 367 ) 368 conn.close() 369 except Exception as e: 370 log(f"Database connection failed: {e}", level="ERROR") 371 return False 372 373 return True 374 375 376 def recreate_database_psycopg2(self, dbname, user, password, host, port): 377 """Drop and recreate the target database using psycopg2.""" 378 conn = psycopg2.connect( 379 dbname="postgres", # connect to control DB 380 user=user, 381 password=password, 382 host=host, 383 port=port 384 ) 385 conn.set_session(autocommit=True) 386 with conn.cursor() as cur: 387 try: 388 cur.execute(f"DROP DATABASE IF EXISTS {dbname};") 389 print(f"✅ Dropped database '{dbname}'") 390 except Exception as e: 391 print(f"⚠️ Failed to drop: {e}") 392 try: 393 cur.execute(f"CREATE DATABASE {dbname};") 394 print(f"✅ Created database '{dbname}'") 395 except Exception as e: 396 print(f"❌ Failed to create database: {e}") 397 raise 398 conn.close() 399 400 401 def restore_schema_with_psycopg2( 402 self, 403 dump_file, 404 dbname, 405 user, 406 password, 407 host="localhost", 408 port=5432, 409 recreate_db=False 410 ): 411 if recreate_db: 412 self.recreate_database_psycopg2(self, dbname, user, password, host, port) 413 414 with open(dump_file, "r") as f: 415 raw_sql = f.read() 416 417 statements = sqlparse.split(raw_sql) 418 419 conn = psycopg2.connect( 420 dbname=dbname, 421 user=user, 422 password=password, 423 host=host, 424 port=port 425 ) 426 conn.set_session(autocommit=False) 427 with conn.cursor() as cur: 428 for stmt in statements: 429 stmt = stmt.strip() 430 if not stmt: 431 continue 432 try: 433 cur.execute(stmt) 434 except psycopg2.errors.DuplicateTable: 435 print("⚠️ Table already exists. Skipped.") 436 conn.rollback() 437 except psycopg2.errors.DuplicateObject: 438 print("⚠️ Object already exists. Skipped.") 439 conn.rollback() 440 except Exception as e: 441 print(f"❌ Error executing statement:\n{stmt[:200]}...\n{e}") 442 conn.rollback() 443 else: 444 conn.commit() 445 446 conn.close() 447 print(f"✅ Schema restored to '{dbname}'") 448 449 450 def generate_create_table_sql(self, table_name: str, columns: List[Dict]) -> str: 451 col_lines = [] 452 enum_defs = [] 453 requires_uuid_ossp = False 454 455 for col in columns: 456 data_type = col["data_type_s"] 457 column_name = col["column_name"] 458 459 # Handle user-defined enums 460 if data_type == "USER-DEFINED" and col.get("enum_values"): 461 enum_type_name = f"{table_name}_{column_name}_enum" 462 data_type = enum_type_name 463 enum_vals = ', '.join(f"'{v}'" for v in col["enum_values"]) 464 enum_defs.append( 465 f"DO $$ BEGIN IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = '{enum_type_name}') " 466 f"THEN CREATE TYPE {enum_type_name} AS ENUM ({enum_vals}); END IF; END $$;" 467 ) 468 469 elif data_type == "ARRAY": 470 data_type = "text[]" # Customize as needed 471 472 # Build column line 473 line = f" {column_name} {data_type}" 474 475 # Add default value if present (skip if serial type) 476 default_value = col.get("default_value") 477 if default_value and not (data_type.lower() == "serial" and "nextval" in default_value.lower()): 478 if ( 479 col["data_type_s"] == "USER-DEFINED" 480 and "::" in default_value 481 and col.get("enum_values") 482 ): 483 # Replace generic cast with the correct enum type name 484 enum_type_name = f"{table_name}_{column_name}_enum" 485 default_value = default_value.split("::")[0] + f"::{enum_type_name}" 486 line += f" DEFAULT {default_value}" 487 if "uuid_generate_v4()" in default_value: 488 requires_uuid_ossp = True 489 490 # Add NOT NULL constraint 491 if col["is_nullable"] == "NO": 492 line += " NOT NULL" 493 494 col_lines.append(line) 495 496 ddl_parts = [] 497 498 # Add CREATE EXTENSION if required 499 if requires_uuid_ossp: 500 ddl_parts.append('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') 501 502 # Add ENUM definitions if any 503 ddl_parts.extend(enum_defs) 504 505 # Final CREATE TABLE statement 506 body = ",\n".join(col_lines) 507 ddl_parts.append(f"CREATE TABLE {table_name} (\n{body}\n);") 508 509 return "\n".join(ddl_parts) 510 511 512 def gen_table_map(self, map : Dict) -> Dict: 513 # Regenerate with fixed function 514 fixed_create_table_ddls = { 515 table_name: self.generate_create_table_sql(table_name, columns) 516 for table_name, columns in map.items() 517 } 518 return fixed_create_table_ddls 519 520 521 def create_database_if_not_exists(self, existing_dbname, new_dbname, user, password, host='localhost', port=5432): 522 method = "create_database_if_not_exists" 523 try: 524 conn = psycopg2.connect( 525 dbname=existing_dbname, 526 user=user, 527 password=password, 528 host=host, 529 port=port 530 ) 531 conn.set_session(autocommit=True) 532 with conn.cursor() as cur: 533 cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (new_dbname,)) 534 if cur.fetchone(): 535 log(f"ℹ️ Database '{new_dbname}' already exists.") 536 else: 537 log(f"📦 Creating database '{new_dbname}' owned by '{user}'...") 538 cur.execute( 539 sql.SQL("CREATE DATABASE {} OWNER {}") 540 .format(sql.Identifier(new_dbname), sql.Identifier(user)) 541 ) 542 log(f"✅ Database '{new_dbname}' created.") 543 conn.close() 544 except Exception as e: 545 raise RuntimeError(f"[{method}] ❌ Failed to create database '{new_dbname}': {e}") 546 547 548 def apply_create_table_ddls(self, table_ddl_map, existing_dbname, new_dbname, user, password, host='localhost', port=5432, drop_if_exists=False): 549 method = "apply_create_table_ddls" 550 self.create_database_if_not_exists(existing_dbname, new_dbname, user, password, host, port) 551 conn = psycopg2.connect( 552 dbname=new_dbname, 553 user=user, 554 password=password, 555 host=host, 556 port=port 557 ) 558 conn.set_session(autocommit=False) 559 560 with conn.cursor() as cur: 561 for table, ddl in table_ddl_map.items(): 562 log(f"▶️ Creating table: {table}") 563 try: 564 cur.execute(ddl) 565 except errors.DuplicateTable: 566 log(f"⚠️ Table '{table}' already exists.") 567 conn.rollback() 568 if drop_if_exists: 569 try: 570 log(f"🔁 Dropping and recreating table '{table}'...") 571 cur.execute(sql.SQL("DROP TABLE IF EXISTS {} CASCADE").format(sql.Identifier(table))) 572 conn.commit() 573 cur.execute(ddl) 574 conn.commit() 575 log(f"✅ Recreated table: {table}") 576 except Exception as drop_err: 577 log(f"❌ Error recreating '{table}': {drop_err}") 578 conn.rollback() 579 else: 580 log(f"⏩ Skipping '{table}' (already exists)") 581 except Exception as e: 582 log(f"❌ Error creating '{table}': {e}") 583 conn.rollback() 584 else: 585 conn.commit() 586 log(f"✅ Created table: {table}") 587 conn.close() 588 589 590 def clone_db(self): 591 ddls = self.get_schemas_from_db() 592 tables = self.gen_table_map(ddls) 593 594 new_dbname = os.getenv("DB_NAME") + "_test" 595 user = os.getenv("DB_USER") 596 password = os.getenv("DB_PASSWORD") 597 host = os.getenv("DB_HOST") 598 port = os.getenv("DB_PORT") 599 600 self.apply_create_table_ddls(tables, os.getenv("DB_NAME"), new_dbname, user, password, host=host, port=port, drop_if_exists=True) 601 os.environ["DB_NAME_TEST"] = new_dbname 602 603 604 def extract_permissions(self, result_data): 605 if not isinstance(result_data, list): 606 raise ValueError("Expected result_data to be a list.") 607 608 if not result_data: 609 raise ValueError("result_data is an empty list.") 610 611 last_element = result_data[-1] 612 if not isinstance(last_element, list): 613 raise ValueError("Expected last element of result_data to be a list.") 614 615 if not last_element: 616 raise ValueError("The last list inside result_data is empty.") 617 618 last_dict = last_element[-1] 619 if not isinstance(last_dict, dict): 620 raise ValueError("Expected last element of the nested list to be a dict.") 621 622 if not last_dict: 623 raise ValueError("The final dictionary is empty.") 624 625 return last_dict 626 627 628 def check_permissions(self): 629 db_user_name = os.getenv("DB_USER") 630 query = f""" 631 SELECT jsonb_build_object( 632 'can_connect', has_database_privilege('postgres', current_database(), 'CONNECT'), 633 'can_use_schema', has_schema_privilege('postgres', 'public', 'USAGE'), 634 'can_create_db', rolcreatedb, 635 'can_create_role', rolcreaterole, 636 'is_superuser', rolsuper 637 ) AS permissions 638 FROM pg_roles 639 WHERE rolname = '{db_user_name}'; 640 """ 641 result = self.run_query_on_db(query) 642 result_data = json.loads(result) 643 permissions = self.extract_permissions(result_data) 644 645 suggested_queries_template ={ 646 'can_connect': "GRANT CONNECT ON DATABASE {database} TO {username};", 647 'is_superuser': "ALTER ROLE {username} WITH SUPERUSER;", 648 'can_create_db': "ALTER ROLE {username} WITH CREATEDB;", 649 'can_use_schema': "GRANT USAGE ON SCHEMA public TO {username};", 650 'can_create_role': "ALTER ROLE {username} WITH CREATEROLE;", 651 } 652 suggested_queries = {} 653 for key, value in permissions.items(): 654 if value is False: 655 suggested_queries[key] = suggested_queries_template[key].format(database=os.getenv("DB_NAME"), username=db_user_name) 656 657 if all(value == True for value in permissions.values()): 658 return { 659 "status": "success", 660 "permissions": permissions, 661 "message": "All required permissions are already granted. " 662 } 663 else: 664 suggested_queries_str = "\n".join([f"{key}: {value}" for key, value in suggested_queries.items()]) 665 return { 666 "status": "error", 667 "permissions": permissions, 668 "error": f"You do not have the necessary permissions to perform this operation. You can use the following " 669 f"SQL statements to grant the required permissions:<pre style={{ whiteSpace: 'pre-wrap', " 670 f"background: '#f6f8fa', padding: 10, borderRadius: 4 }}>{suggested_queries_str}</pre>", 671 } 672 673 674 def default_serializer(obj): 675 if isinstance(obj, (datetime.date, datetime.datetime)): 676 return obj.isoformat() 677 return obj