codec.rs
1 // Copyright (c) 2025 ADnet Contributors 2 // This file is part of the AlphaOS library. 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at: 7 8 // http://www.apache.org/licenses/LICENSE-2.0 9 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 use crate::{bft::events::Event, bootstrap_client::network::MessageOrEvent, router::messages::Message}; 17 use alphavm::prelude::{FromBytes, Network, ToBytes}; 18 19 use bytes::{BufMut, BytesMut}; 20 use core::marker::PhantomData; 21 use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec}; 22 23 /// The maximum size of a message that can be transmitted during the handshake. 24 const MAX_HANDSHAKE_SIZE: usize = 1024 * 1024; // 1 MiB 25 /// The maximum size of a post-handshake message that can be obtained from the network. 26 const MAX_POST_HANDSHAKE_SIZE: usize = 2 * 1024 * 1024; // 2 MiB 27 28 /// The codec used to decode and encode network messages. 29 pub struct BootstrapClientCodec<N: Network> { 30 codec: LengthDelimitedCodec, 31 _phantom: PhantomData<N>, 32 } 33 34 impl<N: Network> BootstrapClientCodec<N> { 35 pub fn handshake() -> Self { 36 let mut codec = Self::default(); 37 codec.codec.set_max_frame_length(MAX_HANDSHAKE_SIZE); 38 codec 39 } 40 } 41 42 impl<N: Network> Default for BootstrapClientCodec<N> { 43 fn default() -> Self { 44 Self { 45 codec: LengthDelimitedCodec::builder() 46 .max_frame_length(MAX_POST_HANDSHAKE_SIZE) 47 .little_endian() 48 .new_codec(), 49 _phantom: Default::default(), 50 } 51 } 52 } 53 54 impl<N: Network> Encoder<Message<N>> for BootstrapClientCodec<N> { 55 type Error = std::io::Error; 56 57 fn encode(&mut self, message: Message<N>, dst: &mut BytesMut) -> Result<(), Self::Error> { 58 // Serialize the payload directly into dst. 59 message 60 .write_le(&mut dst.writer()) 61 // This error should never happen, the conversion is for greater compatibility. 62 .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "serialization error"))?; 63 64 let serialized_message = dst.split_to(dst.len()).freeze(); 65 66 self.codec.encode(serialized_message, dst) 67 } 68 } 69 70 impl<N: Network> Encoder<Event<N>> for BootstrapClientCodec<N> { 71 type Error = std::io::Error; 72 73 fn encode(&mut self, event: Event<N>, dst: &mut BytesMut) -> Result<(), Self::Error> { 74 // Serialize the payload directly into dst. 75 event 76 .write_le(&mut dst.writer()) 77 // This error should never happen, the conversion is for greater compatibility. 78 .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "serialization error"))?; 79 80 let serialized_event = dst.split_to(dst.len()).freeze(); 81 82 self.codec.encode(serialized_event, dst) 83 } 84 } 85 86 impl<N: Network> Encoder<MessageOrEvent<N>> for BootstrapClientCodec<N> { 87 type Error = std::io::Error; 88 89 fn encode(&mut self, item: MessageOrEvent<N>, dst: &mut BytesMut) -> Result<(), Self::Error> { 90 // Serialize the payload directly into dst. 91 match item { 92 MessageOrEvent::Message(message) => self.encode(message, dst), 93 MessageOrEvent::Event(event) => self.encode(event, dst), 94 } 95 } 96 } 97 98 impl<N: Network> Decoder for BootstrapClientCodec<N> { 99 type Error = std::io::Error; 100 type Item = MessageOrEvent<N>; 101 102 fn decode(&mut self, source: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { 103 // Decode a frame containing bytes belonging to a message. 104 let bytes = match self.codec.decode(source)? { 105 Some(bytes) => bytes, 106 None => return Ok(None), 107 }; 108 109 // Reject invalid/truncated messages. 110 if bytes.len() < 2 { 111 warn!("Failed to deserialize a message: too short"); 112 return Err(std::io::ErrorKind::InvalidData.into()); 113 } 114 115 // Check the ID of the serialized Message or Event. 116 let message_id = u16::from_le_bytes(bytes[..2].try_into().unwrap()); 117 118 // Discard messages that aren't of interest to a bootstrapper node. 119 match message_id { 120 2..=5 => match Message::read_le(&bytes[..]) { 121 Ok(message) => Ok(Some(MessageOrEvent::Message(message))), 122 Err(error) => { 123 warn!("Failed to deserialize a message: {error}"); 124 Err(std::io::ErrorKind::InvalidData.into()) 125 } 126 }, 127 7..=9 | 13 => match Event::read_le(&bytes[..]) { 128 Ok(event) => Ok(Some(MessageOrEvent::Event(event))), 129 Err(error) => { 130 warn!("Failed to deserialize a message: {error}"); 131 Err(std::io::ErrorKind::InvalidData.into()) 132 } 133 }, 134 id => { 135 trace!("Ignoring an unhandled message (ID {id})"); 136 Ok(None) 137 } 138 } 139 } 140 }