/ node / src / bootstrap_client / handshake.rs
handshake.rs
  1  // Copyright (c) 2025 ADnet Contributors
  2  // This file is part of the AlphaOS library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  use crate::{
 17      BootstrapClient,
 18      bft::events::{self, Event},
 19      bootstrap_client::{codec::BootstrapClientCodec, network::MessageOrEvent},
 20      network::{ConnectionMode, NodeType, PeerPoolHandling, log_repo_sha_comparison},
 21      router::messages::{self, Message},
 22      tcp::{Connection, ConnectionSide, protocols::*},
 23  };
 24  use alphavm::{
 25      ledger::narwhal::Data,
 26      prelude::{Address, Network, error},
 27  };
 28  
 29  use futures_util::sink::SinkExt;
 30  use rand::{Rng, rngs::OsRng};
 31  use std::{io, net::SocketAddr};
 32  use tokio::net::TcpStream;
 33  use tokio_stream::StreamExt;
 34  use tokio_util::codec::Framed;
 35  
 36  #[derive(Debug)]
 37  enum HandshakeMessageKind {
 38      ChallengeRequest,
 39      ChallengeResponse,
 40  }
 41  
 42  macro_rules! send_msg {
 43      ($msg:expr, $framed:expr, $peer_addr:expr) => {{
 44          trace!("Sending '{}' to '{}'", $msg.name(), $peer_addr);
 45          $framed.send($msg).await
 46      }};
 47  }
 48  
 49  /// A macro handling incoming handshake messages, rejecting unexpected ones.
 50  macro_rules! expect_handshake_msg {
 51      ($msg_ty:expr, $framed:expr, $peer_addr:expr) => {{
 52          // Read the message as bytes.
 53          let Some(message) = $framed.try_next().await? else {
 54              return Err(error(format!(
 55                  "the peer disconnected before sending {:?}, likely due to peer saturation or shutdown",
 56                  stringify!($msg_ty),
 57              )));
 58          };
 59  
 60          // Match the expected message type with its expected size or peer type indicator.
 61          match $msg_ty {
 62              HandshakeMessageKind::ChallengeRequest
 63                  if matches!(
 64                      message,
 65                      MessageOrEvent::Message(Message::ChallengeRequest(_))
 66                          | MessageOrEvent::Event(Event::ChallengeRequest(_))
 67                  ) =>
 68              {
 69                  trace!("Received a '{}' from '{}'", stringify!($msg_ty), $peer_addr);
 70                  message
 71              }
 72              HandshakeMessageKind::ChallengeResponse
 73                  if matches!(
 74                      message,
 75                      MessageOrEvent::Message(Message::ChallengeResponse(_))
 76                          | MessageOrEvent::Event(Event::ChallengeResponse(_))
 77                  ) =>
 78              {
 79                  trace!("Received a '{}' from '{}'", stringify!($msg_ty), $peer_addr);
 80                  message
 81              }
 82              _ => {
 83                  let msg_name = match message {
 84                      MessageOrEvent::Message(message) => message.name(),
 85                      MessageOrEvent::Event(event) => event.name(),
 86                  };
 87                  return Err(error(format!(
 88                      "'{}' did not follow the handshake protocol: expected {}, got {msg_name}",
 89                      $peer_addr,
 90                      stringify!($msg_ty),
 91                  )));
 92              }
 93          }
 94      }};
 95  }
 96  
 97  #[async_trait]
 98  impl<N: Network> Handshake for BootstrapClient<N> {
 99      async fn perform_handshake(&self, mut connection: Connection) -> io::Result<Connection> {
100          let peer_addr = connection.addr();
101          let peer_side = connection.side();
102          let stream = self.borrow_stream(&mut connection);
103  
104          // We don't know the listening address yet, as we don't initiate connections.
105          let mut listener_addr = if peer_side == ConnectionSide::Initiator {
106              debug!("Received a connection request from '{peer_addr}'");
107              None
108          } else {
109              unreachable!("The boostrapper clients don't initiate connections");
110          };
111  
112          // Perform the handshake; we pass on a mutable reference to listener_addr in case the process is broken at any point in time.
113          let handshake_result = if peer_side == ConnectionSide::Responder {
114              unreachable!("The boostrapper clients don't initiate connections");
115          } else {
116              self.handshake_inner_responder(peer_addr, &mut listener_addr, stream).await
117          };
118  
119          if let Some(addr) = listener_addr {
120              match handshake_result {
121                  Ok(Some((peer_port, peer_aleo_addr, peer_node_type, peer_version, connection_mode))) => {
122                      if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
123                          // Due to only having a single Resolver, the BootstrapClient only adds an Aleo
124                          // address mapping for Gateway-mode connections, as it is only used there, and
125                          // it could otherwise clash with the Router-mode mapping for validators, which
126                          // may connect in both modes at the same time.
127                          let aleo_addr =
128                              if connection_mode == ConnectionMode::Gateway { Some(peer_aleo_addr) } else { None };
129                          self.resolver.write().insert_peer(peer.listener_addr(), peer_addr, aleo_addr);
130                          peer.upgrade_to_connected(
131                              peer_addr,
132                              peer_port,
133                              peer_aleo_addr,
134                              peer_node_type,
135                              peer_version,
136                              connection_mode,
137                          );
138                      }
139                      debug!("Completed the handshake with '{peer_addr}'");
140                  }
141                  Ok(None) => {
142                      return Err(error(format!("Duplicate handshake attempt with '{addr}'")));
143                  }
144                  Err(error) => {
145                      debug!("Handshake with '{peer_addr}' failed: {error}");
146                      if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
147                          // The peer may only be downgraded if it's a ConnectingPeer.
148                          if peer.is_connecting() {
149                              peer.downgrade_to_candidate(addr);
150                          }
151                      }
152                      return Err(error);
153                  }
154              }
155          }
156  
157          Ok(connection)
158      }
159  }
160  
161  impl<N: Network> BootstrapClient<N> {
162      /// The connection responder side of the handshake.
163      async fn handshake_inner_responder<'a>(
164          &'a self,
165          peer_addr: SocketAddr,
166          listener_addr: &mut Option<SocketAddr>,
167          stream: &'a mut TcpStream,
168      ) -> io::Result<Option<(u16, Address<N>, NodeType, u32, ConnectionMode)>> {
169          // Construct the stream.
170          let mut framed = Framed::new(stream, BootstrapClientCodec::<N>::handshake());
171  
172          /* Step 1: Receive the challenge request. */
173  
174          // Listen for the challenge request message, which can be either from a regular peer, or a validator.
175          let peer_request = expect_handshake_msg!(HandshakeMessageKind::ChallengeRequest, framed, peer_addr);
176          let (peer_port, peer_nonce, peer_aleo_addr, peer_node_type, peer_version, connection_mode) = match peer_request
177          {
178              MessageOrEvent::Message(Message::ChallengeRequest(ref msg)) => {
179                  (msg.listener_port, msg.nonce, msg.address, msg.node_type, msg.version, ConnectionMode::Router)
180              }
181              MessageOrEvent::Event(Event::ChallengeRequest(ref msg)) => {
182                  (msg.listener_port, msg.nonce, msg.address, NodeType::Validator, msg.version, ConnectionMode::Gateway)
183              }
184              _ => unreachable!(),
185          };
186          debug!("Handshake mode: {connection_mode:?}");
187  
188          // Obtain the peer's listening address.
189          *listener_addr = Some(SocketAddr::new(peer_addr.ip(), peer_port));
190          let listener_addr = listener_addr.unwrap();
191  
192          // Introduce the peer into the peer pool.
193          if !self.add_connecting_peer(listener_addr) {
194              // Return early if already being connected to.
195              return Ok(None);
196          }
197  
198          // Verify the challenge request.
199          if !self.verify_challenge_request(peer_addr, &mut framed, &peer_request).await? {
200              return Err(error(format!("Handshake with '{peer_addr}' failed: invalid challenge request")));
201          };
202  
203          /* Step 2: Send the challenge response followed by own challenge request. */
204  
205          // Initialize an RNG.
206          let rng = &mut OsRng;
207  
208          // Sign the counterparty nonce.
209          let response_nonce: u64 = rng.r#gen();
210          let data = [peer_nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
211          let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
212              return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
213          };
214  
215          // Send the challenge response.
216          if connection_mode == ConnectionMode::Router {
217              let our_response = messages::ChallengeResponse {
218                  genesis_header: self.genesis_header,
219                  restrictions_id: self.restrictions_id,
220                  signature: Data::Object(our_signature),
221                  nonce: response_nonce,
222              };
223              let msg = Message::ChallengeResponse::<N>(our_response);
224              send_msg!(msg, framed, peer_addr)?;
225          } else {
226              let our_response = events::ChallengeResponse {
227                  restrictions_id: self.restrictions_id,
228                  signature: Data::Object(our_signature),
229                  nonce: response_nonce,
230              };
231              let msg = Event::ChallengeResponse::<N>(our_response);
232              send_msg!(msg, framed, peer_addr)?;
233          }
234  
235          // Sample a random nonce.
236          let our_nonce: u64 = rng.r#gen();
237          // Do not send a AlphaOS SHA as the bootstrap client is not aware of height.
238          let alphaos_sha = None;
239          // Send the challenge request.
240          if connection_mode == ConnectionMode::Router {
241              let our_request = messages::ChallengeRequest::new(
242                  self.local_ip().port(),
243                  NodeType::BootstrapClient,
244                  self.account.address(),
245                  our_nonce,
246                  alphaos_sha,
247              );
248              let msg = Message::ChallengeRequest(our_request);
249              send_msg!(msg, framed, peer_addr)?;
250          } else {
251              let our_request =
252                  events::ChallengeRequest::new(self.local_ip().port(), self.account.address(), our_nonce, alphaos_sha);
253              let msg = Event::ChallengeRequest(our_request);
254              send_msg!(msg, framed, peer_addr)?;
255          }
256  
257          /* Step 3: Receive the challenge response. */
258  
259          // Listen for the challenge response message.
260          let peer_response = expect_handshake_msg!(HandshakeMessageKind::ChallengeResponse, framed, peer_addr);
261          // Verify the challenge response.
262          if !self.verify_challenge_response(peer_addr, peer_aleo_addr, our_nonce, &peer_response).await {
263              if connection_mode == ConnectionMode::Router {
264                  let msg = Message::Disconnect::<N>(messages::DisconnectReason::InvalidChallengeResponse.into());
265                  send_msg!(msg, framed, peer_addr)?;
266              } else {
267                  let msg = Event::Disconnect::<N>(events::DisconnectReason::InvalidChallengeResponse.into());
268                  send_msg!(msg, framed, peer_addr)?;
269              }
270              return Err(error(format!("Handshake with '{peer_addr}' failed: invalid challenge response")));
271          }
272  
273          Ok(Some((peer_port, peer_aleo_addr, peer_node_type, peer_version, connection_mode)))
274      }
275  
276      async fn verify_challenge_request(
277          &self,
278          peer_addr: SocketAddr,
279          framed: &mut Framed<&mut TcpStream, BootstrapClientCodec<N>>,
280          request: &MessageOrEvent<N>,
281      ) -> io::Result<bool> {
282          match request {
283              MessageOrEvent::Message(Message::ChallengeRequest(msg)) => {
284                  log_repo_sha_comparison(peer_addr, &msg.alphaos_sha, Self::OWNER);
285  
286                  if msg.version < Message::<N>::latest_message_version() {
287                      let msg = Message::Disconnect::<N>(messages::DisconnectReason::OutdatedClientVersion.into());
288                      send_msg!(msg, framed, peer_addr)?;
289                      return Ok(false);
290                  }
291  
292                  // Reject validators that aren't members of the committee.
293                  if msg.node_type == NodeType::Validator {
294                      if let Some(current_committee) =
295                          self.get_or_update_committee().await.map_err(|_| error("Couldn't load the committee"))?
296                      {
297                          if !current_committee.contains(&msg.address) {
298                              let msg = Message::Disconnect::<N>(messages::DisconnectReason::ProtocolViolation.into());
299                              send_msg!(msg, framed, peer_addr)?;
300                              return Ok(false);
301                          }
302                      }
303                  }
304              }
305              MessageOrEvent::Event(Event::ChallengeRequest(msg)) => {
306                  log_repo_sha_comparison(peer_addr, &msg.alphaos_sha, Self::OWNER);
307  
308                  if msg.version < Event::<N>::VERSION {
309                      let msg = Event::Disconnect::<N>(events::DisconnectReason::OutdatedClientVersion.into());
310                      send_msg!(msg, framed, peer_addr)?;
311                      return Ok(false);
312                  }
313  
314                  // Reject validators that aren't members of the committee.
315                  if let Some(current_committee) =
316                      self.get_or_update_committee().await.map_err(|_| error("Couldn't load the committee"))?
317                  {
318                      if !current_committee.contains(&msg.address) {
319                          let msg = Event::Disconnect::<N>(events::DisconnectReason::ProtocolViolation.into());
320                          send_msg!(msg, framed, peer_addr)?;
321                          return Ok(false);
322                      }
323                  }
324              }
325              _ => unreachable!(),
326          }
327  
328          Ok(true)
329      }
330  
331      async fn verify_challenge_response(
332          &self,
333          peer_addr: SocketAddr,
334          peer_aleo_addr: Address<N>,
335          our_nonce: u64,
336          response: &MessageOrEvent<N>,
337      ) -> bool {
338          let (peer_restrictions_id, peer_signature, peer_nonce) = match response {
339              MessageOrEvent::Message(Message::ChallengeResponse(msg)) => {
340                  (msg.restrictions_id, msg.signature.clone(), msg.nonce)
341              }
342              MessageOrEvent::Event(Event::ChallengeResponse(msg)) => {
343                  (msg.restrictions_id, msg.signature.clone(), msg.nonce)
344              }
345              _ => unreachable!(),
346          };
347  
348          // Verify the restrictions ID.
349          if peer_restrictions_id != self.restrictions_id {
350              warn!("{} Handshake with '{peer_addr}' failed (incorrect restrictions ID)", Self::OWNER);
351              return false;
352          }
353          // Perform the deferred non-blocking deserialization of the signature.
354          let Ok(signature) = peer_signature.deserialize().await else {
355              warn!("{} Handshake with '{peer_addr}' failed (cannot deserialize the signature)", Self::OWNER);
356              return false;
357          };
358          // Verify the signature.
359          if !signature.verify_bytes(&peer_aleo_addr, &[our_nonce.to_le_bytes(), peer_nonce.to_le_bytes()].concat()) {
360              warn!("{} Handshake with '{peer_addr}' failed (invalid signature)", Self::OWNER);
361              return false;
362          }
363  
364          true
365      }
366  }