/ radicle / src / node / routing.rs
routing.rs
  1  use std::collections::HashSet;
  2  use std::path::Path;
  3  use std::{fmt, time};
  4  
  5  use sqlite as sql;
  6  use thiserror::Error;
  7  
  8  use crate::{
  9      prelude::Timestamp,
 10      prelude::{Id, NodeId},
 11      sql::transaction,
 12  };
 13  
 14  /// How long to wait for the database lock to be released before failing a read.
 15  const DB_READ_TIMEOUT: time::Duration = time::Duration::from_secs(3);
 16  /// How long to wait for the database lock to be released before failing a write.
 17  const DB_WRITE_TIMEOUT: time::Duration = time::Duration::from_secs(6);
 18  
 19  /// Result of inserting into the routing table.
 20  #[derive(Debug, Copy, Clone, PartialEq, Eq)]
 21  pub enum InsertResult {
 22      /// Nothing was updated.
 23      NotUpdated,
 24      /// The entry's timestamp was updated.
 25      TimeUpdated,
 26      /// A new entry was inserted.
 27      SeedAdded,
 28  }
 29  
 30  /// An error occuring in peer-to-peer networking code.
 31  #[derive(Error, Debug)]
 32  pub enum Error {
 33      /// An Internal error.
 34      #[error("internal error: {0}")]
 35      Internal(#[from] sql::Error),
 36      /// Internal unit overflow.
 37      #[error("the unit overflowed")]
 38      UnitOverflow,
 39  }
 40  
 41  /// Persistent file storage for a routing table.
 42  pub struct Table {
 43      db: sql::Connection,
 44  }
 45  
 46  impl fmt::Debug for Table {
 47      fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 48          write!(f, "Table(..)")
 49      }
 50  }
 51  
 52  impl Table {
 53      const SCHEMA: &str = include_str!("routing/schema.sql");
 54  
 55      /// Open a routing file store at the given path. Creates a new empty store
 56      /// if an existing store isn't found.
 57      pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
 58          let mut db = sql::Connection::open(path)?;
 59          db.set_busy_timeout(DB_WRITE_TIMEOUT.as_millis() as usize)?;
 60          db.execute(Self::SCHEMA)?;
 61  
 62          Ok(Self { db })
 63      }
 64  
 65      /// Same as [`Self::open`], but in read-only mode. This is useful to have multiple
 66      /// open databases, as no locking is required.
 67      pub fn reader<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
 68          let mut db =
 69              sql::Connection::open_with_flags(path, sqlite::OpenFlags::new().set_read_only())?;
 70          db.set_busy_timeout(DB_READ_TIMEOUT.as_millis() as usize)?;
 71          db.execute(Self::SCHEMA)?;
 72  
 73          Ok(Self { db })
 74      }
 75  
 76      /// Create a new in-memory routing table.
 77      pub fn memory() -> Result<Self, Error> {
 78          let db = sql::Connection::open(":memory:")?;
 79          db.execute(Self::SCHEMA)?;
 80  
 81          Ok(Self { db })
 82      }
 83  }
 84  
 85  /// Backing store for a routing table.
 86  pub trait Store {
 87      /// Get the nodes seeding the given id.
 88      fn get(&self, id: &Id) -> Result<HashSet<NodeId>, Error>;
 89      /// Get the resources seeded by the given node.
 90      fn get_resources(&self, node_id: &NodeId) -> Result<HashSet<Id>, Error>;
 91      /// Get a specific entry.
 92      fn entry(&self, id: &Id, node: &NodeId) -> Result<Option<Timestamp>, Error>;
 93      /// Checks if any entries are available.
 94      fn is_empty(&self) -> Result<bool, Error> {
 95          Ok(self.len()? == 0)
 96      }
 97      /// Add a new node seeding the given id.
 98      fn insert<'a>(
 99          &mut self,
100          ids: impl IntoIterator<Item = &'a Id>,
101          node: NodeId,
102          time: Timestamp,
103      ) -> Result<Vec<(Id, InsertResult)>, Error>;
104      /// Remove a node for the given id.
105      fn remove(&mut self, id: &Id, node: &NodeId) -> Result<bool, Error>;
106      /// Iterate over all entries in the routing table.
107      fn entries(&self) -> Result<Box<dyn Iterator<Item = (Id, NodeId)>>, Error>;
108      /// Get the total number of routing entries.
109      fn len(&self) -> Result<usize, Error>;
110      /// Prune entries older than the given timestamp.
111      fn prune(&mut self, oldest: Timestamp, limit: Option<usize>) -> Result<usize, Error>;
112      /// Count the number of routes for a specific repo RID.
113      fn count(&self, id: &Id) -> Result<usize, Error>;
114  }
115  
116  impl Store for Table {
117      fn get(&self, id: &Id) -> Result<HashSet<NodeId>, Error> {
118          let mut stmt = self
119              .db
120              .prepare("SELECT (node) FROM routing WHERE resource = ?")?;
121          stmt.bind((1, id))?;
122  
123          let mut nodes = HashSet::new();
124          for row in stmt.into_iter() {
125              nodes.insert(row?.read::<NodeId, _>("node"));
126          }
127          Ok(nodes)
128      }
129  
130      fn get_resources(&self, node: &NodeId) -> Result<HashSet<Id>, Error> {
131          let mut stmt = self
132              .db
133              .prepare("SELECT resource FROM routing WHERE node = ?")?;
134          stmt.bind((1, node))?;
135  
136          let mut resources = HashSet::new();
137          for row in stmt.into_iter() {
138              resources.insert(row?.read::<Id, _>("resource"));
139          }
140          Ok(resources)
141      }
142  
143      fn entry(&self, id: &Id, node: &NodeId) -> Result<Option<Timestamp>, Error> {
144          let mut stmt = self
145              .db
146              .prepare("SELECT (time) FROM routing WHERE resource = ? AND node = ?")?;
147  
148          stmt.bind((1, id))?;
149          stmt.bind((2, node))?;
150  
151          if let Some(Ok(row)) = stmt.into_iter().next() {
152              return Ok(Some(row.read::<i64, _>("time") as Timestamp));
153          }
154          Ok(None)
155      }
156  
157      fn insert<'a>(
158          &mut self,
159          ids: impl IntoIterator<Item = &'a Id>,
160          node: NodeId,
161          time: Timestamp,
162      ) -> Result<Vec<(Id, InsertResult)>, Error> {
163          let time: i64 = time.try_into().map_err(|_| Error::UnitOverflow)?;
164          let mut results = Vec::new();
165  
166          transaction(&self.db, |db| {
167              for id in ids.into_iter() {
168                  let mut stmt =
169                      db.prepare("SELECT (time) FROM routing WHERE resource = ? AND node = ?")?;
170  
171                  stmt.bind((1, id))?;
172                  stmt.bind((2, &node))?;
173  
174                  let existed = stmt.into_iter().next().is_some();
175                  let mut stmt = db.prepare(
176                      "INSERT INTO routing (resource, node, time)
177                       VALUES (?, ?, ?)
178                       ON CONFLICT DO UPDATE
179                       SET time = ?3
180                       WHERE time < ?3",
181                  )?;
182  
183                  stmt.bind((1, id))?;
184                  stmt.bind((2, &node))?;
185                  stmt.bind((3, time))?;
186                  stmt.next()?;
187  
188                  let result = match (self.db.change_count() > 0, existed) {
189                      (true, true) => InsertResult::TimeUpdated,
190                      (true, false) => InsertResult::SeedAdded,
191                      (false, _) => InsertResult::NotUpdated,
192                  };
193                  results.push((*id, result));
194              }
195              Ok(results)
196          })
197          .map_err(Error::from)
198      }
199  
200      fn entries(&self) -> Result<Box<dyn Iterator<Item = (Id, NodeId)>>, Error> {
201          let mut stmt = self
202              .db
203              .prepare("SELECT resource, node FROM routing ORDER BY resource")?
204              .into_iter();
205          let mut entries = Vec::new();
206  
207          while let Some(Ok(row)) = stmt.next() {
208              let id = row.read("resource");
209              let node = row.read("node");
210  
211              entries.push((id, node));
212          }
213          Ok(Box::new(entries.into_iter()))
214      }
215  
216      fn remove(&mut self, id: &Id, node: &NodeId) -> Result<bool, Error> {
217          let mut stmt = self
218              .db
219              .prepare("DELETE FROM routing WHERE resource = ? AND node = ?")?;
220  
221          stmt.bind((1, id))?;
222          stmt.bind((2, node))?;
223          stmt.next()?;
224  
225          Ok(self.db.change_count() > 0)
226      }
227  
228      fn len(&self) -> Result<usize, Error> {
229          let stmt = self.db.prepare("SELECT COUNT(1) FROM routing")?;
230          let count: i64 = stmt
231              .into_iter()
232              .next()
233              .expect("COUNT will always return a single row")?
234              .read(0);
235          let count: usize = count.try_into().map_err(|_| Error::UnitOverflow)?;
236          Ok(count)
237      }
238  
239      fn prune(&mut self, oldest: Timestamp, limit: Option<usize>) -> Result<usize, Error> {
240          let oldest: i64 = oldest.try_into().map_err(|_| Error::UnitOverflow)?;
241          let limit: i64 = limit
242              .unwrap_or(i64::MAX as usize)
243              .try_into()
244              .map_err(|_| Error::UnitOverflow)?;
245  
246          let mut stmt = self.db.prepare(
247              "DELETE FROM routing WHERE rowid IN
248              (SELECT rowid FROM routing WHERE time < ? LIMIT ?)",
249          )?;
250          stmt.bind((1, oldest))?;
251          stmt.bind((2, limit))?;
252          stmt.next()?;
253  
254          Ok(self.db.change_count())
255      }
256  
257      fn count(&self, id: &Id) -> Result<usize, Error> {
258          let mut stmt = self
259              .db
260              .prepare("SELECT COUNT(*) FROM routing WHERE resource = ?")?;
261  
262          stmt.bind((1, id))?;
263  
264          let count: i64 = stmt
265              .into_iter()
266              .next()
267              .expect("COUNT will always return a single row")?
268              .read(0);
269  
270          let count: usize = count.try_into().map_err(|_| Error::UnitOverflow)?;
271  
272          Ok(count)
273      }
274  }
275  
276  #[cfg(test)]
277  mod test {
278      use localtime::LocalTime;
279  
280      use super::*;
281      use crate::test::arbitrary;
282  
283      #[test]
284      fn test_insert_and_get() {
285          let ids = arbitrary::set::<Id>(5..10);
286          let nodes = arbitrary::set::<NodeId>(5..10);
287          let mut db = Table::open(":memory:").unwrap();
288  
289          for node in &nodes {
290              assert_eq!(
291                  db.insert(&ids, *node, 0).unwrap(),
292                  ids.iter()
293                      .map(|id| (*id, InsertResult::SeedAdded))
294                      .collect::<Vec<_>>()
295              );
296          }
297  
298          for id in &ids {
299              let seeds = db.get(id).unwrap();
300              for node in &nodes {
301                  assert!(seeds.contains(node));
302              }
303          }
304      }
305  
306      #[test]
307      fn test_insert_and_get_resources() {
308          let ids = arbitrary::set::<Id>(5..10);
309          let nodes = arbitrary::set::<NodeId>(5..10);
310          let mut db = Table::open(":memory:").unwrap();
311  
312          for node in &nodes {
313              db.insert(&ids, *node, 0).unwrap();
314          }
315  
316          for node in &nodes {
317              let projects = db.get_resources(node).unwrap();
318              for id in &ids {
319                  assert!(projects.contains(id));
320              }
321          }
322      }
323  
324      #[test]
325      fn test_entries() {
326          let ids = arbitrary::set::<Id>(6..9);
327          let nodes = arbitrary::set::<NodeId>(6..9);
328          let mut db = Table::open(":memory:").unwrap();
329  
330          for node in &nodes {
331              assert!(db
332                  .insert(&ids, *node, 0)
333                  .unwrap()
334                  .iter()
335                  .all(|(_, r)| *r == InsertResult::SeedAdded));
336          }
337  
338          let results = db.entries().unwrap().collect::<Vec<_>>();
339          assert_eq!(results.len(), ids.len() * nodes.len());
340  
341          let mut results_ids = results.iter().map(|(id, _)| *id).collect::<Vec<_>>();
342          results_ids.dedup();
343  
344          assert_eq!(results_ids.len(), ids.len(), "Entries are grouped by id");
345      }
346  
347      #[test]
348      fn test_insert_and_remove() {
349          let ids = arbitrary::set::<Id>(5..10);
350          let nodes = arbitrary::set::<NodeId>(5..10);
351          let mut db = Table::open(":memory:").unwrap();
352  
353          for node in &nodes {
354              db.insert(&ids, *node, 0).unwrap();
355          }
356          for id in &ids {
357              for node in &nodes {
358                  assert!(db.remove(id, node).unwrap());
359              }
360          }
361          for id in &ids {
362              assert!(db.get(id).unwrap().is_empty());
363          }
364      }
365  
366      #[test]
367      fn test_insert_duplicate() {
368          let id = arbitrary::gen::<Id>(1);
369          let node = arbitrary::gen::<NodeId>(1);
370          let mut db = Table::open(":memory:").unwrap();
371  
372          assert_eq!(
373              db.insert([&id], node, 0).unwrap(),
374              vec![(id, InsertResult::SeedAdded)]
375          );
376          assert_eq!(
377              db.insert([&id], node, 0).unwrap(),
378              vec![(id, InsertResult::NotUpdated)]
379          );
380          assert_eq!(
381              db.insert([&id], node, 0).unwrap(),
382              vec![(id, InsertResult::NotUpdated)]
383          );
384      }
385  
386      #[test]
387      fn test_insert_existing_updated_time() {
388          let id = arbitrary::gen::<Id>(1);
389          let node = arbitrary::gen::<NodeId>(1);
390          let mut db = Table::open(":memory:").unwrap();
391  
392          assert_eq!(
393              db.insert([&id], node, 0).unwrap(),
394              vec![(id, InsertResult::SeedAdded)]
395          );
396          assert_eq!(
397              db.insert([&id], node, 1).unwrap(),
398              vec![(id, InsertResult::TimeUpdated)]
399          );
400          assert_eq!(db.entry(&id, &node).unwrap(), Some(1));
401      }
402  
403      #[test]
404      fn test_update_existing_multi() {
405          let id1 = arbitrary::gen::<Id>(1);
406          let id2 = arbitrary::gen::<Id>(1);
407          let node = arbitrary::gen::<NodeId>(1);
408          let mut db = Table::open(":memory:").unwrap();
409  
410          assert_eq!(
411              db.insert([&id1], node, 0).unwrap(),
412              vec![(id1, InsertResult::SeedAdded)]
413          );
414          assert_eq!(
415              db.insert([&id1, &id2], node, 0).unwrap(),
416              vec![
417                  (id1, InsertResult::NotUpdated),
418                  (id2, InsertResult::SeedAdded)
419              ]
420          );
421          assert_eq!(
422              db.insert([&id1, &id2], node, 1).unwrap(),
423              vec![
424                  (id1, InsertResult::TimeUpdated),
425                  (id2, InsertResult::TimeUpdated)
426              ]
427          );
428      }
429  
430      #[test]
431      fn test_remove_redundant() {
432          let id = arbitrary::gen::<Id>(1);
433          let node = arbitrary::gen::<NodeId>(1);
434          let mut db = Table::open(":memory:").unwrap();
435  
436          assert_eq!(
437              db.insert([&id], node, 0).unwrap(),
438              vec![(id, InsertResult::SeedAdded)]
439          );
440          assert!(db.remove(&id, &node).unwrap());
441          assert!(!db.remove(&id, &node).unwrap());
442      }
443  
444      #[test]
445      fn test_len() {
446          let mut db = Table::open(":memory:").unwrap();
447          let ids = arbitrary::vec::<Id>(10);
448          let node = arbitrary::gen(1);
449  
450          db.insert(&ids, node, LocalTime::now().as_millis()).unwrap();
451  
452          assert_eq!(10, db.len().unwrap(), "correct number of rows in table");
453      }
454  
455      #[test]
456      fn test_prune() {
457          let mut rng = fastrand::Rng::new();
458          let now = LocalTime::now();
459          let ids = arbitrary::vec::<Id>(10);
460          let nodes = arbitrary::vec::<NodeId>(10);
461          let mut db = Table::open(":memory:").unwrap();
462  
463          for node in &nodes {
464              let time = rng.u64(..now.as_millis());
465              db.insert(&ids, *node, time).unwrap();
466          }
467  
468          let ids = arbitrary::vec::<Id>(10);
469          let nodes = arbitrary::vec::<NodeId>(10);
470  
471          for node in &nodes {
472              let time = rng.u64(now.as_millis()..i64::MAX as u64);
473              db.insert(&ids, *node, time).unwrap();
474          }
475  
476          let pruned = db.prune(now.as_millis(), None).unwrap();
477          assert_eq!(pruned, ids.len() * nodes.len());
478  
479          for id in &ids {
480              for node in &nodes {
481                  let t = db.entry(id, node).unwrap().unwrap();
482                  assert!(t >= now.as_millis());
483              }
484          }
485      }
486  
487      #[test]
488      fn test_count() {
489          let id = arbitrary::gen::<Id>(1);
490          let nodes = arbitrary::set::<NodeId>(5..10);
491          let mut db = Table::open(":memory:").unwrap();
492  
493          for node in &nodes {
494              db.insert([&id], *node, 0).unwrap();
495          }
496          assert_eq!(db.count(&id).unwrap(), nodes.len());
497      }
498  }