client.rs
1 //! Relay server client. 2 //! 3 //! This module provides a WebSocket-based client for communicating with 4 //! relay servers. 5 6 use std::time::Duration; 7 8 use futures_util::{SinkExt, StreamExt}; 9 use tokio::net::TcpStream; 10 use tokio::sync::Mutex; 11 use tokio::time::timeout; 12 use tokio_tungstenite::{ 13 connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream, 14 }; 15 16 use super::protocol::{ 17 GossipForwardedMessage, RelayErrorCode, RelayMessage, StoredMessage, MAX_PAYLOAD_SIZE, 18 }; 19 20 /// Default connection timeout in seconds. 21 const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 30; 22 23 /// Default operation timeout in seconds. 24 const DEFAULT_OPERATION_TIMEOUT_SECS: u64 = 60; 25 26 /// Errors that can occur during relay operations. 27 #[derive(Debug, Clone, PartialEq, Eq)] 28 pub enum RelayError { 29 /// Failed to connect to the relay server. 30 ConnectionFailed(String), 31 /// Connection was closed unexpectedly. 32 Disconnected, 33 /// Operation timed out. 34 Timeout, 35 /// Invalid message format. 36 InvalidMessage(String), 37 /// Payload exceeds maximum size. 38 PayloadTooLarge, 39 /// Server returned an error. 40 ServerError { 41 code: RelayErrorCode, 42 message: String, 43 }, 44 /// Not connected to a relay. 45 NotConnected, 46 /// Not registered with the relay. 47 NotRegistered, 48 /// Internal error. 49 Internal(String), 50 } 51 52 impl std::fmt::Display for RelayError { 53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 54 match self { 55 RelayError::ConnectionFailed(msg) => write!(f, "Connection failed: {}", msg), 56 RelayError::Disconnected => write!(f, "Disconnected from relay"), 57 RelayError::Timeout => write!(f, "Operation timed out"), 58 RelayError::InvalidMessage(msg) => write!(f, "Invalid message: {}", msg), 59 RelayError::PayloadTooLarge => write!(f, "Payload exceeds maximum size"), 60 RelayError::ServerError { code, message } => { 61 write!(f, "Server error ({:?}): {}", code, message) 62 } 63 RelayError::NotConnected => write!(f, "Not connected to relay"), 64 RelayError::NotRegistered => write!(f, "Not registered with relay"), 65 RelayError::Internal(msg) => write!(f, "Internal error: {}", msg), 66 } 67 } 68 } 69 70 impl std::error::Error for RelayError {} 71 72 /// Result type for relay operations. 73 pub type RelayResult<T> = Result<T, RelayError>; 74 75 /// Client for communicating with a relay server. 76 /// 77 /// The client maintains a WebSocket connection and handles the relay protocol. 78 pub struct RelayClient { 79 /// Relay server URL. 80 url: String, 81 /// WebSocket connection (if connected). 82 connection: Mutex<Option<WebSocketStream<MaybeTlsStream<TcpStream>>>>, 83 /// Assigned mailbox ID (if registered). 84 mailbox_id: Mutex<Option<[u8; 16]>>, 85 /// Our public key (for registration). 86 public_key: [u8; 32], 87 /// Connection timeout. 88 connect_timeout: Duration, 89 /// Operation timeout. 90 operation_timeout: Duration, 91 } 92 93 impl RelayClient { 94 /// Create a new relay client. 95 /// 96 /// # Arguments 97 /// 98 /// * `url` - WebSocket URL of the relay server (e.g., "wss://relay.example.com") 99 /// * `public_key` - Our X25519 public key for registration 100 pub fn new(url: &str, public_key: [u8; 32]) -> Self { 101 Self { 102 url: url.to_string(), 103 connection: Mutex::new(None), 104 mailbox_id: Mutex::new(None), 105 public_key, 106 connect_timeout: Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS), 107 operation_timeout: Duration::from_secs(DEFAULT_OPERATION_TIMEOUT_SECS), 108 } 109 } 110 111 /// Set the connection timeout. 112 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self { 113 self.connect_timeout = timeout; 114 self 115 } 116 117 /// Set the operation timeout. 118 pub fn with_operation_timeout(mut self, timeout: Duration) -> Self { 119 self.operation_timeout = timeout; 120 self 121 } 122 123 /// Connect to the relay server. 124 pub async fn connect(&self) -> RelayResult<()> { 125 let connect_future = connect_async(&self.url); 126 127 let (ws_stream, _response) = timeout(self.connect_timeout, connect_future) 128 .await 129 .map_err(|_| RelayError::Timeout)? 130 .map_err(|e| RelayError::ConnectionFailed(e.to_string()))?; 131 132 let mut conn = self.connection.lock().await; 133 *conn = Some(ws_stream); 134 135 Ok(()) 136 } 137 138 /// Disconnect from the relay server. 139 pub async fn disconnect(&self) -> RelayResult<()> { 140 let mut conn = self.connection.lock().await; 141 if let Some(mut ws) = conn.take() { 142 let _ = ws.close(None).await; 143 } 144 145 let mut mailbox = self.mailbox_id.lock().await; 146 *mailbox = None; 147 148 Ok(()) 149 } 150 151 /// Check if connected to the relay. 152 pub async fn is_connected(&self) -> bool { 153 self.connection.lock().await.is_some() 154 } 155 156 /// Register with the relay to create a mailbox. 157 /// 158 /// Returns the assigned mailbox ID. 159 pub async fn register(&self) -> RelayResult<[u8; 16]> { 160 let msg = RelayMessage::Register { 161 public_key: self.public_key, 162 }; 163 164 let response = self.send_and_receive(msg).await?; 165 166 match response { 167 RelayMessage::RegisterAck { mailbox_id } => { 168 let mut mailbox = self.mailbox_id.lock().await; 169 *mailbox = Some(mailbox_id); 170 Ok(mailbox_id) 171 } 172 RelayMessage::Error { code, message } => { 173 Err(RelayError::ServerError { code, message }) 174 } 175 _ => Err(RelayError::InvalidMessage( 176 "Expected RegisterAck response".to_string(), 177 )), 178 } 179 } 180 181 /// Send a message to a recipient via the relay. 182 /// 183 /// # Arguments 184 /// 185 /// * `recipient_key` - Recipient's X25519 public key 186 /// * `payload` - Encrypted message payload 187 /// 188 /// # Returns 189 /// 190 /// The message ID assigned by the relay. 191 pub async fn send(&self, recipient_key: [u8; 32], payload: &[u8]) -> RelayResult<[u8; 16]> { 192 if payload.len() > MAX_PAYLOAD_SIZE { 193 return Err(RelayError::PayloadTooLarge); 194 } 195 196 let msg = RelayMessage::Send { 197 recipient_key, 198 payload: payload.to_vec(), 199 }; 200 201 let response = self.send_and_receive(msg).await?; 202 203 match response { 204 RelayMessage::SendAck { message_id } => Ok(message_id), 205 RelayMessage::Error { code, message } => { 206 Err(RelayError::ServerError { code, message }) 207 } 208 _ => Err(RelayError::InvalidMessage( 209 "Expected SendAck response".to_string(), 210 )), 211 } 212 } 213 214 /// Fetch messages from the mailbox. 215 /// 216 /// # Arguments 217 /// 218 /// * `since` - Unix timestamp - fetch messages newer than this (0 for all) 219 /// 220 /// # Returns 221 /// 222 /// List of stored messages. 223 pub async fn fetch(&self, since: u64) -> RelayResult<Vec<StoredMessage>> { 224 if self.mailbox_id.lock().await.is_none() { 225 return Err(RelayError::NotRegistered); 226 } 227 228 let msg = RelayMessage::Fetch { since }; 229 230 let response = self.send_and_receive(msg).await?; 231 232 match response { 233 RelayMessage::Messages { messages } => Ok(messages), 234 RelayMessage::Error { code, message } => { 235 Err(RelayError::ServerError { code, message }) 236 } 237 _ => Err(RelayError::InvalidMessage( 238 "Expected Messages response".to_string(), 239 )), 240 } 241 } 242 243 /// Acknowledge receipt of messages. 244 /// 245 /// This allows the relay to delete the acknowledged messages. 246 pub async fn acknowledge(&self, message_ids: Vec<[u8; 16]>) -> RelayResult<()> { 247 if message_ids.is_empty() { 248 return Ok(()); 249 } 250 251 let msg = RelayMessage::Ack { message_ids }; 252 253 // For ack, we don't expect a response 254 self.send_message(msg).await 255 } 256 257 /// Send a ping to keep the connection alive. 258 pub async fn ping(&self) -> RelayResult<()> { 259 let msg = RelayMessage::Ping; 260 let response = self.send_and_receive(msg).await?; 261 262 match response { 263 RelayMessage::Pong => Ok(()), 264 _ => Err(RelayError::InvalidMessage( 265 "Expected Pong response".to_string(), 266 )), 267 } 268 } 269 270 /// Get the assigned mailbox ID (if registered). 271 pub async fn mailbox_id(&self) -> Option<[u8; 16]> { 272 *self.mailbox_id.lock().await 273 } 274 275 /// Get the relay URL. 276 pub fn url(&self) -> &str { 277 &self.url 278 } 279 280 // ========================================================================= 281 // GOSSIP PROTOCOL 282 // ========================================================================= 283 284 /// Send a gossip digest to the peer and receive their digest. 285 /// 286 /// Returns the peer's bloom filter and message count. 287 pub async fn send_gossip_digest( 288 &self, 289 bloom: Vec<u8>, 290 message_count: u32, 291 gossip_version: u8, 292 ) -> RelayResult<(Vec<u8>, u32)> { 293 let msg = RelayMessage::GossipDigest { 294 bloom, 295 message_count, 296 gossip_version, 297 }; 298 299 let response = self.send_and_receive(msg).await?; 300 301 match response { 302 RelayMessage::GossipDigest { 303 bloom, 304 message_count, 305 gossip_version: _, 306 } => Ok((bloom, message_count)), 307 RelayMessage::Error { code, message } => Err(RelayError::ServerError { code, message }), 308 _ => Err(RelayError::InvalidMessage( 309 "Expected GossipDigest response".to_string(), 310 )), 311 } 312 } 313 314 /// Send a gossip request for specific recipients. 315 /// 316 /// Returns forwarded messages matching the request. 317 pub async fn send_gossip_request( 318 &self, 319 recipient_hashes: Vec<[u8; 32]>, 320 limit: u32, 321 ) -> RelayResult<Vec<GossipForwardedMessage>> { 322 let msg = RelayMessage::GossipRequest { 323 recipient_hashes, 324 limit, 325 }; 326 327 let response = self.send_and_receive(msg).await?; 328 329 match response { 330 RelayMessage::GossipResponse { messages } => Ok(messages), 331 RelayMessage::Error { code, message } => Err(RelayError::ServerError { code, message }), 332 _ => Err(RelayError::InvalidMessage( 333 "Expected GossipResponse".to_string(), 334 )), 335 } 336 } 337 338 /// Send a gossip response with forwarded messages. 339 pub async fn send_gossip_response( 340 &self, 341 messages: Vec<GossipForwardedMessage>, 342 ) -> RelayResult<()> { 343 let msg = RelayMessage::GossipResponse { messages }; 344 self.send_message(msg).await 345 } 346 347 /// Send a delivery confirmation for a successfully delivered message. 348 pub async fn send_delivery_confirmation( 349 &self, 350 message_hash: [u8; 16], 351 hops_remaining: u8, 352 ) -> RelayResult<()> { 353 let msg = RelayMessage::DeliveryConfirmation { 354 message_hash, 355 hops_remaining, 356 }; 357 self.send_message(msg).await 358 } 359 360 /// Receive the next message from the relay (non-blocking wait). 361 /// 362 /// Used for receiving gossip requests after sending a digest. 363 pub async fn receive_message(&self) -> RelayResult<RelayMessage> { 364 let mut conn_guard = self.connection.lock().await; 365 let conn = conn_guard.as_mut().ok_or(RelayError::NotConnected)?; 366 367 let recv_future = conn.next(); 368 let response = timeout(self.operation_timeout, recv_future) 369 .await 370 .map_err(|_| RelayError::Timeout)? 371 .ok_or(RelayError::Disconnected)? 372 .map_err(|e| RelayError::Internal(format!("Receive error: {}", e)))?; 373 374 match response { 375 Message::Binary(data) => RelayMessage::from_bytes(&data) 376 .map_err(|e| RelayError::InvalidMessage(format!("Deserialization error: {}", e))), 377 Message::Close(_) => Err(RelayError::Disconnected), 378 _ => Err(RelayError::InvalidMessage( 379 "Expected binary message".to_string(), 380 )), 381 } 382 } 383 384 /// Send a message and wait for a response. 385 async fn send_and_receive(&self, msg: RelayMessage) -> RelayResult<RelayMessage> { 386 let mut conn_guard = self.connection.lock().await; 387 let conn = conn_guard.as_mut().ok_or(RelayError::NotConnected)?; 388 389 // Serialize and send 390 let data = msg 391 .to_bytes() 392 .map_err(|e| RelayError::Internal(format!("Serialization error: {}", e)))?; 393 394 let send_future = conn.send(Message::Binary(data)); 395 timeout(self.operation_timeout, send_future) 396 .await 397 .map_err(|_| RelayError::Timeout)? 398 .map_err(|e| RelayError::Internal(format!("Send error: {}", e)))?; 399 400 // Receive response 401 let recv_future = conn.next(); 402 let response = timeout(self.operation_timeout, recv_future) 403 .await 404 .map_err(|_| RelayError::Timeout)? 405 .ok_or(RelayError::Disconnected)? 406 .map_err(|e| RelayError::Internal(format!("Receive error: {}", e)))?; 407 408 // Parse response 409 match response { 410 Message::Binary(data) => RelayMessage::from_bytes(&data) 411 .map_err(|e| RelayError::InvalidMessage(format!("Deserialization error: {}", e))), 412 Message::Close(_) => Err(RelayError::Disconnected), 413 _ => Err(RelayError::InvalidMessage( 414 "Expected binary message".to_string(), 415 )), 416 } 417 } 418 419 /// Send a message without waiting for response. 420 async fn send_message(&self, msg: RelayMessage) -> RelayResult<()> { 421 let mut conn_guard = self.connection.lock().await; 422 let conn = conn_guard.as_mut().ok_or(RelayError::NotConnected)?; 423 424 let data = msg 425 .to_bytes() 426 .map_err(|e| RelayError::Internal(format!("Serialization error: {}", e)))?; 427 428 let send_future = conn.send(Message::Binary(data)); 429 timeout(self.operation_timeout, send_future) 430 .await 431 .map_err(|_| RelayError::Timeout)? 432 .map_err(|e| RelayError::Internal(format!("Send error: {}", e)))?; 433 434 Ok(()) 435 } 436 } 437 438 impl std::fmt::Debug for RelayClient { 439 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 440 f.debug_struct("RelayClient") 441 .field("url", &self.url) 442 .field("public_key", &hex::encode(&self.public_key)) 443 .finish() 444 } 445 } 446 447 #[cfg(test)] 448 mod tests { 449 use super::*; 450 451 #[test] 452 fn test_relay_client_new() { 453 let client = RelayClient::new("wss://relay.example.com", [0xAB; 32]); 454 assert_eq!(client.url(), "wss://relay.example.com"); 455 } 456 457 #[test] 458 fn test_relay_error_display() { 459 let err = RelayError::ConnectionFailed("Connection refused".to_string()); 460 assert!(err.to_string().contains("Connection refused")); 461 462 let err = RelayError::ServerError { 463 code: RelayErrorCode::RateLimited, 464 message: "Too many requests".to_string(), 465 }; 466 assert!(err.to_string().contains("RateLimited")); 467 } 468 469 #[tokio::test] 470 async fn test_relay_client_not_connected() { 471 let client = RelayClient::new("wss://relay.example.com", [0xAB; 32]); 472 473 assert!(!client.is_connected().await); 474 assert!(client.mailbox_id().await.is_none()); 475 476 // Operations should fail when not connected 477 let result = client.register().await; 478 assert!(matches!(result, Err(RelayError::NotConnected))); 479 480 let result = client.send([0xCD; 32], b"test").await; 481 assert!(matches!(result, Err(RelayError::NotConnected))); 482 } 483 484 #[tokio::test] 485 async fn test_relay_client_fetch_not_registered() { 486 let client = RelayClient::new("wss://relay.example.com", [0xAB; 32]); 487 488 // Fetch should fail when not registered 489 let result = client.fetch(0).await; 490 assert!(matches!(result, Err(RelayError::NotConnected | RelayError::NotRegistered))); 491 } 492 493 #[test] 494 fn test_relay_client_payload_too_large() { 495 // Test that the constant is reasonable 496 assert!(MAX_PAYLOAD_SIZE >= 1024 * 1024); // At least 1 MB 497 } 498 }