/ bin / drk / src / walletdb.rs
walletdb.rs
  1  /* This file is part of DarkFi (https://dark.fi)
  2   *
  3   * Copyright (C) 2020-2025 Dyne.org foundation
  4   *
  5   * This program is free software: you can redistribute it and/or modify
  6   * it under the terms of the GNU Affero General Public License as
  7   * published by the Free Software Foundation, either version 3 of the
  8   * License, or (at your option) any later version.
  9   *
 10   * This program is distributed in the hope that it will be useful,
 11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
 12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 13   * GNU Affero General Public License for more details.
 14   *
 15   * You should have received a copy of the GNU Affero General Public License
 16   * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 17   */
 18  
 19  use std::{
 20      path::PathBuf,
 21      sync::{Arc, Mutex},
 22  };
 23  
 24  use rusqlite::{
 25      types::{ToSql, Value},
 26      Connection,
 27  };
 28  use tracing::{debug, error};
 29  
 30  use crate::error::{WalletDbError, WalletDbResult};
 31  
 32  pub type WalletPtr = Arc<WalletDb>;
 33  
 34  /// Structure representing base wallet database operations.
 35  pub struct WalletDb {
 36      /// Connection to the SQLite database.
 37      pub conn: Mutex<Connection>,
 38  }
 39  
 40  impl WalletDb {
 41      /// Create a new wallet database handler. If `path` is `None`, create it in memory.
 42      pub fn new(path: Option<PathBuf>, password: Option<&str>) -> WalletDbResult<WalletPtr> {
 43          let Ok(conn) = (match path.clone() {
 44              Some(p) => Connection::open(p),
 45              None => Connection::open_in_memory(),
 46          }) else {
 47              return Err(WalletDbError::ConnectionFailed);
 48          };
 49  
 50          if let Some(password) = password {
 51              if let Err(e) = conn.pragma_update(None, "key", password) {
 52                  error!(target: "walletdb::new", "[WalletDb] Pragma update failed: {e}");
 53                  return Err(WalletDbError::PragmaUpdateError);
 54              };
 55          }
 56          if let Err(e) = conn.pragma_update(None, "foreign_keys", "ON") {
 57              error!(target: "walletdb::new", "[WalletDb] Pragma update failed: {e}");
 58              return Err(WalletDbError::PragmaUpdateError);
 59          };
 60  
 61          debug!(target: "walletdb::new", "[WalletDb] Opened Sqlite connection at \"{path:?}\"");
 62          Ok(Arc::new(Self { conn: Mutex::new(conn) }))
 63      }
 64  
 65      /// This function executes a given SQL query that contains multiple SQL statements,
 66      /// that don't contain any parameters.
 67      pub fn exec_batch_sql(&self, query: &str) -> WalletDbResult<()> {
 68          debug!(target: "walletdb::exec_batch_sql", "[WalletDb] Executing batch SQL query:\n{query}");
 69          let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
 70          if let Err(e) = conn.execute_batch(query) {
 71              error!(target: "walletdb::exec_batch_sql", "[WalletDb] Query failed: {e}");
 72              return Err(WalletDbError::QueryExecutionFailed)
 73          };
 74  
 75          Ok(())
 76      }
 77  
 78      /// This function executes a given SQL query, but isn't able to return anything.
 79      /// Therefore it's best to use it for initializing a table or similar things.
 80      pub fn exec_sql(&self, query: &str, params: &[&dyn ToSql]) -> WalletDbResult<()> {
 81          debug!(target: "walletdb::exec_sql", "[WalletDb] Executing SQL query:\n{query}");
 82          let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
 83  
 84          // If no params are provided, execute directly
 85          if params.is_empty() {
 86              if let Err(e) = conn.execute(query, ()) {
 87                  error!(target: "walletdb::exec_sql", "[WalletDb] Query failed: {e}");
 88                  return Err(WalletDbError::QueryExecutionFailed)
 89              };
 90              return Ok(())
 91          }
 92  
 93          // First we prepare the query
 94          let Ok(mut stmt) = conn.prepare(query) else {
 95              return Err(WalletDbError::QueryPreparationFailed)
 96          };
 97  
 98          // Execute the query using provided params
 99          if let Err(e) = stmt.execute(params) {
100              error!(target: "walletdb::exec_sql", "[WalletDb] Query failed: {e}");
101              return Err(WalletDbError::QueryExecutionFailed)
102          };
103  
104          // Finalize query and drop connection lock
105          if let Err(e) = stmt.finalize() {
106              error!(target: "walletdb::exec_sql", "[WalletDb] Query finalization failed: {e}");
107              return Err(WalletDbError::QueryFinalizationFailed)
108          };
109          drop(conn);
110  
111          Ok(())
112      }
113  
114      /// Generate a new statement for provided query and bind the provided params,
115      /// returning the raw SQL query as a string.
116      pub fn create_prepared_statement(
117          &self,
118          query: &str,
119          params: &[&dyn ToSql],
120      ) -> WalletDbResult<String> {
121          debug!(target: "walletdb::create_prepared_statement", "[WalletDb] Preparing statement for SQL query:\n{query}");
122          let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
123  
124          // First we prepare the query
125          let Ok(mut stmt) = conn.prepare(query) else {
126              return Err(WalletDbError::QueryPreparationFailed)
127          };
128  
129          // Bind all provided params
130          for (index, param) in params.iter().enumerate() {
131              if stmt.raw_bind_parameter(index + 1, param).is_err() {
132                  return Err(WalletDbError::QueryPreparationFailed)
133              };
134          }
135  
136          // Grab the raw SQL
137          let query = stmt.expanded_sql().unwrap();
138  
139          // Drop statement and the connection lock
140          drop(stmt);
141          drop(conn);
142  
143          Ok(query)
144      }
145  
146      /// Generate a `SELECT` query for provided table from selected column names and
147      /// provided `WHERE` clauses. Named parameters are supported in the `WHERE` clauses,
148      /// assuming they follow the normal formatting ":{column_name}".
149      fn generate_select_query(
150          &self,
151          table: &str,
152          col_names: &[&str],
153          params: &[(&str, &dyn ToSql)],
154      ) -> String {
155          let mut query = if col_names.is_empty() {
156              format!("SELECT * FROM {table}")
157          } else {
158              format!("SELECT {} FROM {table}", col_names.join(", "))
159          };
160          if params.is_empty() {
161              return query
162          }
163  
164          let mut where_str = Vec::with_capacity(params.len());
165          for (k, _) in params {
166              let col = &k[1..];
167              where_str.push(format!("{col} = {k}"));
168          }
169          query.push_str(&format!(" WHERE {}", where_str.join(" AND ")));
170  
171          query
172      }
173  
174      /// Query provided table from selected column names and provided `WHERE` clauses,
175      /// for a single row.
176      pub fn query_single(
177          &self,
178          table: &str,
179          col_names: &[&str],
180          params: &[(&str, &dyn ToSql)],
181      ) -> WalletDbResult<Vec<Value>> {
182          // Generate `SELECT` query
183          let query = self.generate_select_query(table, col_names, params);
184          debug!(target: "walletdb::query_single", "[WalletDb] Executing SQL query:\n{query}");
185  
186          // First we prepare the query
187          let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
188  
189          let Ok(mut stmt) = conn.prepare(&query) else {
190              return Err(WalletDbError::QueryPreparationFailed)
191          };
192  
193          // Execute the query using provided params
194          let Ok(mut rows) = stmt.query(params) else {
195              return Err(WalletDbError::QueryExecutionFailed)
196          };
197  
198          // Check if row exists
199          let Ok(next) = rows.next() else { return Err(WalletDbError::QueryExecutionFailed) };
200          let row = match next {
201              Some(row_result) => row_result,
202              None => return Err(WalletDbError::RowNotFound),
203          };
204  
205          // Grab returned values
206          let mut result = vec![];
207          if col_names.is_empty() {
208              let mut idx = 0;
209              loop {
210                  let Ok(value) = row.get(idx) else { break };
211                  result.push(value);
212                  idx += 1;
213              }
214          } else {
215              for col in col_names {
216                  let Ok(value) = row.get(*col) else {
217                      return Err(WalletDbError::ParseColumnValueError)
218                  };
219                  result.push(value);
220              }
221          }
222  
223          Ok(result)
224      }
225  
226      /// Query provided table from selected column names and provided `WHERE` clauses,
227      /// for multiple rows.
228      pub fn query_multiple(
229          &self,
230          table: &str,
231          col_names: &[&str],
232          params: &[(&str, &dyn ToSql)],
233      ) -> WalletDbResult<Vec<Vec<Value>>> {
234          // Generate `SELECT` query
235          let query = self.generate_select_query(table, col_names, params);
236          debug!(target: "walletdb::query_multiple", "[WalletDb] Executing SQL query:\n{query}");
237  
238          // First we prepare the query
239          let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
240          let Ok(mut stmt) = conn.prepare(&query) else {
241              return Err(WalletDbError::QueryPreparationFailed)
242          };
243  
244          // Execute the query using provided converted params
245          let Ok(mut rows) = stmt.query(params) else {
246              return Err(WalletDbError::QueryExecutionFailed)
247          };
248  
249          // Loop over returned rows and parse them
250          let mut result = vec![];
251          loop {
252              // Check if an error occured
253              let row = match rows.next() {
254                  Ok(r) => r,
255                  Err(_) => return Err(WalletDbError::QueryExecutionFailed),
256              };
257  
258              // Check if no row was returned
259              let row = match row {
260                  Some(r) => r,
261                  None => break,
262              };
263  
264              // Grab row returned values
265              let mut row_values = vec![];
266              if col_names.is_empty() {
267                  let mut idx = 0;
268                  loop {
269                      let Ok(value) = row.get(idx) else { break };
270                      row_values.push(value);
271                      idx += 1;
272                  }
273              } else {
274                  for col in col_names {
275                      let Ok(value) = row.get(*col) else {
276                          return Err(WalletDbError::ParseColumnValueError)
277                      };
278                      row_values.push(value);
279                  }
280              }
281              result.push(row_values);
282          }
283  
284          Ok(result)
285      }
286  
287      /// Query provided table using provided query for multiple rows.
288      pub fn query_custom(
289          &self,
290          query: &str,
291          params: &[&dyn ToSql],
292      ) -> WalletDbResult<Vec<Vec<Value>>> {
293          debug!(target: "walletdb::query_custom", "[WalletDb] Executing SQL query:\n{query}");
294  
295          // First we prepare the query
296          let Ok(conn) = self.conn.lock() else { return Err(WalletDbError::FailedToAquireLock) };
297          let Ok(mut stmt) = conn.prepare(query) else {
298              return Err(WalletDbError::QueryPreparationFailed)
299          };
300  
301          // Execute the query using provided converted params
302          let Ok(mut rows) = stmt.query(params) else {
303              return Err(WalletDbError::QueryExecutionFailed)
304          };
305  
306          // Loop over returned rows and parse them
307          let mut result = vec![];
308          loop {
309              // Check if an error occured
310              let row = match rows.next() {
311                  Ok(r) => r,
312                  Err(_) => return Err(WalletDbError::QueryExecutionFailed),
313              };
314  
315              // Check if no row was returned
316              let row = match row {
317                  Some(r) => r,
318                  None => break,
319              };
320  
321              // Grab row returned values
322              let mut row_values = vec![];
323              let mut idx = 0;
324              loop {
325                  let Ok(value) = row.get(idx) else { break };
326                  row_values.push(value);
327                  idx += 1;
328              }
329              result.push(row_values);
330          }
331  
332          Ok(result)
333      }
334  }
335  
336  /// Custom implementation of rusqlite::named_params! to use `expr` instead of `literal` as `$param_name`,
337  /// and append the ":" named parameters prefix.
338  #[macro_export]
339  macro_rules! convert_named_params {
340      () => {
341          &[] as &[(&str, &dyn rusqlite::types::ToSql)]
342      };
343      ($(($param_name:expr, $param_val:expr)),+ $(,)?) => {
344          &[$((format!(":{}", $param_name).as_str(), &$param_val as &dyn rusqlite::types::ToSql)),+] as &[(&str, &dyn rusqlite::types::ToSql)]
345      };
346  }
347  
348  #[cfg(test)]
349  mod tests {
350      use rusqlite::types::Value;
351  
352      use crate::walletdb::WalletDb;
353  
354      #[test]
355      fn test_mem_wallet() {
356          let wallet = WalletDb::new(None, Some("foobar")).unwrap();
357          wallet
358              .exec_batch_sql(
359                  "CREATE TABLE mista ( numba INTEGER ); INSERT INTO mista ( numba ) VALUES ( 42 );",
360              )
361              .unwrap();
362  
363          let ret = wallet.query_single("mista", &["numba"], &[]).unwrap();
364          assert_eq!(ret.len(), 1);
365          let numba: i64 = if let Value::Integer(numba) = ret[0] { numba } else { -1 };
366          assert_eq!(numba, 42);
367  
368          let ret = wallet.query_custom("SELECT numba FROM mista;", &[]).unwrap();
369          assert_eq!(ret.len(), 1);
370          assert_eq!(ret[0].len(), 1);
371          let numba: i64 = if let Value::Integer(numba) = ret[0][0] { numba } else { -1 };
372          assert_eq!(numba, 42);
373      }
374  
375      #[test]
376      fn test_query_single() {
377          let wallet = WalletDb::new(None, None).unwrap();
378          wallet
379              .exec_batch_sql("CREATE TABLE mista ( why INTEGER, are TEXT, you INTEGER, gae BLOB );")
380              .unwrap();
381  
382          let why = 42;
383          let are = "are".to_string();
384          let you = 69;
385          let gae = vec![42u8; 32];
386  
387          wallet
388              .exec_sql(
389                  "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
390                  rusqlite::params![why, are, you, gae],
391              )
392              .unwrap();
393  
394          let ret = wallet.query_single("mista", &["why", "are", "you", "gae"], &[]).unwrap();
395          assert_eq!(ret.len(), 4);
396          assert_eq!(ret[0], Value::Integer(why));
397          assert_eq!(ret[1], Value::Text(are.clone()));
398          assert_eq!(ret[2], Value::Integer(you));
399          assert_eq!(ret[3], Value::Blob(gae.clone()));
400          let ret = wallet.query_custom("SELECT why, are, you, gae FROM mista;", &[]).unwrap();
401          assert_eq!(ret.len(), 1);
402          assert_eq!(ret[0].len(), 4);
403          assert_eq!(ret[0][0], Value::Integer(why));
404          assert_eq!(ret[0][1], Value::Text(are.clone()));
405          assert_eq!(ret[0][2], Value::Integer(you));
406          assert_eq!(ret[0][3], Value::Blob(gae.clone()));
407  
408          let ret = wallet
409              .query_single(
410                  "mista",
411                  &["gae"],
412                  rusqlite::named_params! {":why": why, ":are": are, ":you": you},
413              )
414              .unwrap();
415          assert_eq!(ret.len(), 1);
416          assert_eq!(ret[0], Value::Blob(gae.clone()));
417          let ret = wallet
418              .query_custom(
419                  "SELECT gae FROM mista WHERE why = ?1 AND are = ?2 AND you = ?3;",
420                  rusqlite::params![why, are, you],
421              )
422              .unwrap();
423          assert_eq!(ret.len(), 1);
424          assert_eq!(ret[0].len(), 1);
425          assert_eq!(ret[0][0], Value::Blob(gae));
426      }
427  
428      #[test]
429      fn test_query_multi() {
430          let wallet = WalletDb::new(None, None).unwrap();
431          wallet
432              .exec_batch_sql("CREATE TABLE mista ( why INTEGER, are TEXT, you INTEGER, gae BLOB );")
433              .unwrap();
434  
435          let why = 42;
436          let are = "are".to_string();
437          let you = 69;
438          let gae = vec![42u8; 32];
439  
440          wallet
441              .exec_sql(
442                  "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
443                  rusqlite::params![why, are, you, gae],
444              )
445              .unwrap();
446          wallet
447              .exec_sql(
448                  "INSERT INTO mista ( why, are, you, gae ) VALUES (?1, ?2, ?3, ?4);",
449                  rusqlite::params![why, are, you, gae],
450              )
451              .unwrap();
452  
453          let ret = wallet.query_multiple("mista", &[], &[]).unwrap();
454          assert_eq!(ret.len(), 2);
455          for row in ret {
456              assert_eq!(row.len(), 4);
457              assert_eq!(row[0], Value::Integer(why));
458              assert_eq!(row[1], Value::Text(are.clone()));
459              assert_eq!(row[2], Value::Integer(you));
460              assert_eq!(row[3], Value::Blob(gae.clone()));
461          }
462          let ret = wallet.query_custom("SELECT * FROM mista;", &[]).unwrap();
463          assert_eq!(ret.len(), 2);
464          for row in ret {
465              assert_eq!(row.len(), 4);
466              assert_eq!(row[0], Value::Integer(why));
467              assert_eq!(row[1], Value::Text(are.clone()));
468              assert_eq!(row[2], Value::Integer(you));
469              assert_eq!(row[3], Value::Blob(gae.clone()));
470          }
471  
472          let ret = wallet
473              .query_multiple(
474                  "mista",
475                  &["gae"],
476                  convert_named_params! {("why", why), ("are", are), ("you", you)},
477              )
478              .unwrap();
479          assert_eq!(ret.len(), 2);
480          for row in ret {
481              assert_eq!(row.len(), 1);
482              assert_eq!(row[0], Value::Blob(gae.clone()));
483          }
484          let ret = wallet
485              .query_custom(
486                  "SELECT gae FROM mista WHERE why = ?1 AND are = ?2 AND you = ?3;",
487                  rusqlite::params![why, are, you],
488              )
489              .unwrap();
490          assert_eq!(ret.len(), 2);
491          for row in ret {
492              assert_eq!(row.len(), 1);
493              assert_eq!(row[0], Value::Blob(gae.clone()));
494          }
495      }
496  }