channel.rs
1 /* This file is part of DarkFi (https://dark.fi) 2 * 3 * Copyright (C) 2020-2025 Dyne.org foundation 4 * 5 * This program is free software: you can redistribute it and/or modify 6 * it under the terms of the GNU Affero General Public License as 7 * published by the Free Software Foundation, either version 3 of the 8 * License, or (at your option) any later version. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU Affero General Public License for more details. 14 * 15 * You should have received a copy of the GNU Affero General Public License 16 * along with this program. If not, see <https://www.gnu.org/licenses/>. 17 */ 18 19 use std::{ 20 collections::HashMap, 21 fmt, 22 sync::{ 23 atomic::{AtomicBool, Ordering::SeqCst}, 24 Arc, 25 }, 26 time::UNIX_EPOCH, 27 }; 28 29 use darkfi_serial::{ 30 async_trait, AsyncDecodable, AsyncEncodable, SerialDecodable, SerialEncodable, VarInt, 31 }; 32 use rand::{rngs::OsRng, Rng}; 33 use smol::{ 34 io::{self, AsyncRead, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, 35 lock::{Mutex as AsyncMutex, OnceCell}, 36 Executor, 37 }; 38 use tracing::{debug, error, trace, warn}; 39 use url::Url; 40 41 use super::{ 42 dnet::{self, dnetev, DnetEvent}, 43 hosts::{HostColor, HostsPtr}, 44 message, 45 message::{SerializedMessage, VersionMessage, MAX_COMMAND_LENGTH}, 46 message_publisher::{MessageSubscription, MessageSubsystem}, 47 metering::{MeteringConfiguration, MeteringQueue}, 48 p2p::P2pPtr, 49 session::{ 50 Session, SessionBitFlag, SessionWeakPtr, SESSION_ALL, SESSION_INBOUND, SESSION_REFINE, 51 }, 52 transport::PtStream, 53 }; 54 use crate::{ 55 net::BanPolicy, 56 system::{msleep, Publisher, PublisherPtr, StoppableTask, StoppableTaskPtr, Subscription}, 57 util::{logger::verbose, time::NanoTimestamp}, 58 Error, Result, 59 }; 60 61 /// Atomic pointer to async channel 62 pub type ChannelPtr = Arc<Channel>; 63 64 /// Channel debug info 65 #[derive(Clone, Debug, SerialEncodable, SerialDecodable)] 66 pub struct ChannelInfo { 67 pub resolve_addr: Option<Url>, 68 pub connect_addr: Url, 69 pub start_time: u64, 70 pub id: u32, 71 pub transport_mixed: bool, 72 } 73 74 impl ChannelInfo { 75 fn new( 76 resolve_addr: Option<Url>, 77 connect_addr: Url, 78 start_time: u64, 79 transport_mixed: bool, 80 ) -> Self { 81 Self { resolve_addr, connect_addr, start_time, id: OsRng.gen(), transport_mixed } 82 } 83 } 84 85 /// Async channel for communication between nodes. 86 pub struct Channel { 87 /// The reading half of the transport stream 88 reader: AsyncMutex<ReadHalf<Box<dyn PtStream>>>, 89 /// The writing half of the transport stream 90 writer: AsyncMutex<WriteHalf<Box<dyn PtStream>>>, 91 /// The message subsystem instance for this channel 92 message_subsystem: MessageSubsystem, 93 /// Publisher listening for stop signal for closing this channel 94 stop_publisher: PublisherPtr<Error>, 95 /// Task that is listening for the stop signal 96 receive_task: StoppableTaskPtr, 97 /// A boolean marking if this channel is stopped 98 stopped: AtomicBool, 99 /// Weak pointer to respective session 100 pub(in crate::net) session: SessionWeakPtr, 101 /// The version message of the node we are connected to. 102 /// Some if the version exchange has already occurred, None 103 /// otherwise. 104 pub version: OnceCell<Arc<VersionMessage>>, 105 /// Channel debug info 106 pub info: ChannelInfo, 107 /// Map holding a `MeteringQueue` for each [`Message`] to perform 108 /// rate limiting of propagation towards the stream. 109 metering_map: AsyncMutex<HashMap<String, MeteringQueue>>, 110 } 111 112 impl Channel { 113 /// Sets up a new channel. Creates a reader and writer [`PtStream`] and 114 /// the message publisher subsystem. Performs a network handshake on the 115 /// subsystem dispatchers. 116 pub async fn new( 117 stream: Box<dyn PtStream>, 118 resolve_addr: Option<Url>, 119 connect_addr: Url, 120 session: SessionWeakPtr, 121 transport_mixed: bool, 122 ) -> Arc<Self> { 123 let (reader, writer) = io::split(stream); 124 let reader = AsyncMutex::new(reader); 125 let writer = AsyncMutex::new(writer); 126 127 let message_subsystem = MessageSubsystem::new(); 128 Self::setup_dispatchers(&message_subsystem).await; 129 130 let start_time = UNIX_EPOCH.elapsed().unwrap().as_secs(); 131 let info = 132 ChannelInfo::new(resolve_addr, connect_addr.clone(), start_time, transport_mixed); 133 let metering_map = AsyncMutex::new(HashMap::new()); 134 135 Arc::new(Self { 136 reader, 137 writer, 138 message_subsystem, 139 stop_publisher: Publisher::new(), 140 receive_task: StoppableTask::new(), 141 stopped: AtomicBool::new(false), 142 session, 143 version: OnceCell::new(), 144 info, 145 metering_map, 146 }) 147 } 148 149 /// Perform network handshake for message subsystem dispatchers. 150 async fn setup_dispatchers(subsystem: &MessageSubsystem) { 151 subsystem.add_dispatch::<message::VersionMessage>().await; 152 subsystem.add_dispatch::<message::VerackMessage>().await; 153 subsystem.add_dispatch::<message::PingMessage>().await; 154 subsystem.add_dispatch::<message::PongMessage>().await; 155 subsystem.add_dispatch::<message::GetAddrsMessage>().await; 156 subsystem.add_dispatch::<message::AddrsMessage>().await; 157 } 158 159 /// Starts the channel. Runs a receive loop to start receiving messages 160 /// or handles a network failure. 161 pub fn start(self: Arc<Self>, executor: Arc<Executor<'_>>) { 162 debug!(target: "net::channel::start()", "START {self:?}"); 163 164 let self_ = self.clone(); 165 self.receive_task.clone().start( 166 self.clone().main_receive_loop(), 167 |result| self_.handle_stop(result), 168 Error::ChannelStopped, 169 executor, 170 ); 171 172 debug!(target: "net::channel::start()", "END {self:?}"); 173 } 174 175 /// Stops the channel. 176 /// Notifies all publishers that the channel has been closed in `handle_stop()`. 177 pub async fn stop(&self) { 178 debug!(target: "net::channel::stop()", "START {self:?}"); 179 self.receive_task.stop().await; 180 debug!(target: "net::channel::stop()", "END {self:?}"); 181 } 182 183 /// Creates a subscription to a stopped signal. 184 /// If the channel is stopped then this will return a ChannelStopped error. 185 pub async fn subscribe_stop(&self) -> Result<Subscription<Error>> { 186 debug!(target: "net::channel::subscribe_stop()", "START {self:?}"); 187 188 if self.is_stopped() { 189 return Err(Error::ChannelStopped) 190 } 191 192 let sub = self.stop_publisher.clone().subscribe().await; 193 194 debug!(target: "net::channel::subscribe_stop()", "END {self:?}"); 195 196 Ok(sub) 197 } 198 199 pub fn is_stopped(&self) -> bool { 200 self.stopped.load(SeqCst) 201 } 202 203 /// Sends a message across a channel. First it converts the message 204 /// into a `SerializedMessage` and then calls `send_serialized` to send it. 205 /// Returns an error if something goes wrong. 206 pub async fn send<M: message::Message>(&self, message: &M) -> Result<()> { 207 self.send_serialized( 208 &SerializedMessage::new(message).await, 209 &M::METERING_SCORE, 210 &M::METERING_CONFIGURATION, 211 ) 212 .await 213 } 214 215 /// Sends the encoded payload of provided `SerializedMessage` across the channel. 216 /// 217 /// We first check if we should apply some throttling, based on the provided 218 /// `Message` configuration. We always sleep 2x times more than the expected one, 219 /// so we don't flood the peer. 220 /// Then, calls `send_message` that creates a new payload and sends it over the 221 /// network transport as a packet. 222 /// Returns an error if something goes wrong. 223 pub async fn send_serialized( 224 &self, 225 message: &SerializedMessage, 226 metering_score: &u64, 227 metering_config: &MeteringConfiguration, 228 ) -> Result<()> { 229 debug!( 230 target: "net::channel::send()", "[START] command={} {self:?}", 231 message.command, 232 ); 233 234 // Check if we need to initialize a `MeteringQueue` 235 // for this specific `Message`. 236 let mut lock = self.metering_map.lock().await; 237 if !lock.contains_key(&message.command) { 238 lock.insert(message.command.clone(), MeteringQueue::new(metering_config.clone())); 239 } 240 241 // Insert metering information and grab potential sleep time. 242 // It's safe to unwrap here since we initialized the value 243 // previously. 244 let queue = lock.get_mut(&message.command).unwrap(); 245 queue.push(metering_score); 246 let sleep_time = queue.sleep_time(); 247 drop(lock); 248 249 // Check if we need to sleep 250 if let Some(sleep_time) = sleep_time { 251 let sleep_time = 2 * sleep_time; 252 debug!( 253 target: "net::channel::send()", 254 "[P2P] Channel rate limit is active, sleeping before sending for: {sleep_time} (ms)" 255 ); 256 msleep(sleep_time).await; 257 } 258 259 // Check if the channel is stopped, so we can abort 260 if self.is_stopped() { 261 return Err(Error::ChannelStopped) 262 } 263 264 // Catch failure and stop channel, return a net error 265 if let Err(e) = self.send_message(message).await { 266 if self.session.upgrade().unwrap().type_id() & (SESSION_ALL & !SESSION_REFINE) != 0 { 267 error!( 268 target: "net::channel::send()", "[P2P] Channel send error for [{self:?}]: {e}" 269 ); 270 } 271 self.stop().await; 272 return Err(Error::ChannelStopped) 273 } 274 275 debug!( 276 target: "net::channel::send()", "[END] command={} {self:?}", 277 message.command 278 ); 279 280 Ok(()) 281 } 282 283 /// Sends the encoded payload of provided `SerializedMessage` by writing 284 /// the data to the channel async stream. 285 async fn send_message(&self, message: &SerializedMessage) -> Result<()> { 286 assert!(!message.command.is_empty()); 287 288 let stream = &mut *self.writer.lock().await; 289 let mut written: usize = 0; 290 291 dnetev!(self, SendMessage, { 292 chan: self.info.clone(), 293 cmd: message.command.clone(), 294 time: NanoTimestamp::current_time(), 295 }); 296 297 trace!(target: "net::channel::send_message()", "Sending magic..."); 298 let magic_bytes = self.p2p().settings().read().await.magic_bytes.0; 299 written += magic_bytes.encode_async(stream).await?; 300 trace!(target: "net::channel::send_message()", "Sent magic"); 301 302 trace!(target: "net::channel::send_message()", "Sending command..."); 303 written += message.command.encode_async(stream).await?; 304 trace!(target: "net::channel::send_message()", "Sent command: {}", message.command); 305 306 trace!(target: "net::channel::send_message()", "Sending payload..."); 307 // First extract the length of the payload as a VarInt and write it to the stream. 308 written += VarInt(message.payload.len() as u64).encode_async(stream).await?; 309 // Then write the encoded payload itself to the stream. 310 stream.write_all(&message.payload).await?; 311 written += message.payload.len(); 312 313 trace!(target: "net::channel::send_message()", "Sent payload {} bytes, total bytes {written}", 314 message.payload.len()); 315 316 stream.flush().await?; 317 318 Ok(()) 319 } 320 321 /// Returns a decoded Message command. We start by extracting the length 322 /// from the stream, then allocate the precise buffer for this length 323 /// using stream.take(). This manual deserialization provides a basic 324 /// DDOS protection, since it prevents nodes from sending an arbitarily 325 /// large payload. 326 pub async fn read_command<R: AsyncRead + Unpin + Send + Sized>( 327 &self, 328 stream: &mut R, 329 ) -> Result<String> { 330 // Messages should have a 4 byte header of magic digits. 331 // This is used for network debugging. 332 let mut magic = [0u8; 4]; 333 trace!(target: "net::channel::read_command()", "Reading magic..."); 334 stream.read_exact(&mut magic).await?; 335 336 trace!(target: "net::channel::read_command()", "Read magic {magic:?}"); 337 let magic_bytes = self.p2p().settings().read().await.magic_bytes.0; 338 if magic != magic_bytes { 339 error!(target: "net::channel::read_command", "Error: Magic bytes mismatch"); 340 return Err(Error::MalformedPacket) 341 } 342 343 // First extract the length from the stream 344 let cmd_len = VarInt::decode_async(stream).await?.0; 345 if cmd_len > (MAX_COMMAND_LENGTH as u64) { 346 error!(target: "net::channel::read_command", 347 "Error: Command length ({cmd_len}) exceeds configured limit ({MAX_COMMAND_LENGTH}). Dropping..."); 348 return Err(Error::MessageInvalid); 349 } 350 351 // Then extract precisely `cmd_len` items from the stream. 352 let mut take = stream.take(cmd_len); 353 354 // Deserialize into a vector of `cmd_len` size. 355 let mut bytes = vec![0; cmd_len.try_into().unwrap()]; 356 take.read_exact(&mut bytes).await?; 357 358 let command = String::from_utf8(bytes)?; 359 360 Ok(command) 361 } 362 363 /// Subscribe to a message on the message subsystem. 364 pub async fn subscribe_msg<M: message::Message>(&self) -> Result<MessageSubscription<M>> { 365 debug!( 366 target: "net::channel::subscribe_msg()", "[START] command={} {self:?}", 367 M::NAME 368 ); 369 370 let sub = self.message_subsystem.subscribe::<M>().await; 371 372 debug!( 373 target: "net::channel::subscribe_msg()", "[END] command={} {self:?}", 374 M::NAME 375 ); 376 377 sub 378 } 379 380 /// Handle network errors. Panic if error passes silently, otherwise 381 /// broadcast the error. 382 async fn handle_stop(self: Arc<Self>, result: Result<()>) { 383 debug!(target: "net::channel::handle_stop()", "[START] {self:?}"); 384 385 self.stopped.store(true, SeqCst); 386 387 match result { 388 Ok(()) => panic!("Channel task should never complete without error status"), 389 // Send this error to all channel subscribers 390 Err(e) => { 391 self.stop_publisher.notify(Error::ChannelStopped).await; 392 self.message_subsystem.trigger_error(e).await; 393 } 394 } 395 396 debug!(target: "net::channel::handle_stop()", "[END] {self:?}"); 397 } 398 399 /// Run the receive loop. Start receiving messages or handle network failure. 400 async fn main_receive_loop(self: Arc<Self>) -> Result<()> { 401 debug!(target: "net::channel::main_receive_loop()", "[START] {self:?}"); 402 403 // Acquire reader lock 404 let reader = &mut *self.reader.lock().await; 405 406 // Run loop 407 loop { 408 let command = match self.read_command(reader).await { 409 Ok(command) => command, 410 Err(err) => { 411 if Self::is_eof_error(&err) { 412 verbose!( 413 target: "net::channel::main_receive_loop()", 414 "[P2P] Channel {} disconnected", 415 self.display_address() 416 ); 417 } else if let Error::MessageInvalid = err { 418 // The command name length has exceeded the limit, this is possibly a malicious attack so ban it 419 if let BanPolicy::Strict = self.p2p().settings().read().await.ban_policy { 420 self.ban().await; 421 } 422 } else if self.session.upgrade().unwrap().type_id() & 423 (SESSION_ALL & !SESSION_REFINE) != 424 0 425 { 426 error!( 427 target: "net::channel::main_receive_loop()", 428 "[P2P] Read error on channel {}: {err}", 429 self.display_address() 430 ); 431 } 432 433 debug!( 434 target: "net::channel::main_receive_loop()", 435 "Stopping channel {self:?}" 436 ); 437 return Err(Error::ChannelStopped) 438 } 439 }; 440 441 dnetev!(self, RecvMessage, { 442 chan: self.info.clone(), 443 cmd: command.clone(), 444 time: NanoTimestamp::current_time(), 445 }); 446 447 // Send result to our publishers 448 match self.message_subsystem.notify(&command, reader).await { 449 Ok(()) => {} 450 Err(Error::MissingDispatcher) | 451 Err(Error::MessageInvalid) | 452 Err(Error::MeteringLimitExceeded) => { 453 // If we're getting messages without dispatchers or its invalid, 454 // it's spam. We therefore ban this channel if: 455 // 456 // 1) This channel is NOT part of a refine session. 457 // 458 // It's possible that nodes can send messages without 459 // dispatchers during the refinery process. If that happens 460 // we simply ignore it. Otherwise, it's spam. 461 // 462 // 2) BanPolicy is set to Strict. 463 // 464 // We only ban if the BanPolicy is set to Strict, which is 465 // the default setting for most nodes. The exception to 466 // this is a seed node like Lilith which has BanPolicy::Relaxed 467 // since it regularly forms connections with nodes sending 468 // messages it does not have dispatchers for. 469 if self.session.upgrade().unwrap().type_id() != SESSION_REFINE { 470 warn!( 471 target: "net::channel::main_receive_loop()", 472 "MissingDispatcher|MessageInvalid|MeteringLimitExceeded for command={command}, channel={self:?}" 473 ); 474 475 if let BanPolicy::Strict = self.p2p().settings().read().await.ban_policy { 476 self.ban().await; 477 } 478 479 return Err(Error::ChannelStopped) 480 } 481 } 482 Err(_) => unreachable!("You added a new error in notify()"), 483 } 484 } 485 } 486 487 /// Ban a malicious peer and stop the channel. 488 pub async fn ban(&self) { 489 debug!(target: "net::channel::ban()", "START {self:?}"); 490 debug!(target: "net::channel::ban()", "Peer: {:?}", self.display_address()); 491 492 // Just store the hostname if this is an inbound session. 493 // This will block all ports from this peer by setting 494 // `hosts.block_all_ports()` to true. 495 let peer = { 496 if self.session_type_id() & SESSION_INBOUND != 0 { 497 if self.address().host().is_none() { 498 error!("[P2P] ban() caught Url without host: {:?}", self.display_address()); 499 return 500 } 501 502 // An inbound Tor connection can't really be banned :) 503 #[cfg(feature = "p2p-tor")] 504 if (self.address().scheme() == "tor" || self.address().scheme() == "tor+tls") && 505 self.p2p().hosts().is_local_host(self.address()) 506 { 507 return 508 } 509 510 if self.address().scheme() == "unix" { 511 return 512 } 513 514 let mut addr = self.address().clone(); 515 addr.set_port(None).unwrap(); 516 addr 517 } else { 518 self.address().clone() 519 } 520 }; 521 522 let last_seen = UNIX_EPOCH.elapsed().unwrap().as_secs(); 523 verbose!(target: "net::channel::ban()", "Blacklisting peer={peer}"); 524 match self.p2p().hosts().move_host(&peer, last_seen, HostColor::Black).await { 525 Ok(()) => { 526 verbose!(target: "net::channel::ban()", "Peer={peer} blacklisted successfully"); 527 } 528 Err(e) => { 529 warn!(target: "net::channel::ban()", "Could not blacklisted peer={peer}, err={e}"); 530 } 531 } 532 self.stop().await; 533 debug!(target: "net::channel::ban()", "STOP {self:?}"); 534 } 535 536 /// Returns the relevant socket address for this connection. If this is 537 /// an outbound connection, the transport-processed resolve_addr will 538 /// be returned except for transport mixed connections, to make sure 539 /// mixed hosts don't enter hostlist. 540 /// Otherwise for inbound connections it will default 541 /// to connect_addr. 542 pub fn address(&self) -> &Url { 543 if !self.info.transport_mixed { 544 if let Some(resolve_addr) = &self.info.resolve_addr { 545 return resolve_addr 546 } 547 } 548 &self.info.connect_addr 549 } 550 551 /// Returns the address used for UI purposes like in logging or tools like dnet. 552 /// For transport_mixed connection shows the mixed address. 553 pub fn display_address(&self) -> &Url { 554 self.info.resolve_addr.as_ref().unwrap_or(&self.info.connect_addr) 555 } 556 557 /// Returns the socket address that has undergone transport 558 /// processing, if it exists. Returns None otherwise. 559 pub fn resolve_addr(&self) -> Option<Url> { 560 self.info.resolve_addr.clone() 561 } 562 563 /// Return the socket address without transport processing. 564 pub fn connect_addr(&self) -> &Url { 565 &self.info.connect_addr 566 } 567 568 /// Set the VersionMessage of the node this channel is connected 569 /// to. Called on receiving a version message in `ProtocolVersion`. 570 pub(crate) async fn set_version(&self, version: Arc<VersionMessage>) { 571 self.version.set(version).await.unwrap(); 572 } 573 /// Should only be called after the version exchange has been completed. 574 pub fn get_version(&self) -> Arc<VersionMessage> { 575 self.version.get().unwrap().clone() 576 } 577 578 /// Returns the inner [`MessageSubsystem`] reference 579 pub fn message_subsystem(&self) -> &MessageSubsystem { 580 &self.message_subsystem 581 } 582 583 fn session(&self) -> Arc<dyn Session> { 584 self.session.upgrade().unwrap() 585 } 586 587 pub fn session_type_id(&self) -> SessionBitFlag { 588 let session = self.session(); 589 session.type_id() 590 } 591 592 #[inline] 593 pub fn p2p(&self) -> P2pPtr { 594 self.session().p2p() 595 } 596 #[inline] 597 pub fn hosts(&self) -> HostsPtr { 598 self.p2p().hosts() 599 } 600 601 fn is_eof_error(err: &Error) -> bool { 602 match err { 603 Error::Io(ioerr) => ioerr == &std::io::ErrorKind::UnexpectedEof, 604 _ => false, 605 } 606 } 607 } 608 609 impl fmt::Debug for Channel { 610 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 611 write!(f, "<Channel addr='{}' id={}>", self.display_address(), self.info.id) 612 } 613 }