keychain.rs
  1  use std::collections::BTreeMap;
  2  use std::io::Write;
  3  
  4  use aleph_bft::Keychain as KeychainTrait;
  5  use fedimint_core::encoding::Encodable;
  6  use fedimint_core::session_outcome::SchnorrSignature;
  7  use fedimint_core::{secp256k1, BitcoinHash, NumPeersExt, PeerId};
  8  use secp256k1::hashes::sha256;
  9  use secp256k1::{schnorr, KeyPair, Message, PublicKey};
 10  
 11  use crate::config::ServerConfig;
 12  
 13  #[derive(Clone, Debug)]
 14  pub struct Keychain {
 15      peer_id: PeerId,
 16      public_keys: BTreeMap<PeerId, PublicKey>,
 17      keypair: KeyPair,
 18  }
 19  
 20  impl Keychain {
 21      pub fn new(cfg: &ServerConfig) -> Self {
 22          Keychain {
 23              peer_id: cfg.local.identity,
 24              public_keys: cfg.consensus.broadcast_public_keys.clone(),
 25              keypair: cfg
 26                  .private
 27                  .broadcast_secret_key
 28                  .keypair(secp256k1::SECP256K1),
 29          }
 30      }
 31  
 32      pub fn peer_id(&self) -> PeerId {
 33          self.peer_id
 34      }
 35  
 36      pub fn peer_count(&self) -> usize {
 37          self.public_keys.total()
 38      }
 39  
 40      pub fn threshold(&self) -> usize {
 41          self.public_keys.threshold()
 42      }
 43  
 44      fn tagged_hash(&self, message: &[u8]) -> Message {
 45          let mut engine = sha256::HashEngine::default();
 46  
 47          let public_key_tag = self.public_keys.consensus_hash::<sha256::Hash>();
 48  
 49          engine
 50              .write_all(public_key_tag.as_ref())
 51              .expect("Writing to a hash engine can not fail");
 52  
 53          engine
 54              .write_all(message)
 55              .expect("Writing to a hash engine can not fail");
 56  
 57          let hash = sha256::Hash::from_engine(engine);
 58  
 59          Message::from(hash)
 60      }
 61  }
 62  
 63  impl aleph_bft::Index for Keychain {
 64      fn index(&self) -> aleph_bft::NodeIndex {
 65          self.peer_id.to_usize().into()
 66      }
 67  }
 68  
 69  #[async_trait::async_trait]
 70  impl aleph_bft::Keychain for Keychain {
 71      type Signature = SchnorrSignature;
 72  
 73      fn node_count(&self) -> aleph_bft::NodeCount {
 74          self.peer_count().into()
 75      }
 76  
 77      fn sign(&self, message: &[u8]) -> Self::Signature {
 78          SchnorrSignature(
 79              self.keypair
 80                  .sign_schnorr(self.tagged_hash(message))
 81                  .as_ref()
 82                  .to_owned(),
 83          )
 84      }
 85  
 86      fn verify(
 87          &self,
 88          message: &[u8],
 89          signature: &Self::Signature,
 90          node_index: aleph_bft::NodeIndex,
 91      ) -> bool {
 92          let peer_id = super::to_peer_id(node_index);
 93  
 94          if let Some(public_key) = self.public_keys.get(&peer_id) {
 95              if let Ok(sig) = schnorr::Signature::from_slice(&signature.0) {
 96                  return secp256k1::SECP256K1
 97                      .verify_schnorr(
 98                          &sig,
 99                          &self.tagged_hash(message),
100                          &public_key.x_only_public_key().0,
101                      )
102                      .is_ok();
103              }
104          }
105  
106          false
107      }
108  }
109  
110  impl aleph_bft::MultiKeychain for Keychain {
111      type PartialMultisignature = aleph_bft::NodeMap<SchnorrSignature>;
112  
113      fn bootstrap_multi(
114          &self,
115          signature: &Self::Signature,
116          index: aleph_bft::NodeIndex,
117      ) -> Self::PartialMultisignature {
118          let mut partial = aleph_bft::NodeMap::with_size(self.peer_count().into());
119          partial.insert(index, signature.clone());
120          partial
121      }
122  
123      fn is_complete(&self, msg: &[u8], partial: &Self::PartialMultisignature) -> bool {
124          if partial.iter().count() < self.threshold() {
125              return false;
126          }
127  
128          partial.iter().all(|(i, sgn)| self.verify(msg, sgn, i))
129      }
130  }