/ core / src / relay / client.rs
client.rs
  1  //! Relay server client.
  2  //!
  3  //! This module provides a WebSocket-based client for communicating with
  4  //! relay servers.
  5  
  6  use std::time::Duration;
  7  
  8  use futures_util::{SinkExt, StreamExt};
  9  use tokio::net::TcpStream;
 10  use tokio::sync::Mutex;
 11  use tokio::time::timeout;
 12  use tokio_tungstenite::{
 13      connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream,
 14  };
 15  
 16  use super::protocol::{
 17      GossipForwardedMessage, RelayErrorCode, RelayMessage, StoredMessage, MAX_PAYLOAD_SIZE,
 18  };
 19  
 20  /// Default connection timeout in seconds.
 21  const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 30;
 22  
 23  /// Default operation timeout in seconds.
 24  const DEFAULT_OPERATION_TIMEOUT_SECS: u64 = 60;
 25  
 26  /// Errors that can occur during relay operations.
 27  #[derive(Debug, Clone, PartialEq, Eq)]
 28  pub enum RelayError {
 29      /// Failed to connect to the relay server.
 30      ConnectionFailed(String),
 31      /// Connection was closed unexpectedly.
 32      Disconnected,
 33      /// Operation timed out.
 34      Timeout,
 35      /// Invalid message format.
 36      InvalidMessage(String),
 37      /// Payload exceeds maximum size.
 38      PayloadTooLarge,
 39      /// Server returned an error.
 40      ServerError {
 41          code: RelayErrorCode,
 42          message: String,
 43      },
 44      /// Not connected to a relay.
 45      NotConnected,
 46      /// Not registered with the relay.
 47      NotRegistered,
 48      /// Internal error.
 49      Internal(String),
 50  }
 51  
 52  impl std::fmt::Display for RelayError {
 53      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 54          match self {
 55              RelayError::ConnectionFailed(msg) => write!(f, "Connection failed: {}", msg),
 56              RelayError::Disconnected => write!(f, "Disconnected from relay"),
 57              RelayError::Timeout => write!(f, "Operation timed out"),
 58              RelayError::InvalidMessage(msg) => write!(f, "Invalid message: {}", msg),
 59              RelayError::PayloadTooLarge => write!(f, "Payload exceeds maximum size"),
 60              RelayError::ServerError { code, message } => {
 61                  write!(f, "Server error ({:?}): {}", code, message)
 62              }
 63              RelayError::NotConnected => write!(f, "Not connected to relay"),
 64              RelayError::NotRegistered => write!(f, "Not registered with relay"),
 65              RelayError::Internal(msg) => write!(f, "Internal error: {}", msg),
 66          }
 67      }
 68  }
 69  
 70  impl std::error::Error for RelayError {}
 71  
 72  /// Result type for relay operations.
 73  pub type RelayResult<T> = Result<T, RelayError>;
 74  
 75  /// Client for communicating with a relay server.
 76  ///
 77  /// The client maintains a WebSocket connection and handles the relay protocol.
 78  pub struct RelayClient {
 79      /// Relay server URL.
 80      url: String,
 81      /// WebSocket connection (if connected).
 82      connection: Mutex<Option<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
 83      /// Assigned mailbox ID (if registered).
 84      mailbox_id: Mutex<Option<[u8; 16]>>,
 85      /// Our public key (for registration).
 86      public_key: [u8; 32],
 87      /// Connection timeout.
 88      connect_timeout: Duration,
 89      /// Operation timeout.
 90      operation_timeout: Duration,
 91  }
 92  
 93  impl RelayClient {
 94      /// Create a new relay client.
 95      ///
 96      /// # Arguments
 97      ///
 98      /// * `url` - WebSocket URL of the relay server (e.g., "wss://relay.example.com")
 99      /// * `public_key` - Our X25519 public key for registration
100      pub fn new(url: &str, public_key: [u8; 32]) -> Self {
101          Self {
102              url: url.to_string(),
103              connection: Mutex::new(None),
104              mailbox_id: Mutex::new(None),
105              public_key,
106              connect_timeout: Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS),
107              operation_timeout: Duration::from_secs(DEFAULT_OPERATION_TIMEOUT_SECS),
108          }
109      }
110  
111      /// Set the connection timeout.
112      pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
113          self.connect_timeout = timeout;
114          self
115      }
116  
117      /// Set the operation timeout.
118      pub fn with_operation_timeout(mut self, timeout: Duration) -> Self {
119          self.operation_timeout = timeout;
120          self
121      }
122  
123      /// Connect to the relay server.
124      pub async fn connect(&self) -> RelayResult<()> {
125          let connect_future = connect_async(&self.url);
126  
127          let (ws_stream, _response) = timeout(self.connect_timeout, connect_future)
128              .await
129              .map_err(|_| RelayError::Timeout)?
130              .map_err(|e| RelayError::ConnectionFailed(e.to_string()))?;
131  
132          let mut conn = self.connection.lock().await;
133          *conn = Some(ws_stream);
134  
135          Ok(())
136      }
137  
138      /// Disconnect from the relay server.
139      pub async fn disconnect(&self) -> RelayResult<()> {
140          let mut conn = self.connection.lock().await;
141          if let Some(mut ws) = conn.take() {
142              let _ = ws.close(None).await;
143          }
144  
145          let mut mailbox = self.mailbox_id.lock().await;
146          *mailbox = None;
147  
148          Ok(())
149      }
150  
151      /// Check if connected to the relay.
152      pub async fn is_connected(&self) -> bool {
153          self.connection.lock().await.is_some()
154      }
155  
156      /// Register with the relay to create a mailbox.
157      ///
158      /// Returns the assigned mailbox ID.
159      pub async fn register(&self) -> RelayResult<[u8; 16]> {
160          let msg = RelayMessage::Register {
161              public_key: self.public_key,
162          };
163  
164          let response = self.send_and_receive(msg).await?;
165  
166          match response {
167              RelayMessage::RegisterAck { mailbox_id } => {
168                  let mut mailbox = self.mailbox_id.lock().await;
169                  *mailbox = Some(mailbox_id);
170                  Ok(mailbox_id)
171              }
172              RelayMessage::Error { code, message } => {
173                  Err(RelayError::ServerError { code, message })
174              }
175              _ => Err(RelayError::InvalidMessage(
176                  "Expected RegisterAck response".to_string(),
177              )),
178          }
179      }
180  
181      /// Send a message to a recipient via the relay.
182      ///
183      /// # Arguments
184      ///
185      /// * `recipient_key` - Recipient's X25519 public key
186      /// * `payload` - Encrypted message payload
187      ///
188      /// # Returns
189      ///
190      /// The message ID assigned by the relay.
191      pub async fn send(&self, recipient_key: [u8; 32], payload: &[u8]) -> RelayResult<[u8; 16]> {
192          if payload.len() > MAX_PAYLOAD_SIZE {
193              return Err(RelayError::PayloadTooLarge);
194          }
195  
196          let msg = RelayMessage::Send {
197              recipient_key,
198              payload: payload.to_vec(),
199          };
200  
201          let response = self.send_and_receive(msg).await?;
202  
203          match response {
204              RelayMessage::SendAck { message_id } => Ok(message_id),
205              RelayMessage::Error { code, message } => {
206                  Err(RelayError::ServerError { code, message })
207              }
208              _ => Err(RelayError::InvalidMessage(
209                  "Expected SendAck response".to_string(),
210              )),
211          }
212      }
213  
214      /// Fetch messages from the mailbox.
215      ///
216      /// # Arguments
217      ///
218      /// * `since` - Unix timestamp - fetch messages newer than this (0 for all)
219      ///
220      /// # Returns
221      ///
222      /// List of stored messages.
223      pub async fn fetch(&self, since: u64) -> RelayResult<Vec<StoredMessage>> {
224          if self.mailbox_id.lock().await.is_none() {
225              return Err(RelayError::NotRegistered);
226          }
227  
228          let msg = RelayMessage::Fetch { since };
229  
230          let response = self.send_and_receive(msg).await?;
231  
232          match response {
233              RelayMessage::Messages { messages } => Ok(messages),
234              RelayMessage::Error { code, message } => {
235                  Err(RelayError::ServerError { code, message })
236              }
237              _ => Err(RelayError::InvalidMessage(
238                  "Expected Messages response".to_string(),
239              )),
240          }
241      }
242  
243      /// Acknowledge receipt of messages.
244      ///
245      /// This allows the relay to delete the acknowledged messages.
246      pub async fn acknowledge(&self, message_ids: Vec<[u8; 16]>) -> RelayResult<()> {
247          if message_ids.is_empty() {
248              return Ok(());
249          }
250  
251          let msg = RelayMessage::Ack { message_ids };
252  
253          // For ack, we don't expect a response
254          self.send_message(msg).await
255      }
256  
257      /// Send a ping to keep the connection alive.
258      pub async fn ping(&self) -> RelayResult<()> {
259          let msg = RelayMessage::Ping;
260          let response = self.send_and_receive(msg).await?;
261  
262          match response {
263              RelayMessage::Pong => Ok(()),
264              _ => Err(RelayError::InvalidMessage(
265                  "Expected Pong response".to_string(),
266              )),
267          }
268      }
269  
270      /// Get the assigned mailbox ID (if registered).
271      pub async fn mailbox_id(&self) -> Option<[u8; 16]> {
272          *self.mailbox_id.lock().await
273      }
274  
275      /// Get the relay URL.
276      pub fn url(&self) -> &str {
277          &self.url
278      }
279  
280      // =========================================================================
281      // GOSSIP PROTOCOL
282      // =========================================================================
283  
284      /// Send a gossip digest to the peer and receive their digest.
285      ///
286      /// Returns the peer's bloom filter and message count.
287      pub async fn send_gossip_digest(
288          &self,
289          bloom: Vec<u8>,
290          message_count: u32,
291          gossip_version: u8,
292      ) -> RelayResult<(Vec<u8>, u32)> {
293          let msg = RelayMessage::GossipDigest {
294              bloom,
295              message_count,
296              gossip_version,
297          };
298  
299          let response = self.send_and_receive(msg).await?;
300  
301          match response {
302              RelayMessage::GossipDigest {
303                  bloom,
304                  message_count,
305                  gossip_version: _,
306              } => Ok((bloom, message_count)),
307              RelayMessage::Error { code, message } => Err(RelayError::ServerError { code, message }),
308              _ => Err(RelayError::InvalidMessage(
309                  "Expected GossipDigest response".to_string(),
310              )),
311          }
312      }
313  
314      /// Send a gossip request for specific recipients.
315      ///
316      /// Returns forwarded messages matching the request.
317      pub async fn send_gossip_request(
318          &self,
319          recipient_hashes: Vec<[u8; 32]>,
320          limit: u32,
321      ) -> RelayResult<Vec<GossipForwardedMessage>> {
322          let msg = RelayMessage::GossipRequest {
323              recipient_hashes,
324              limit,
325          };
326  
327          let response = self.send_and_receive(msg).await?;
328  
329          match response {
330              RelayMessage::GossipResponse { messages } => Ok(messages),
331              RelayMessage::Error { code, message } => Err(RelayError::ServerError { code, message }),
332              _ => Err(RelayError::InvalidMessage(
333                  "Expected GossipResponse".to_string(),
334              )),
335          }
336      }
337  
338      /// Send a gossip response with forwarded messages.
339      pub async fn send_gossip_response(
340          &self,
341          messages: Vec<GossipForwardedMessage>,
342      ) -> RelayResult<()> {
343          let msg = RelayMessage::GossipResponse { messages };
344          self.send_message(msg).await
345      }
346  
347      /// Send a delivery confirmation for a successfully delivered message.
348      pub async fn send_delivery_confirmation(
349          &self,
350          message_hash: [u8; 16],
351          hops_remaining: u8,
352      ) -> RelayResult<()> {
353          let msg = RelayMessage::DeliveryConfirmation {
354              message_hash,
355              hops_remaining,
356          };
357          self.send_message(msg).await
358      }
359  
360      /// Receive the next message from the relay (non-blocking wait).
361      ///
362      /// Used for receiving gossip requests after sending a digest.
363      pub async fn receive_message(&self) -> RelayResult<RelayMessage> {
364          let mut conn_guard = self.connection.lock().await;
365          let conn = conn_guard.as_mut().ok_or(RelayError::NotConnected)?;
366  
367          let recv_future = conn.next();
368          let response = timeout(self.operation_timeout, recv_future)
369              .await
370              .map_err(|_| RelayError::Timeout)?
371              .ok_or(RelayError::Disconnected)?
372              .map_err(|e| RelayError::Internal(format!("Receive error: {}", e)))?;
373  
374          match response {
375              Message::Binary(data) => RelayMessage::from_bytes(&data)
376                  .map_err(|e| RelayError::InvalidMessage(format!("Deserialization error: {}", e))),
377              Message::Close(_) => Err(RelayError::Disconnected),
378              _ => Err(RelayError::InvalidMessage(
379                  "Expected binary message".to_string(),
380              )),
381          }
382      }
383  
384      /// Send a message and wait for a response.
385      async fn send_and_receive(&self, msg: RelayMessage) -> RelayResult<RelayMessage> {
386          let mut conn_guard = self.connection.lock().await;
387          let conn = conn_guard.as_mut().ok_or(RelayError::NotConnected)?;
388  
389          // Serialize and send
390          let data = msg
391              .to_bytes()
392              .map_err(|e| RelayError::Internal(format!("Serialization error: {}", e)))?;
393  
394          let send_future = conn.send(Message::Binary(data));
395          timeout(self.operation_timeout, send_future)
396              .await
397              .map_err(|_| RelayError::Timeout)?
398              .map_err(|e| RelayError::Internal(format!("Send error: {}", e)))?;
399  
400          // Receive response
401          let recv_future = conn.next();
402          let response = timeout(self.operation_timeout, recv_future)
403              .await
404              .map_err(|_| RelayError::Timeout)?
405              .ok_or(RelayError::Disconnected)?
406              .map_err(|e| RelayError::Internal(format!("Receive error: {}", e)))?;
407  
408          // Parse response
409          match response {
410              Message::Binary(data) => RelayMessage::from_bytes(&data)
411                  .map_err(|e| RelayError::InvalidMessage(format!("Deserialization error: {}", e))),
412              Message::Close(_) => Err(RelayError::Disconnected),
413              _ => Err(RelayError::InvalidMessage(
414                  "Expected binary message".to_string(),
415              )),
416          }
417      }
418  
419      /// Send a message without waiting for response.
420      async fn send_message(&self, msg: RelayMessage) -> RelayResult<()> {
421          let mut conn_guard = self.connection.lock().await;
422          let conn = conn_guard.as_mut().ok_or(RelayError::NotConnected)?;
423  
424          let data = msg
425              .to_bytes()
426              .map_err(|e| RelayError::Internal(format!("Serialization error: {}", e)))?;
427  
428          let send_future = conn.send(Message::Binary(data));
429          timeout(self.operation_timeout, send_future)
430              .await
431              .map_err(|_| RelayError::Timeout)?
432              .map_err(|e| RelayError::Internal(format!("Send error: {}", e)))?;
433  
434          Ok(())
435      }
436  }
437  
438  impl std::fmt::Debug for RelayClient {
439      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440          f.debug_struct("RelayClient")
441              .field("url", &self.url)
442              .field("public_key", &hex::encode(&self.public_key))
443              .finish()
444      }
445  }
446  
447  #[cfg(test)]
448  mod tests {
449      use super::*;
450  
451      #[test]
452      fn test_relay_client_new() {
453          let client = RelayClient::new("wss://relay.example.com", [0xAB; 32]);
454          assert_eq!(client.url(), "wss://relay.example.com");
455      }
456  
457      #[test]
458      fn test_relay_error_display() {
459          let err = RelayError::ConnectionFailed("Connection refused".to_string());
460          assert!(err.to_string().contains("Connection refused"));
461  
462          let err = RelayError::ServerError {
463              code: RelayErrorCode::RateLimited,
464              message: "Too many requests".to_string(),
465          };
466          assert!(err.to_string().contains("RateLimited"));
467      }
468  
469      #[tokio::test]
470      async fn test_relay_client_not_connected() {
471          let client = RelayClient::new("wss://relay.example.com", [0xAB; 32]);
472  
473          assert!(!client.is_connected().await);
474          assert!(client.mailbox_id().await.is_none());
475  
476          // Operations should fail when not connected
477          let result = client.register().await;
478          assert!(matches!(result, Err(RelayError::NotConnected)));
479  
480          let result = client.send([0xCD; 32], b"test").await;
481          assert!(matches!(result, Err(RelayError::NotConnected)));
482      }
483  
484      #[tokio::test]
485      async fn test_relay_client_fetch_not_registered() {
486          let client = RelayClient::new("wss://relay.example.com", [0xAB; 32]);
487  
488          // Fetch should fail when not registered
489          let result = client.fetch(0).await;
490          assert!(matches!(result, Err(RelayError::NotConnected | RelayError::NotRegistered)));
491      }
492  
493      #[test]
494      fn test_relay_client_payload_too_large() {
495          // Test that the constant is reasonable
496          assert!(MAX_PAYLOAD_SIZE >= 1024 * 1024); // At least 1 MB
497      }
498  }