/ tcp / src / protocols / handshake.rs
handshake.rs
  1  // Copyright (c) 2025 ADnet Contributors
  2  // This file is part of the adnet-core library.
  3  // Derived from snarkOS, originally by Provable Inc.
  4  
  5  // Licensed under the Apache License, Version 2.0 (the "License");
  6  // you may not use this file except in compliance with the License.
  7  // You may obtain a copy of the License at:
  8  
  9  // http://www.apache.org/licenses/LICENSE-2.0
 10  
 11  // Unless required by applicable law or agreed to in writing, software
 12  // distributed under the License is distributed on an "AS IS" BASIS,
 13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14  // See the License for the specific language governing permissions and
 15  // limitations under the License.
 16  
 17  use std::{io, time::Duration};
 18  
 19  use tokio::{
 20      io::{AsyncRead, AsyncWrite, split},
 21      net::TcpStream,
 22      sync::{mpsc, oneshot},
 23      time::timeout,
 24  };
 25  use tracing::*;
 26  
 27  use crate::{
 28      Connection, P2P,
 29      protocols::{ProtocolHandler, ReturnableConnection},
 30  };
 31  
 32  /// Can be used to specify and enable network handshakes. Upon establishing a connection, both sides will
 33  /// need to adhere to the specified handshake rules in order to finalize the connection and be able to send
 34  /// or receive any messages.
 35  #[async_trait::async_trait]
 36  pub trait Handshake: P2P
 37  where
 38      Self: Clone + Send + Sync + 'static,
 39  {
 40      /// The maximum time allowed for a connection to perform a handshake before it is rejected.
 41      ///
 42      /// The default value is 3000ms.
 43      const TIMEOUT_MS: u64 = 3_000;
 44  
 45      /// Prepares the node to perform specified network handshakes.
 46      async fn enable_handshake(&self) {
 47          let (from_node_sender, mut from_node_receiver) =
 48              mpsc::unbounded_channel::<ReturnableConnection>();
 49  
 50          // use a channel to know when the handshake task is ready
 51          let (tx, rx) = oneshot::channel();
 52  
 53          // spawn a background task dedicated to handling the handshakes
 54          let self_clone = self.clone();
 55          let handshake_task = tokio::spawn(async move {
 56              trace!(parent: self_clone.tcp().span(), "spawned the Handshake handler task");
 57              tx.send(()).unwrap(); // safe; the channel was just opened
 58  
 59              while let Some((conn, result_sender)) = from_node_receiver.recv().await {
 60                  let addr = conn.addr();
 61  
 62                  let node = self_clone.clone();
 63                  tokio::spawn(async move {
 64                      debug!(parent: node.tcp().span(), "shaking hands with {} as the {:?}", addr, !conn.side());
 65                      let result = timeout(
 66                          Duration::from_millis(Self::TIMEOUT_MS),
 67                          node.perform_handshake(conn),
 68                      )
 69                      .await;
 70  
 71                      let ret = match result {
 72                          Ok(Ok(conn)) => {
 73                              debug!(parent: node.tcp().span(), "successfully handshaken with {}", addr);
 74                              Ok(conn)
 75                          }
 76                          Ok(Err(e)) => {
 77                              debug!(parent: node.tcp().span(), "handshake with {addr} failed: {e}");
 78                              Err(e)
 79                          }
 80                          Err(_) => {
 81                              debug!(parent: node.tcp().span(), "handshake with {} timed out", addr);
 82                              Err(io::ErrorKind::TimedOut.into())
 83                          }
 84                      };
 85  
 86                      // return the Connection to the Tcp, resuming Tcp::adapt_stream
 87                      if result_sender.send(ret).is_err() {
 88                          unreachable!("couldn't return a Connection to the Tcp");
 89                      }
 90                  });
 91              }
 92          });
 93          let _ = rx.await;
 94          self.tcp().tasks.lock().push(handshake_task);
 95  
 96          // register the Handshake handler with the Tcp
 97          let hdl = Box::new(ProtocolHandler(from_node_sender));
 98          assert!(
 99              self.tcp().protocols.handshake.set(hdl).is_ok(),
100              "the Handshake protocol was enabled more than once!"
101          );
102      }
103  
104      /// Performs the handshake; temporarily assumes control of the [`Connection`] and returns it if the handshake is
105      /// successful.
106      async fn perform_handshake(&self, conn: Connection) -> io::Result<Connection>;
107  
108      /// Borrows the full connection stream to be used in the implementation of [`Handshake::perform_handshake`].
109      fn borrow_stream<'a>(&self, conn: &'a mut Connection) -> &'a mut TcpStream {
110          conn.stream.as_mut().unwrap()
111      }
112  
113      /// Assumes full control of a connection's stream in the implementation of [`Handshake::perform_handshake`], by
114      /// the end of which it *must* be followed by [`Handshake::return_stream`].
115      fn take_stream(&self, conn: &mut Connection) -> TcpStream {
116          conn.stream.take().unwrap()
117      }
118  
119      /// This method only needs to be called if [`Handshake::take_stream`] had been called before; it is used to
120      /// return a (potentially modified) stream back to the applicable connection.
121      fn return_stream<T: AsyncRead + AsyncWrite + Send + Sync + 'static>(
122          &self,
123          conn: &mut Connection,
124          stream: T,
125      ) {
126          let (reader, writer) = split(stream);
127          conn.reader = Some(Box::new(reader));
128          conn.writer = Some(Box::new(writer));
129      }
130  }