handshake.rs
1 // Copyright (c) 2025 ADnet Contributors 2 // This file is part of the adnet-core library. 3 // Derived from snarkOS, originally by Provable Inc. 4 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at: 8 9 // http://www.apache.org/licenses/LICENSE-2.0 10 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 use std::{io, time::Duration}; 18 19 use tokio::{ 20 io::{AsyncRead, AsyncWrite, split}, 21 net::TcpStream, 22 sync::{mpsc, oneshot}, 23 time::timeout, 24 }; 25 use tracing::*; 26 27 use crate::{ 28 Connection, P2P, 29 protocols::{ProtocolHandler, ReturnableConnection}, 30 }; 31 32 /// Can be used to specify and enable network handshakes. Upon establishing a connection, both sides will 33 /// need to adhere to the specified handshake rules in order to finalize the connection and be able to send 34 /// or receive any messages. 35 #[async_trait::async_trait] 36 pub trait Handshake: P2P 37 where 38 Self: Clone + Send + Sync + 'static, 39 { 40 /// The maximum time allowed for a connection to perform a handshake before it is rejected. 41 /// 42 /// The default value is 3000ms. 43 const TIMEOUT_MS: u64 = 3_000; 44 45 /// Prepares the node to perform specified network handshakes. 46 async fn enable_handshake(&self) { 47 let (from_node_sender, mut from_node_receiver) = 48 mpsc::unbounded_channel::<ReturnableConnection>(); 49 50 // use a channel to know when the handshake task is ready 51 let (tx, rx) = oneshot::channel(); 52 53 // spawn a background task dedicated to handling the handshakes 54 let self_clone = self.clone(); 55 let handshake_task = tokio::spawn(async move { 56 trace!(parent: self_clone.tcp().span(), "spawned the Handshake handler task"); 57 tx.send(()).unwrap(); // safe; the channel was just opened 58 59 while let Some((conn, result_sender)) = from_node_receiver.recv().await { 60 let addr = conn.addr(); 61 62 let node = self_clone.clone(); 63 tokio::spawn(async move { 64 debug!(parent: node.tcp().span(), "shaking hands with {} as the {:?}", addr, !conn.side()); 65 let result = timeout( 66 Duration::from_millis(Self::TIMEOUT_MS), 67 node.perform_handshake(conn), 68 ) 69 .await; 70 71 let ret = match result { 72 Ok(Ok(conn)) => { 73 debug!(parent: node.tcp().span(), "successfully handshaken with {}", addr); 74 Ok(conn) 75 } 76 Ok(Err(e)) => { 77 debug!(parent: node.tcp().span(), "handshake with {addr} failed: {e}"); 78 Err(e) 79 } 80 Err(_) => { 81 debug!(parent: node.tcp().span(), "handshake with {} timed out", addr); 82 Err(io::ErrorKind::TimedOut.into()) 83 } 84 }; 85 86 // return the Connection to the Tcp, resuming Tcp::adapt_stream 87 if result_sender.send(ret).is_err() { 88 unreachable!("couldn't return a Connection to the Tcp"); 89 } 90 }); 91 } 92 }); 93 let _ = rx.await; 94 self.tcp().tasks.lock().push(handshake_task); 95 96 // register the Handshake handler with the Tcp 97 let hdl = Box::new(ProtocolHandler(from_node_sender)); 98 assert!( 99 self.tcp().protocols.handshake.set(hdl).is_ok(), 100 "the Handshake protocol was enabled more than once!" 101 ); 102 } 103 104 /// Performs the handshake; temporarily assumes control of the [`Connection`] and returns it if the handshake is 105 /// successful. 106 async fn perform_handshake(&self, conn: Connection) -> io::Result<Connection>; 107 108 /// Borrows the full connection stream to be used in the implementation of [`Handshake::perform_handshake`]. 109 fn borrow_stream<'a>(&self, conn: &'a mut Connection) -> &'a mut TcpStream { 110 conn.stream.as_mut().unwrap() 111 } 112 113 /// Assumes full control of a connection's stream in the implementation of [`Handshake::perform_handshake`], by 114 /// the end of which it *must* be followed by [`Handshake::return_stream`]. 115 fn take_stream(&self, conn: &mut Connection) -> TcpStream { 116 conn.stream.take().unwrap() 117 } 118 119 /// This method only needs to be called if [`Handshake::take_stream`] had been called before; it is used to 120 /// return a (potentially modified) stream back to the applicable connection. 121 fn return_stream<T: AsyncRead + AsyncWrite + Send + Sync + 'static>( 122 &self, 123 conn: &mut Connection, 124 stream: T, 125 ) { 126 let (reader, writer) = split(stream); 127 conn.reader = Some(Box::new(reader)); 128 conn.writer = Some(Box::new(writer)); 129 } 130 }