/ src / praisonai-ts / src / integrations / postgres.ts
postgres.ts
  1  /**
  2   * Natural Language Postgres
  3   * 
  4   * Provides NL->SQL capabilities for querying PostgreSQL databases.
  5   */
  6  
  7  export interface PostgresConfig {
  8    /** Database connection URL */
  9    connectionUrl: string;
 10    /** Schema to use (default: public) */
 11    schema?: string;
 12    /** Read-only mode (default: true for safety) */
 13    readOnly?: boolean;
 14    /** Allowed tables (whitelist) */
 15    allowedTables?: string[];
 16    /** Blocked tables (blacklist) */
 17    blockedTables?: string[];
 18    /** Maximum rows to return (default: 100) */
 19    maxRows?: number;
 20    /** Query timeout in ms (default: 30000) */
 21    timeout?: number;
 22  }
 23  
 24  export interface TableSchema {
 25    name: string;
 26    columns: ColumnSchema[];
 27    primaryKey?: string[];
 28    foreignKeys?: ForeignKey[];
 29  }
 30  
 31  export interface ColumnSchema {
 32    name: string;
 33    type: string;
 34    nullable: boolean;
 35    defaultValue?: string;
 36    description?: string;
 37  }
 38  
 39  export interface ForeignKey {
 40    columns: string[];
 41    referencedTable: string;
 42    referencedColumns: string[];
 43  }
 44  
 45  export interface QueryResult {
 46    rows: any[];
 47    rowCount: number;
 48    fields: { name: string; type: string }[];
 49    executionTime: number;
 50  }
 51  
 52  export interface NLQueryResult {
 53    /** Natural language query */
 54    query: string;
 55    /** Generated SQL */
 56    sql: string;
 57    /** Query result */
 58    result: QueryResult;
 59    /** Explanation of the query */
 60    explanation?: string;
 61  }
 62  
 63  /**
 64   * Create a Natural Language Postgres client.
 65   * 
 66   * @example Basic usage
 67   * ```typescript
 68   * import { createNLPostgres } from 'praisonai/integrations/postgres';
 69   * 
 70   * const db = await createNLPostgres({
 71   *   connectionUrl: process.env.DATABASE_URL!,
 72   *   readOnly: true
 73   * });
 74   * 
 75   * // Query with natural language
 76   * const result = await db.query('Show me all users who signed up last month');
 77   * console.log(result.rows);
 78   * ```
 79   * 
 80   * @example With schema introspection
 81   * ```typescript
 82   * const db = await createNLPostgres({ connectionUrl: '...' });
 83   * 
 84   * // Get schema information
 85   * const schema = await db.getSchema();
 86   * console.log('Tables:', schema.map(t => t.name));
 87   * 
 88   * // Query with context
 89   * const result = await db.query('How many orders were placed today?');
 90   * ```
 91   */
 92  export async function createNLPostgres(config: PostgresConfig): Promise<NLPostgresClient> {
 93    const client = new NLPostgresClient(config);
 94    await client.connect();
 95    return client;
 96  }
 97  
 98  export class NLPostgresClient {
 99    private config: PostgresConfig;
100    private pool: any = null;
101    private schema: TableSchema[] = [];
102    private connected = false;
103  
104    constructor(config: PostgresConfig) {
105      this.config = {
106        schema: 'public',
107        readOnly: true,
108        maxRows: 100,
109        timeout: 30000,
110        ...config,
111      };
112    }
113  
114    /**
115     * Connect to the database.
116     */
117    async connect(): Promise<void> {
118      if (this.connected) return;
119  
120      try {
121        // @ts-ignore - Optional dependency
122        const { Pool } = await import('pg');
123        this.pool = new Pool({
124          connectionString: this.config.connectionUrl,
125          statement_timeout: this.config.timeout,
126        });
127  
128        // Test connection
129        await this.pool.query('SELECT 1');
130        this.connected = true;
131  
132        // Load schema
133        await this.loadSchema();
134      } catch (error: any) {
135        throw new Error(
136          `Failed to connect to PostgreSQL: ${error.message}. ` +
137          'Install with: npm install pg'
138        );
139      }
140    }
141  
142    /**
143     * Disconnect from the database.
144     */
145    async disconnect(): Promise<void> {
146      if (this.pool) {
147        await this.pool.end();
148        this.connected = false;
149      }
150    }
151  
152    /**
153     * Load database schema.
154     */
155    private async loadSchema(): Promise<void> {
156      const schemaQuery = `
157        SELECT 
158          t.table_name,
159          c.column_name,
160          c.data_type,
161          c.is_nullable,
162          c.column_default,
163          pgd.description
164        FROM information_schema.tables t
165        JOIN information_schema.columns c 
166          ON t.table_name = c.table_name 
167          AND t.table_schema = c.table_schema
168        LEFT JOIN pg_catalog.pg_statio_all_tables st
169          ON st.relname = t.table_name
170        LEFT JOIN pg_catalog.pg_description pgd
171          ON pgd.objoid = st.relid
172          AND pgd.objsubid = c.ordinal_position
173        WHERE t.table_schema = $1
174          AND t.table_type = 'BASE TABLE'
175        ORDER BY t.table_name, c.ordinal_position
176      `;
177  
178      const result = await this.pool.query(schemaQuery, [this.config.schema]);
179      
180      const tableMap = new Map<string, TableSchema>();
181      
182      for (const row of result.rows) {
183        // Check allowed/blocked tables
184        if (this.config.allowedTables && !this.config.allowedTables.includes(row.table_name)) {
185          continue;
186        }
187        if (this.config.blockedTables && this.config.blockedTables.includes(row.table_name)) {
188          continue;
189        }
190  
191        if (!tableMap.has(row.table_name)) {
192          tableMap.set(row.table_name, {
193            name: row.table_name,
194            columns: [],
195          });
196        }
197  
198        const table = tableMap.get(row.table_name)!;
199        table.columns.push({
200          name: row.column_name,
201          type: row.data_type,
202          nullable: row.is_nullable === 'YES',
203          defaultValue: row.column_default,
204          description: row.description,
205        });
206      }
207  
208      this.schema = Array.from(tableMap.values());
209    }
210  
211    /**
212     * Get the database schema.
213     */
214    getSchema(): TableSchema[] {
215      return this.schema;
216    }
217  
218    /**
219     * Get schema as a string for LLM context.
220     */
221    getSchemaContext(): string {
222      let context = 'Database Schema:\n\n';
223      
224      for (const table of this.schema) {
225        context += `Table: ${table.name}\n`;
226        context += 'Columns:\n';
227        for (const col of table.columns) {
228          context += `  - ${col.name} (${col.type}${col.nullable ? ', nullable' : ''})`;
229          if (col.description) {
230            context += ` - ${col.description}`;
231          }
232          context += '\n';
233        }
234        context += '\n';
235      }
236      
237      return context;
238    }
239  
240    /**
241     * Execute a raw SQL query.
242     */
243    async executeSQL(sql: string): Promise<QueryResult> {
244      if (!this.connected) {
245        throw new Error('Not connected to database');
246      }
247  
248      // Safety check for read-only mode
249      if (this.config.readOnly) {
250        const normalizedSQL = sql.trim().toLowerCase();
251        const writeOperations = ['insert', 'update', 'delete', 'drop', 'alter', 'create', 'truncate'];
252        
253        for (const op of writeOperations) {
254          if (normalizedSQL.startsWith(op)) {
255            throw new Error(`Write operation '${op}' not allowed in read-only mode`);
256          }
257        }
258      }
259  
260      // Add LIMIT if not present
261      const hasLimit = /\blimit\s+\d+/i.test(sql);
262      const finalSQL = hasLimit ? sql : `${sql} LIMIT ${this.config.maxRows}`;
263  
264      const startTime = Date.now();
265      const result = await this.pool.query(finalSQL);
266      const executionTime = Date.now() - startTime;
267  
268      return {
269        rows: result.rows,
270        rowCount: result.rowCount,
271        fields: result.fields?.map((f: any) => ({ name: f.name, type: f.dataTypeID?.toString() })) || [],
272        executionTime,
273      };
274    }
275  
276    /**
277     * Query the database using natural language.
278     */
279    async query(naturalLanguageQuery: string, options?: { model?: string }): Promise<NLQueryResult> {
280      const model = options?.model || 'gpt-4o-mini';
281      
282      // Generate SQL from natural language
283      const { generateText } = await import('../ai/generate-text');
284      
285      const schemaContext = this.getSchemaContext();
286      const prompt = `You are a SQL expert. Convert the following natural language query to PostgreSQL SQL.
287  
288  ${schemaContext}
289  
290  Rules:
291  1. Only generate SELECT queries (read-only)
292  2. Use proper PostgreSQL syntax
293  3. Include appropriate JOINs when needed
294  4. Add reasonable LIMIT if not specified
295  5. Return ONLY the SQL query, no explanations
296  
297  Natural language query: ${naturalLanguageQuery}
298  
299  SQL:`;
300  
301      const result = await generateText({
302        model,
303        prompt,
304        temperature: 0,
305      });
306  
307      // Extract SQL from response
308      let sql = result.text.trim();
309      
310      // Remove markdown code blocks if present
311      if (sql.startsWith('```')) {
312        sql = sql.replace(/```sql?\n?/g, '').replace(/```/g, '').trim();
313      }
314  
315      // Execute the SQL
316      const queryResult = await this.executeSQL(sql);
317  
318      return {
319        query: naturalLanguageQuery,
320        sql,
321        result: queryResult,
322      };
323    }
324  
325    /**
326     * Chat with the database (conversational interface).
327     */
328    async chat(message: string, options?: { model?: string; history?: Array<{ role: string; content: string }> }): Promise<string> {
329      const model = options?.model || 'gpt-4o-mini';
330      const history = options?.history || [];
331      
332      const { generateText } = await import('../ai/generate-text');
333      
334      const schemaContext = this.getSchemaContext();
335      const systemPrompt = `You are a helpful database assistant. You can query a PostgreSQL database to answer questions.
336  
337  ${schemaContext}
338  
339  When the user asks a question that requires database data:
340  1. Generate a SQL query to get the data
341  2. Execute it using the query tool
342  3. Summarize the results in natural language
343  
344  Always be helpful and explain your findings clearly.`;
345  
346      const messages: Array<{ role: 'system' | 'user' | 'assistant'; content: string }> = [
347        { role: 'system', content: systemPrompt },
348        ...history.map(h => ({ role: h.role as 'user' | 'assistant', content: h.content })),
349        { role: 'user', content: message },
350      ];
351  
352      // Define the query tool
353      const tools = {
354        query_database: {
355          description: 'Execute a SQL query on the PostgreSQL database',
356          parameters: {
357            type: 'object',
358            properties: {
359              sql: {
360                type: 'string',
361                description: 'The SQL query to execute (SELECT only)',
362              },
363            },
364            required: ['sql'],
365          },
366          execute: async ({ sql }: { sql: string }) => {
367            try {
368              const result = await this.executeSQL(sql);
369              return JSON.stringify({
370                rowCount: result.rowCount,
371                rows: result.rows.slice(0, 10), // Limit for context
372                hasMore: result.rowCount > 10,
373              });
374            } catch (error: any) {
375              return JSON.stringify({ error: error.message });
376            }
377          },
378        },
379      };
380  
381      const result = await generateText({
382        model,
383        messages,
384        tools,
385        maxSteps: 3,
386      });
387  
388      return result.text;
389    }
390  
391    /**
392     * Inspect the database structure.
393     */
394    async inspect(): Promise<{
395      tables: number;
396      schema: TableSchema[];
397      sampleData: Record<string, any[]>;
398    }> {
399      const sampleData: Record<string, any[]> = {};
400      
401      for (const table of this.schema) {
402        try {
403          const result = await this.executeSQL(`SELECT * FROM ${table.name} LIMIT 3`);
404          sampleData[table.name] = result.rows;
405        } catch {
406          sampleData[table.name] = [];
407        }
408      }
409  
410      return {
411        tables: this.schema.length,
412        schema: this.schema,
413        sampleData,
414      };
415    }
416  }
417  
418  /**
419   * Create a tool for querying Postgres with natural language.
420   * 
421   * @example Use with an agent
422   * ```typescript
423   * import { Agent } from 'praisonai';
424   * import { createPostgresTool } from 'praisonai/integrations/postgres';
425   * 
426   * const dbTool = await createPostgresTool({
427   *   connectionUrl: process.env.DATABASE_URL!
428   * });
429   * 
430   * const agent = new Agent({
431   *   instructions: 'You can query the database',
432   *   tools: [dbTool]
433   * });
434   * ```
435   */
436  export async function createPostgresTool(config: PostgresConfig): Promise<any> {
437    const client = await createNLPostgres(config);
438    
439    return {
440      name: 'query_database',
441      description: `Query the PostgreSQL database using natural language. Available tables: ${client.getSchema().map(t => t.name).join(', ')}`,
442      parameters: {
443        type: 'object',
444        properties: {
445          query: {
446            type: 'string',
447            description: 'Natural language query describing what data you want',
448          },
449        },
450        required: ['query'],
451      },
452      execute: async ({ query }: { query: string }) => {
453        const result = await client.query(query);
454        return JSON.stringify({
455          sql: result.sql,
456          rowCount: result.result.rowCount,
457          rows: result.result.rows,
458        });
459      },
460    };
461  }