/ src / net / channel.rs
channel.rs
  1  /* This file is part of DarkFi (https://dark.fi)
  2   *
  3   * Copyright (C) 2020-2025 Dyne.org foundation
  4   *
  5   * This program is free software: you can redistribute it and/or modify
  6   * it under the terms of the GNU Affero General Public License as
  7   * published by the Free Software Foundation, either version 3 of the
  8   * License, or (at your option) any later version.
  9   *
 10   * This program is distributed in the hope that it will be useful,
 11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
 12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 13   * GNU Affero General Public License for more details.
 14   *
 15   * You should have received a copy of the GNU Affero General Public License
 16   * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 17   */
 18  
 19  use std::{
 20      collections::HashMap,
 21      fmt,
 22      sync::{
 23          atomic::{AtomicBool, Ordering::SeqCst},
 24          Arc,
 25      },
 26      time::UNIX_EPOCH,
 27  };
 28  
 29  use darkfi_serial::{
 30      async_trait, AsyncDecodable, AsyncEncodable, SerialDecodable, SerialEncodable, VarInt,
 31  };
 32  use rand::{rngs::OsRng, Rng};
 33  use smol::{
 34      io::{self, AsyncRead, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
 35      lock::{Mutex as AsyncMutex, OnceCell},
 36      Executor,
 37  };
 38  use tracing::{debug, error, trace, warn};
 39  use url::Url;
 40  
 41  use super::{
 42      dnet::{self, dnetev, DnetEvent},
 43      hosts::{HostColor, HostsPtr},
 44      message,
 45      message::{SerializedMessage, VersionMessage, MAX_COMMAND_LENGTH},
 46      message_publisher::{MessageSubscription, MessageSubsystem},
 47      metering::{MeteringConfiguration, MeteringQueue},
 48      p2p::P2pPtr,
 49      session::{
 50          Session, SessionBitFlag, SessionWeakPtr, SESSION_ALL, SESSION_INBOUND, SESSION_REFINE,
 51      },
 52      transport::PtStream,
 53  };
 54  use crate::{
 55      net::BanPolicy,
 56      system::{msleep, Publisher, PublisherPtr, StoppableTask, StoppableTaskPtr, Subscription},
 57      util::{logger::verbose, time::NanoTimestamp},
 58      Error, Result,
 59  };
 60  
 61  /// Atomic pointer to async channel
 62  pub type ChannelPtr = Arc<Channel>;
 63  
 64  /// Channel debug info
 65  #[derive(Clone, Debug, SerialEncodable, SerialDecodable)]
 66  pub struct ChannelInfo {
 67      pub resolve_addr: Option<Url>,
 68      pub connect_addr: Url,
 69      pub start_time: u64,
 70      pub id: u32,
 71      pub transport_mixed: bool,
 72  }
 73  
 74  impl ChannelInfo {
 75      fn new(
 76          resolve_addr: Option<Url>,
 77          connect_addr: Url,
 78          start_time: u64,
 79          transport_mixed: bool,
 80      ) -> Self {
 81          Self { resolve_addr, connect_addr, start_time, id: OsRng.gen(), transport_mixed }
 82      }
 83  }
 84  
 85  /// Async channel for communication between nodes.
 86  pub struct Channel {
 87      /// The reading half of the transport stream
 88      reader: AsyncMutex<ReadHalf<Box<dyn PtStream>>>,
 89      /// The writing half of the transport stream
 90      writer: AsyncMutex<WriteHalf<Box<dyn PtStream>>>,
 91      /// The message subsystem instance for this channel
 92      message_subsystem: MessageSubsystem,
 93      /// Publisher listening for stop signal for closing this channel
 94      stop_publisher: PublisherPtr<Error>,
 95      /// Task that is listening for the stop signal
 96      receive_task: StoppableTaskPtr,
 97      /// A boolean marking if this channel is stopped
 98      stopped: AtomicBool,
 99      /// Weak pointer to respective session
100      pub(in crate::net) session: SessionWeakPtr,
101      /// The version message of the node we are connected to.
102      /// Some if the version exchange has already occurred, None
103      /// otherwise.
104      pub version: OnceCell<Arc<VersionMessage>>,
105      /// Channel debug info
106      pub info: ChannelInfo,
107      /// Map holding a `MeteringQueue` for each [`Message`] to perform
108      /// rate limiting of propagation towards the stream.
109      metering_map: AsyncMutex<HashMap<String, MeteringQueue>>,
110  }
111  
112  impl Channel {
113      /// Sets up a new channel. Creates a reader and writer [`PtStream`] and
114      /// the message publisher subsystem. Performs a network handshake on the
115      /// subsystem dispatchers.
116      pub async fn new(
117          stream: Box<dyn PtStream>,
118          resolve_addr: Option<Url>,
119          connect_addr: Url,
120          session: SessionWeakPtr,
121          transport_mixed: bool,
122      ) -> Arc<Self> {
123          let (reader, writer) = io::split(stream);
124          let reader = AsyncMutex::new(reader);
125          let writer = AsyncMutex::new(writer);
126  
127          let message_subsystem = MessageSubsystem::new();
128          Self::setup_dispatchers(&message_subsystem).await;
129  
130          let start_time = UNIX_EPOCH.elapsed().unwrap().as_secs();
131          let info =
132              ChannelInfo::new(resolve_addr, connect_addr.clone(), start_time, transport_mixed);
133          let metering_map = AsyncMutex::new(HashMap::new());
134  
135          Arc::new(Self {
136              reader,
137              writer,
138              message_subsystem,
139              stop_publisher: Publisher::new(),
140              receive_task: StoppableTask::new(),
141              stopped: AtomicBool::new(false),
142              session,
143              version: OnceCell::new(),
144              info,
145              metering_map,
146          })
147      }
148  
149      /// Perform network handshake for message subsystem dispatchers.
150      async fn setup_dispatchers(subsystem: &MessageSubsystem) {
151          subsystem.add_dispatch::<message::VersionMessage>().await;
152          subsystem.add_dispatch::<message::VerackMessage>().await;
153          subsystem.add_dispatch::<message::PingMessage>().await;
154          subsystem.add_dispatch::<message::PongMessage>().await;
155          subsystem.add_dispatch::<message::GetAddrsMessage>().await;
156          subsystem.add_dispatch::<message::AddrsMessage>().await;
157      }
158  
159      /// Starts the channel. Runs a receive loop to start receiving messages
160      /// or handles a network failure.
161      pub fn start(self: Arc<Self>, executor: Arc<Executor<'_>>) {
162          debug!(target: "net::channel::start()", "START {self:?}");
163  
164          let self_ = self.clone();
165          self.receive_task.clone().start(
166              self.clone().main_receive_loop(),
167              |result| self_.handle_stop(result),
168              Error::ChannelStopped,
169              executor,
170          );
171  
172          debug!(target: "net::channel::start()", "END {self:?}");
173      }
174  
175      /// Stops the channel.
176      /// Notifies all publishers that the channel has been closed in `handle_stop()`.
177      pub async fn stop(&self) {
178          debug!(target: "net::channel::stop()", "START {self:?}");
179          self.receive_task.stop().await;
180          debug!(target: "net::channel::stop()", "END {self:?}");
181      }
182  
183      /// Creates a subscription to a stopped signal.
184      /// If the channel is stopped then this will return a ChannelStopped error.
185      pub async fn subscribe_stop(&self) -> Result<Subscription<Error>> {
186          debug!(target: "net::channel::subscribe_stop()", "START {self:?}");
187  
188          if self.is_stopped() {
189              return Err(Error::ChannelStopped)
190          }
191  
192          let sub = self.stop_publisher.clone().subscribe().await;
193  
194          debug!(target: "net::channel::subscribe_stop()", "END {self:?}");
195  
196          Ok(sub)
197      }
198  
199      pub fn is_stopped(&self) -> bool {
200          self.stopped.load(SeqCst)
201      }
202  
203      /// Sends a message across a channel. First it converts the message
204      /// into a `SerializedMessage` and then calls `send_serialized` to send it.
205      /// Returns an error if something goes wrong.
206      pub async fn send<M: message::Message>(&self, message: &M) -> Result<()> {
207          self.send_serialized(
208              &SerializedMessage::new(message).await,
209              &M::METERING_SCORE,
210              &M::METERING_CONFIGURATION,
211          )
212          .await
213      }
214  
215      /// Sends the encoded payload of provided `SerializedMessage` across the channel.
216      ///
217      /// We first check if we should apply some throttling, based on the provided
218      /// `Message` configuration. We always sleep 2x times more than the expected one,
219      /// so we don't flood the peer.
220      /// Then, calls `send_message` that creates a new payload and sends it over the
221      /// network transport as a packet.
222      /// Returns an error if something goes wrong.
223      pub async fn send_serialized(
224          &self,
225          message: &SerializedMessage,
226          metering_score: &u64,
227          metering_config: &MeteringConfiguration,
228      ) -> Result<()> {
229          debug!(
230               target: "net::channel::send()", "[START] command={} {self:?}",
231               message.command,
232          );
233  
234          // Check if we need to initialize a `MeteringQueue`
235          // for this specific `Message`.
236          let mut lock = self.metering_map.lock().await;
237          if !lock.contains_key(&message.command) {
238              lock.insert(message.command.clone(), MeteringQueue::new(metering_config.clone()));
239          }
240  
241          // Insert metering information and grab potential sleep time.
242          // It's safe to unwrap here since we initialized the value
243          // previously.
244          let queue = lock.get_mut(&message.command).unwrap();
245          queue.push(metering_score);
246          let sleep_time = queue.sleep_time();
247          drop(lock);
248  
249          // Check if we need to sleep
250          if let Some(sleep_time) = sleep_time {
251              let sleep_time = 2 * sleep_time;
252              debug!(
253                  target: "net::channel::send()",
254                  "[P2P] Channel rate limit is active, sleeping before sending for: {sleep_time} (ms)"
255              );
256              msleep(sleep_time).await;
257          }
258  
259          // Check if the channel is stopped, so we can abort
260          if self.is_stopped() {
261              return Err(Error::ChannelStopped)
262          }
263  
264          // Catch failure and stop channel, return a net error
265          if let Err(e) = self.send_message(message).await {
266              if self.session.upgrade().unwrap().type_id() & (SESSION_ALL & !SESSION_REFINE) != 0 {
267                  error!(
268                      target: "net::channel::send()", "[P2P] Channel send error for [{self:?}]: {e}"
269                  );
270              }
271              self.stop().await;
272              return Err(Error::ChannelStopped)
273          }
274  
275          debug!(
276              target: "net::channel::send()", "[END] command={} {self:?}",
277              message.command
278          );
279  
280          Ok(())
281      }
282  
283      /// Sends the encoded payload of provided `SerializedMessage` by writing
284      /// the data to the channel async stream.
285      async fn send_message(&self, message: &SerializedMessage) -> Result<()> {
286          assert!(!message.command.is_empty());
287  
288          let stream = &mut *self.writer.lock().await;
289          let mut written: usize = 0;
290  
291          dnetev!(self, SendMessage, {
292              chan: self.info.clone(),
293              cmd: message.command.clone(),
294              time: NanoTimestamp::current_time(),
295          });
296  
297          trace!(target: "net::channel::send_message()", "Sending magic...");
298          let magic_bytes = self.p2p().settings().read().await.magic_bytes.0;
299          written += magic_bytes.encode_async(stream).await?;
300          trace!(target: "net::channel::send_message()", "Sent magic");
301  
302          trace!(target: "net::channel::send_message()", "Sending command...");
303          written += message.command.encode_async(stream).await?;
304          trace!(target: "net::channel::send_message()", "Sent command: {}", message.command);
305  
306          trace!(target: "net::channel::send_message()", "Sending payload...");
307          // First extract the length of the payload as a VarInt and write it to the stream.
308          written += VarInt(message.payload.len() as u64).encode_async(stream).await?;
309          // Then write the encoded payload itself to the stream.
310          stream.write_all(&message.payload).await?;
311          written += message.payload.len();
312  
313          trace!(target: "net::channel::send_message()", "Sent payload {} bytes, total bytes {written}",
314              message.payload.len());
315  
316          stream.flush().await?;
317  
318          Ok(())
319      }
320  
321      /// Returns a decoded Message command. We start by extracting the length
322      /// from the stream, then allocate the precise buffer for this length
323      /// using stream.take(). This manual deserialization provides a basic
324      /// DDOS protection, since it prevents nodes from sending an arbitarily
325      /// large payload.
326      pub async fn read_command<R: AsyncRead + Unpin + Send + Sized>(
327          &self,
328          stream: &mut R,
329      ) -> Result<String> {
330          // Messages should have a 4 byte header of magic digits.
331          // This is used for network debugging.
332          let mut magic = [0u8; 4];
333          trace!(target: "net::channel::read_command()", "Reading magic...");
334          stream.read_exact(&mut magic).await?;
335  
336          trace!(target: "net::channel::read_command()", "Read magic {magic:?}");
337          let magic_bytes = self.p2p().settings().read().await.magic_bytes.0;
338          if magic != magic_bytes {
339              error!(target: "net::channel::read_command", "Error: Magic bytes mismatch");
340              return Err(Error::MalformedPacket)
341          }
342  
343          // First extract the length from the stream
344          let cmd_len = VarInt::decode_async(stream).await?.0;
345          if cmd_len > (MAX_COMMAND_LENGTH as u64) {
346              error!(target: "net::channel::read_command",
347                  "Error: Command length ({cmd_len}) exceeds configured limit ({MAX_COMMAND_LENGTH}). Dropping...");
348              return Err(Error::MessageInvalid);
349          }
350  
351          // Then extract precisely `cmd_len` items from the stream.
352          let mut take = stream.take(cmd_len);
353  
354          // Deserialize into a vector of `cmd_len` size.
355          let mut bytes = vec![0; cmd_len.try_into().unwrap()];
356          take.read_exact(&mut bytes).await?;
357  
358          let command = String::from_utf8(bytes)?;
359  
360          Ok(command)
361      }
362  
363      /// Subscribe to a message on the message subsystem.
364      pub async fn subscribe_msg<M: message::Message>(&self) -> Result<MessageSubscription<M>> {
365          debug!(
366              target: "net::channel::subscribe_msg()", "[START] command={} {self:?}",
367              M::NAME
368          );
369  
370          let sub = self.message_subsystem.subscribe::<M>().await;
371  
372          debug!(
373              target: "net::channel::subscribe_msg()", "[END] command={} {self:?}",
374              M::NAME
375          );
376  
377          sub
378      }
379  
380      /// Handle network errors. Panic if error passes silently, otherwise
381      /// broadcast the error.
382      async fn handle_stop(self: Arc<Self>, result: Result<()>) {
383          debug!(target: "net::channel::handle_stop()", "[START] {self:?}");
384  
385          self.stopped.store(true, SeqCst);
386  
387          match result {
388              Ok(()) => panic!("Channel task should never complete without error status"),
389              // Send this error to all channel subscribers
390              Err(e) => {
391                  self.stop_publisher.notify(Error::ChannelStopped).await;
392                  self.message_subsystem.trigger_error(e).await;
393              }
394          }
395  
396          debug!(target: "net::channel::handle_stop()", "[END] {self:?}");
397      }
398  
399      /// Run the receive loop. Start receiving messages or handle network failure.
400      async fn main_receive_loop(self: Arc<Self>) -> Result<()> {
401          debug!(target: "net::channel::main_receive_loop()", "[START] {self:?}");
402  
403          // Acquire reader lock
404          let reader = &mut *self.reader.lock().await;
405  
406          // Run loop
407          loop {
408              let command = match self.read_command(reader).await {
409                  Ok(command) => command,
410                  Err(err) => {
411                      if Self::is_eof_error(&err) {
412                          verbose!(
413                              target: "net::channel::main_receive_loop()",
414                              "[P2P] Channel {} disconnected",
415                              self.display_address()
416                          );
417                      } else if let Error::MessageInvalid = err {
418                          // The command name length has exceeded the limit, this is possibly a malicious attack so ban it
419                          if let BanPolicy::Strict = self.p2p().settings().read().await.ban_policy {
420                              self.ban().await;
421                          }
422                      } else if self.session.upgrade().unwrap().type_id() &
423                          (SESSION_ALL & !SESSION_REFINE) !=
424                          0
425                      {
426                          error!(
427                              target: "net::channel::main_receive_loop()",
428                              "[P2P] Read error on channel {}: {err}",
429                              self.display_address()
430                          );
431                      }
432  
433                      debug!(
434                          target: "net::channel::main_receive_loop()",
435                          "Stopping channel {self:?}"
436                      );
437                      return Err(Error::ChannelStopped)
438                  }
439              };
440  
441              dnetev!(self, RecvMessage, {
442                  chan: self.info.clone(),
443                  cmd: command.clone(),
444                  time: NanoTimestamp::current_time(),
445              });
446  
447              // Send result to our publishers
448              match self.message_subsystem.notify(&command, reader).await {
449                  Ok(()) => {}
450                  Err(Error::MissingDispatcher) |
451                  Err(Error::MessageInvalid) |
452                  Err(Error::MeteringLimitExceeded) => {
453                      // If we're getting messages without dispatchers or its invalid,
454                      // it's spam. We therefore ban this channel if:
455                      //
456                      // 1) This channel is NOT part of a refine session.
457                      //
458                      // It's possible that nodes can send messages without
459                      // dispatchers during the refinery process. If that happens
460                      // we simply ignore it. Otherwise, it's spam.
461                      //
462                      // 2) BanPolicy is set to Strict.
463                      //
464                      // We only ban if the BanPolicy is set to Strict, which is
465                      // the default setting for most nodes. The exception to
466                      // this is a seed node like Lilith which has BanPolicy::Relaxed
467                      // since it regularly forms connections with nodes sending
468                      // messages it does not have dispatchers for.
469                      if self.session.upgrade().unwrap().type_id() != SESSION_REFINE {
470                          warn!(
471                          target: "net::channel::main_receive_loop()",
472                          "MissingDispatcher|MessageInvalid|MeteringLimitExceeded for command={command}, channel={self:?}"
473                          );
474  
475                          if let BanPolicy::Strict = self.p2p().settings().read().await.ban_policy {
476                              self.ban().await;
477                          }
478  
479                          return Err(Error::ChannelStopped)
480                      }
481                  }
482                  Err(_) => unreachable!("You added a new error in notify()"),
483              }
484          }
485      }
486  
487      /// Ban a malicious peer and stop the channel.
488      pub async fn ban(&self) {
489          debug!(target: "net::channel::ban()", "START {self:?}");
490          debug!(target: "net::channel::ban()", "Peer: {:?}", self.display_address());
491  
492          // Just store the hostname if this is an inbound session.
493          // This will block all ports from this peer by setting
494          // `hosts.block_all_ports()` to true.
495          let peer = {
496              if self.session_type_id() & SESSION_INBOUND != 0 {
497                  if self.address().host().is_none() {
498                      error!("[P2P] ban() caught Url without host: {:?}", self.display_address());
499                      return
500                  }
501  
502                  // An inbound Tor connection can't really be banned :)
503                  #[cfg(feature = "p2p-tor")]
504                  if (self.address().scheme() == "tor" || self.address().scheme() == "tor+tls") &&
505                      self.p2p().hosts().is_local_host(self.address())
506                  {
507                      return
508                  }
509  
510                  if self.address().scheme() == "unix" {
511                      return
512                  }
513  
514                  let mut addr = self.address().clone();
515                  addr.set_port(None).unwrap();
516                  addr
517              } else {
518                  self.address().clone()
519              }
520          };
521  
522          let last_seen = UNIX_EPOCH.elapsed().unwrap().as_secs();
523          verbose!(target: "net::channel::ban()", "Blacklisting peer={peer}");
524          match self.p2p().hosts().move_host(&peer, last_seen, HostColor::Black).await {
525              Ok(()) => {
526                  verbose!(target: "net::channel::ban()", "Peer={peer} blacklisted successfully");
527              }
528              Err(e) => {
529                  warn!(target: "net::channel::ban()", "Could not blacklisted peer={peer}, err={e}");
530              }
531          }
532          self.stop().await;
533          debug!(target: "net::channel::ban()", "STOP {self:?}");
534      }
535  
536      /// Returns the relevant socket address for this connection. If this is
537      /// an outbound connection, the transport-processed resolve_addr will
538      /// be returned except for transport mixed connections, to make sure
539      /// mixed hosts don't enter hostlist.
540      /// Otherwise for inbound connections it will default
541      /// to connect_addr.
542      pub fn address(&self) -> &Url {
543          if !self.info.transport_mixed {
544              if let Some(resolve_addr) = &self.info.resolve_addr {
545                  return resolve_addr
546              }
547          }
548          &self.info.connect_addr
549      }
550  
551      /// Returns the address used for UI purposes like in logging or tools like dnet.
552      /// For transport_mixed connection shows the mixed address.
553      pub fn display_address(&self) -> &Url {
554          self.info.resolve_addr.as_ref().unwrap_or(&self.info.connect_addr)
555      }
556  
557      /// Returns the socket address that has undergone transport
558      /// processing, if it exists. Returns None otherwise.
559      pub fn resolve_addr(&self) -> Option<Url> {
560          self.info.resolve_addr.clone()
561      }
562  
563      /// Return the socket address without transport processing.
564      pub fn connect_addr(&self) -> &Url {
565          &self.info.connect_addr
566      }
567  
568      /// Set the VersionMessage of the node this channel is connected
569      /// to. Called on receiving a version message in `ProtocolVersion`.
570      pub(crate) async fn set_version(&self, version: Arc<VersionMessage>) {
571          self.version.set(version).await.unwrap();
572      }
573      /// Should only be called after the version exchange has been completed.
574      pub fn get_version(&self) -> Arc<VersionMessage> {
575          self.version.get().unwrap().clone()
576      }
577  
578      /// Returns the inner [`MessageSubsystem`] reference
579      pub fn message_subsystem(&self) -> &MessageSubsystem {
580          &self.message_subsystem
581      }
582  
583      fn session(&self) -> Arc<dyn Session> {
584          self.session.upgrade().unwrap()
585      }
586  
587      pub fn session_type_id(&self) -> SessionBitFlag {
588          let session = self.session();
589          session.type_id()
590      }
591  
592      #[inline]
593      pub fn p2p(&self) -> P2pPtr {
594          self.session().p2p()
595      }
596      #[inline]
597      pub fn hosts(&self) -> HostsPtr {
598          self.p2p().hosts()
599      }
600  
601      fn is_eof_error(err: &Error) -> bool {
602          match err {
603              Error::Io(ioerr) => ioerr == &std::io::ErrorKind::UnexpectedEof,
604              _ => false,
605          }
606      }
607  }
608  
609  impl fmt::Debug for Channel {
610      fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
611          write!(f, "<Channel addr='{}' id={}>", self.display_address(), self.info.id)
612      }
613  }