/ node / router / src / handshake.rs
handshake.rs
  1  // Copyright (c) 2025 ADnet Contributors
  2  // This file is part of the AlphaOS library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  use crate::{
 17      ConnectionMode,
 18      NodeType,
 19      PeerPoolHandling,
 20      Router,
 21      messages::{ChallengeRequest, ChallengeResponse, DisconnectReason, Message, MessageCodec, MessageTrait},
 22  };
 23  use alphaos_node_network::{get_repo_commit_hash, log_repo_sha_comparison};
 24  use alphaos_node_tcp::{ConnectionSide, P2P, Tcp};
 25  use alphavm::{
 26      ledger::narwhal::Data,
 27      prelude::{Address, ConsensusVersion, Field, Network, block::Header, error, io_error},
 28  };
 29  
 30  use anyhow::{Result, bail};
 31  use futures::SinkExt;
 32  use rand::{Rng, rngs::OsRng};
 33  use std::{io, net::SocketAddr};
 34  use tokio::net::TcpStream;
 35  use tokio_stream::StreamExt;
 36  use tokio_util::codec::Framed;
 37  
 38  impl<N: Network> P2P for Router<N> {
 39      /// Returns a reference to the TCP instance.
 40      fn tcp(&self) -> &Tcp {
 41          &self.tcp
 42      }
 43  }
 44  
 45  /// A macro unwrapping the expected handshake message or returning an error for unexpected messages.
 46  #[macro_export]
 47  macro_rules! expect_message {
 48      ($msg_ty:path, $framed:expr, $peer_addr:expr) => {{
 49          use alphavm::utilities::io_error;
 50  
 51          match $framed.try_next().await? {
 52              // Received the expected message, proceed.
 53              Some($msg_ty(data)) => {
 54                  trace!("Received '{}' from '{}'", data.name(), $peer_addr);
 55                  data
 56              }
 57              // Received a disconnect message, abort.
 58              Some(Message::Disconnect($crate::messages::Disconnect { reason })) => {
 59                  return Err(io_error(format!("'{}' disconnected: {reason}", $peer_addr)));
 60              }
 61              // Received an unexpected message, abort.
 62              Some(ty) => {
 63                  return Err(io_error(format!(
 64                      "'{}' did not follow the handshake protocol: received {:?} instead of {}",
 65                      $peer_addr,
 66                      ty.name(),
 67                      stringify!($msg_ty),
 68                  )));
 69              }
 70              // Received nothing.
 71              None => {
 72                  return Err(io_error(format!(
 73                      "the peer disconnected before sending {:?}, likely due to peer saturation or shutdown",
 74                      stringify!($msg_ty),
 75                  )));
 76              }
 77          }
 78      }};
 79  }
 80  
 81  /// Send the given message to the peer.
 82  async fn send<N: Network>(
 83      framed: &mut Framed<&mut TcpStream, MessageCodec<N>>,
 84      peer_addr: SocketAddr,
 85      message: Message<N>,
 86  ) -> io::Result<()> {
 87      trace!("Sending '{}' to '{peer_addr}'", message.name());
 88      framed.send(message).await
 89  }
 90  
 91  impl<N: Network> Router<N> {
 92      /// Executes the handshake protocol.
 93      pub async fn handshake<'a>(
 94          &'a self,
 95          peer_addr: SocketAddr,
 96          stream: &'a mut TcpStream,
 97          peer_side: ConnectionSide,
 98          genesis_header: Header<N>,
 99          restrictions_id: Field<N>,
100      ) -> io::Result<Option<ChallengeRequest<N>>> {
101          // If this is an inbound connection, we log it, but don't know the listening address yet.
102          // Otherwise, we can immediately register the listening address.
103          let mut listener_addr = if peer_side == ConnectionSide::Initiator {
104              debug!("Received a connection request from '{peer_addr}'");
105              None
106          } else {
107              debug!("Shaking hands with '{peer_addr}'...");
108              Some(peer_addr)
109          };
110  
111          // Check (or impose) IP-level bans.
112          #[cfg(not(feature = "test"))]
113          if !self.is_dev() && peer_side == ConnectionSide::Initiator {
114              // If the IP is already banned reject the connection.
115              if self.is_ip_banned(peer_addr.ip()) {
116                  trace!("Rejected a connection request from banned IP '{}'", peer_addr.ip());
117                  return Err(error(format!("'{}' is a banned IP address", peer_addr.ip())));
118              }
119  
120              let num_attempts =
121                  self.cache.insert_inbound_connection(peer_addr.ip(), Router::<N>::CONNECTION_ATTEMPTS_SINCE_SECS);
122  
123              debug!("Number of connection attempts from '{}': {}", peer_addr.ip(), num_attempts);
124              if num_attempts > Router::<N>::MAX_CONNECTION_ATTEMPTS {
125                  self.update_ip_ban(peer_addr.ip());
126                  trace!("Rejected a consecutive connection request from IP '{}'", peer_addr.ip());
127                  return Err(error(format!("'{}' appears to be spamming connections", peer_addr.ip())));
128              }
129          }
130  
131          // Perform the handshake; we pass on a mutable reference to listener_addr in case the process is broken at any point in time.
132          let handshake_result = if peer_side == ConnectionSide::Responder {
133              self.handshake_inner_initiator(peer_addr, stream, genesis_header, restrictions_id).await
134          } else {
135              self.handshake_inner_responder(peer_addr, &mut listener_addr, stream, genesis_header, restrictions_id).await
136          };
137  
138          if let Some(addr) = listener_addr {
139              match handshake_result {
140                  Ok(Some(ref cr)) => {
141                      if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
142                          self.resolver.write().insert_peer(peer.listener_addr(), peer_addr, Some(cr.address));
143                          peer.upgrade_to_connected(
144                              peer_addr,
145                              cr.listener_port,
146                              cr.address,
147                              cr.node_type,
148                              cr.version,
149                              ConnectionMode::Router,
150                          );
151                      }
152                      #[cfg(feature = "metrics")]
153                      self.update_metrics();
154                      debug!("Completed the handshake with '{peer_addr}'");
155                  }
156                  Ok(None) => {
157                      return Err(error(format!("Duplicate handshake attempt with '{addr}'")));
158                  }
159                  Err(_) => {
160                      if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
161                          // The peer may only be downgraded if it's a ConnectingPeer.
162                          if peer.is_connecting() {
163                              peer.downgrade_to_candidate(addr);
164                          }
165                      }
166                  }
167              }
168          }
169  
170          handshake_result
171      }
172  
173      /// The connection initiator side of the handshake.
174      async fn handshake_inner_initiator<'a>(
175          &'a self,
176          peer_addr: SocketAddr,
177          stream: &'a mut TcpStream,
178          genesis_header: Header<N>,
179          restrictions_id: Field<N>,
180      ) -> io::Result<Option<ChallengeRequest<N>>> {
181          // Introduce the peer into the peer pool.
182          if !self.add_connecting_peer(peer_addr) {
183              // Return early if already being connected to.
184              return Ok(None);
185          }
186  
187          // Construct the stream.
188          let mut framed = Framed::new(stream, MessageCodec::<N>::handshake());
189  
190          // Initialize an RNG.
191          let rng = &mut OsRng;
192  
193          // Determine the AlphaOS SHA to send to the peer.
194          let current_block_height = self.ledger.latest_block_height();
195          let consensus_version = N::CONSENSUS_VERSION(current_block_height).unwrap();
196          let alphaos_sha = match (consensus_version >= ConsensusVersion::V12, get_repo_commit_hash()) {
197              (true, Some(sha)) => Some(sha),
198              _ => None,
199          };
200  
201          /* Step 1: Send the challenge request. */
202  
203          // Sample a random nonce.
204          let our_nonce = rng.r#gen();
205          // Send a challenge request to the peer.
206          let our_request =
207              ChallengeRequest::new(self.local_ip().port(), self.node_type, self.address(), our_nonce, alphaos_sha);
208          send(&mut framed, peer_addr, Message::ChallengeRequest(our_request)).await?;
209  
210          /* Step 2: Receive the peer's challenge response followed by the challenge request. */
211  
212          // Listen for the challenge response message.
213          let peer_response = expect_message!(Message::ChallengeResponse, framed, peer_addr);
214          // Listen for the challenge request message.
215          let peer_request = expect_message!(Message::ChallengeRequest, framed, peer_addr);
216  
217          // Verify the challenge response. If a disconnect reason was returned, send the disconnect message and abort.
218          if let Some(reason) = self
219              .verify_challenge_response(
220                  peer_addr,
221                  peer_request.address,
222                  peer_request.node_type,
223                  peer_response,
224                  genesis_header,
225                  restrictions_id,
226                  our_nonce,
227              )
228              .await
229          {
230              send(&mut framed, peer_addr, reason.into()).await?;
231              return Err(io_error(format!("Dropped '{peer_addr}' for reason: {reason}")));
232          }
233          // Verify the challenge request. If a disconnect reason was returned, send the disconnect message and abort.
234          if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
235              send(&mut framed, peer_addr, reason.into()).await?;
236              return Err(io_error(format!("Dropped '{peer_addr}' for reason: {reason}")));
237          }
238  
239          /* Step 3: Send the challenge response. */
240  
241          let response_nonce: u64 = rng.r#gen();
242          let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
243          // Sign the counterparty nonce.
244          let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
245              return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
246          };
247          // Send the challenge response.
248          let our_response = ChallengeResponse {
249              genesis_header,
250              restrictions_id,
251              signature: Data::Object(our_signature),
252              nonce: response_nonce,
253          };
254          send(&mut framed, peer_addr, Message::ChallengeResponse(our_response)).await?;
255  
256          Ok(Some(peer_request))
257      }
258  
259      /// The connection responder side of the handshake.
260      async fn handshake_inner_responder<'a>(
261          &'a self,
262          peer_addr: SocketAddr,
263          listener_addr: &mut Option<SocketAddr>,
264          stream: &'a mut TcpStream,
265          genesis_header: Header<N>,
266          restrictions_id: Field<N>,
267      ) -> io::Result<Option<ChallengeRequest<N>>> {
268          // Construct the stream.
269          let mut framed = Framed::new(stream, MessageCodec::<N>::handshake());
270  
271          /* Step 1: Receive the challenge request. */
272  
273          // Listen for the challenge request message.
274          let peer_request = expect_message!(Message::ChallengeRequest, framed, peer_addr);
275  
276          // Determine the AlphaOS SHA to send to the peer.
277          let current_block_height = self.ledger.latest_block_height();
278          let consensus_version = N::CONSENSUS_VERSION(current_block_height).unwrap();
279          let alphaos_sha = match (consensus_version >= ConsensusVersion::V12, get_repo_commit_hash()) {
280              (true, Some(sha)) => Some(sha),
281              _ => None,
282          };
283  
284          // Obtain the peer's listening address.
285          *listener_addr = Some(SocketAddr::new(peer_addr.ip(), peer_request.listener_port));
286          let listener_addr = listener_addr.unwrap();
287  
288          // Knowing the peer's listening address, ensure it is allowed to connect.
289          if let Err(forbidden_message) = self.ensure_peer_is_allowed(listener_addr) {
290              return Err(error(format!("{forbidden_message}")));
291          }
292  
293          // Introduce the peer into the peer pool.
294          if !self.add_connecting_peer(listener_addr) {
295              // Return early if already being connected to.
296              return Ok(None);
297          }
298  
299          // Verify the challenge request. If a disconnect reason was returned, send the disconnect message and abort.
300          if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
301              send(&mut framed, peer_addr, reason.into()).await?;
302              return Err(io_error(format!("Dropped '{peer_addr}' for reason: {reason}")));
303          }
304  
305          /* Step 2: Send the challenge response followed by own challenge request. */
306  
307          // Initialize an RNG.
308          let rng = &mut OsRng;
309  
310          // Sign the counterparty nonce.
311          let response_nonce: u64 = rng.r#gen();
312          let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
313          let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
314              return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
315          };
316          // Send the challenge response.
317          let our_response = ChallengeResponse {
318              genesis_header,
319              restrictions_id,
320              signature: Data::Object(our_signature),
321              nonce: response_nonce,
322          };
323          send(&mut framed, peer_addr, Message::ChallengeResponse(our_response)).await?;
324  
325          // Sample a random nonce.
326          let our_nonce = rng.r#gen();
327          // Send the challenge request.
328          let our_request =
329              ChallengeRequest::new(self.local_ip().port(), self.node_type, self.address(), our_nonce, alphaos_sha);
330          send(&mut framed, peer_addr, Message::ChallengeRequest(our_request)).await?;
331  
332          /* Step 3: Receive the challenge response. */
333  
334          // Listen for the challenge response message.
335          let peer_response = expect_message!(Message::ChallengeResponse, framed, peer_addr);
336          // Verify the challenge response. If a disconnect reason was returned, send the disconnect message and abort.
337          if let Some(reason) = self
338              .verify_challenge_response(
339                  peer_addr,
340                  peer_request.address,
341                  peer_request.node_type,
342                  peer_response,
343                  genesis_header,
344                  restrictions_id,
345                  our_nonce,
346              )
347              .await
348          {
349              send(&mut framed, peer_addr, reason.into()).await?;
350              return Err(io_error(format!("Dropped '{peer_addr}' for reason: {reason}")));
351          }
352  
353          Ok(Some(peer_request))
354      }
355  
356      /// Ensure the peer is allowed to connect.
357      fn ensure_peer_is_allowed(&self, listener_addr: SocketAddr) -> Result<()> {
358          // Ensure that it's not a self-connect attempt.
359          if self.is_local_ip(listener_addr) {
360              bail!("Dropping connection request from '{listener_addr}' (attempted to self-connect)");
361          }
362          // As a validator, only accept connections from trusted peers and bootstrap nodes.
363          if self.node_type() == NodeType::Validator
364              && !self.is_trusted(listener_addr)
365              && !crate::bootstrap_peers::<N>(self.is_dev()).contains(&listener_addr)
366          {
367              bail!("Dropping connection request from '{listener_addr}' (untrusted)");
368          }
369          // If the node is in trusted peers only mode, ensure the peer is explicitly trusted.
370          if self.trusted_peers_only() && !self.is_trusted(listener_addr) {
371              bail!("Dropping connection request from '{listener_addr}' (untrusted)");
372          }
373          Ok(())
374      }
375  
376      /// Verifies the given challenge request. Returns a disconnect reason if the request is invalid.
377      fn verify_challenge_request(
378          &self,
379          peer_addr: SocketAddr,
380          message: &ChallengeRequest<N>,
381      ) -> Option<DisconnectReason> {
382          // Retrieve the components of the challenge request.
383          let &ChallengeRequest { version, listener_port: _, node_type, address, nonce: _, ref alphaos_sha } = message;
384          log_repo_sha_comparison(peer_addr, alphaos_sha, Self::OWNER);
385  
386          // Ensure the message protocol version is not outdated.
387          if !self.is_valid_message_version(version) {
388              warn!("Dropping '{peer_addr}' on version {version} (outdated)");
389              return Some(DisconnectReason::OutdatedClientVersion);
390          }
391  
392          // Ensure there are no validators connected with the given Aleo address.
393          if self.node_type() == NodeType::Validator
394              && node_type == NodeType::Validator
395              && self.is_connected_address(address)
396          {
397              warn!("Dropping '{peer_addr}' for being already connected ({address})");
398              return Some(DisconnectReason::NoReasonGiven);
399          }
400  
401          None
402      }
403  
404      /// Verifies the given challenge response. Returns a disconnect reason if the response is invalid.
405      #[allow(clippy::too_many_arguments)]
406      async fn verify_challenge_response(
407          &self,
408          peer_addr: SocketAddr,
409          peer_address: Address<N>,
410          peer_node_type: NodeType,
411          response: ChallengeResponse<N>,
412          expected_genesis_header: Header<N>,
413          expected_restrictions_id: Field<N>,
414          expected_nonce: u64,
415      ) -> Option<DisconnectReason> {
416          // Retrieve the components of the challenge response.
417          let ChallengeResponse { genesis_header, restrictions_id, signature, nonce } = response;
418  
419          // Verify the challenge response, by checking that the block header matches.
420          if genesis_header != expected_genesis_header {
421              warn!("Handshake with '{peer_addr}' failed (incorrect block header)");
422              return Some(DisconnectReason::InvalidChallengeResponse);
423          }
424          // Verify the restrictions ID.
425          if !peer_node_type.is_prover() && !self.node_type.is_prover() && restrictions_id != expected_restrictions_id {
426              warn!("Handshake with '{peer_addr}' failed (incorrect restrictions ID)");
427              return Some(DisconnectReason::InvalidChallengeResponse);
428          }
429          // Perform the deferred non-blocking deserialization of the signature.
430          let Ok(signature) = signature.deserialize().await else {
431              warn!("Handshake with '{peer_addr}' failed (cannot deserialize the signature)");
432              return Some(DisconnectReason::InvalidChallengeResponse);
433          };
434          // Verify the signature.
435          if !signature.verify_bytes(&peer_address, &[expected_nonce.to_le_bytes(), nonce.to_le_bytes()].concat()) {
436              warn!("Handshake with '{peer_addr}' failed (invalid signature)");
437              return Some(DisconnectReason::InvalidChallengeResponse);
438          }
439          None
440      }
441  }