protocol.rs
1 //! Protocol messages for mesh inference. 2 //! 3 //! Defines the request/response types that flow between coordinator and workers. 4 5 use serde::{Deserialize, Serialize}; 6 use crate::shard::TokenShard; 7 8 /// Request for distributed inference. 9 #[derive(Debug, Clone, Serialize, Deserialize)] 10 pub struct InferenceRequest { 11 /// Unique request identifier 12 pub request_id: [u8; 32], 13 14 /// Model identifier (e.g., "llama-7b", "mistral-7b") 15 pub model_id: String, 16 17 /// The sharded prompt data 18 pub shard: TokenShard, 19 20 /// Inference parameters 21 pub params: InferenceParams, 22 23 /// Requesting node's callback address (for async response) 24 pub callback: Option<String>, 25 26 /// Request timestamp (for timeout tracking) 27 pub timestamp_ms: u64, 28 29 /// Transport priority tier (0 = No Jitter/Real-time, 1 = Reliable/Bulk) 30 #[serde(default)] 31 pub transport_tier: u8, 32 } 33 34 /// Parameters for inference generation. 35 #[derive(Debug, Clone, Serialize, Deserialize)] 36 pub struct InferenceParams { 37 /// Maximum tokens to generate 38 pub max_tokens: usize, 39 40 /// Temperature (0.0 = deterministic, 1.0 = creative) 41 pub temperature: f32, 42 43 /// Top-p nucleus sampling 44 pub top_p: f32, 45 46 /// Stop sequences 47 pub stop_sequences: Vec<String>, 48 49 /// Whether to stream partial results 50 pub stream: bool, 51 } 52 53 impl Default for InferenceParams { 54 fn default() -> Self { 55 Self { 56 max_tokens: 256, 57 temperature: 0.7, 58 top_p: 0.9, 59 stop_sequences: vec![], 60 stream: false, 61 } 62 } 63 } 64 65 impl InferenceParams { 66 /// Create deterministic params (temperature = 0) 67 pub fn deterministic() -> Self { 68 Self { 69 temperature: 0.0, 70 top_p: 1.0, 71 ..Default::default() 72 } 73 } 74 75 /// Set max tokens 76 pub fn with_max_tokens(mut self, n: usize) -> Self { 77 self.max_tokens = n; 78 self 79 } 80 } 81 82 /// Partial result from a single shard worker. 83 #[derive(Debug, Clone, Serialize, Deserialize)] 84 pub struct PartialResult { 85 /// Request ID this result belongs to 86 pub request_id: [u8; 32], 87 88 /// Shard index that produced this result 89 pub shard_index: usize, 90 91 /// Total shards in the request 92 pub shard_total: usize, 93 94 /// The partial computation (sharded) 95 pub data: Vec<u8>, 96 97 /// Commitment for verification 98 pub commitment: [u8; 32], 99 100 /// Processing time in milliseconds 101 pub processing_time_ms: u64, 102 103 /// Any error that occurred 104 pub error: Option<String>, 105 106 /// Cryptographic Proof of Inference 107 pub proof: Option<crate::proof_of_inference::InferenceProof>, 108 } 109 110 impl PartialResult { 111 /// Create an error result 112 pub fn error(request_id: [u8; 32], shard_index: usize, total: usize, msg: String) -> Self { 113 Self { 114 request_id, 115 shard_index, 116 shard_total: total, 117 data: vec![], 118 commitment: [0u8; 32], 119 processing_time_ms: 0, 120 error: Some(msg), 121 proof: None, 122 } 123 } 124 125 /// Check if this result indicates an error 126 pub fn is_error(&self) -> bool { 127 self.error.is_some() 128 } 129 } 130 131 /// Final aggregated inference response. 132 #[derive(Debug, Clone, Serialize, Deserialize)] 133 pub struct InferenceResponse { 134 /// Request ID 135 pub request_id: [u8; 32], 136 137 /// Generated text (reconstructed from shards) 138 pub text: String, 139 140 /// Tokens generated 141 pub tokens_generated: usize, 142 143 /// Total processing time across all shards 144 pub total_time_ms: u64, 145 146 /// Number of shards that participated 147 pub shards_used: usize, 148 149 /// Model that was used 150 pub model_id: String, 151 152 /// Finish reason 153 pub finish_reason: FinishReason, 154 155 /// Cryptographic proofs for payment 156 #[serde(default)] 157 pub proofs: Vec<crate::proof_of_inference::InferenceProof>, 158 } 159 160 /// Why inference completed. 161 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] 162 pub enum FinishReason { 163 /// Hit max_tokens limit 164 Length, 165 /// Hit a stop sequence 166 Stop, 167 /// Model completed naturally 168 EndOfText, 169 /// Error occurred 170 Error, 171 } 172 173 #[cfg(test)] 174 mod tests { 175 use super::*; 176 177 #[test] 178 fn test_inference_params_default() { 179 let params = InferenceParams::default(); 180 assert_eq!(params.max_tokens, 256); 181 assert!((params.temperature - 0.7).abs() < 0.01); 182 } 183 184 #[test] 185 fn test_deterministic_params() { 186 let params = InferenceParams::deterministic(); 187 assert!((params.temperature - 0.0).abs() < 0.001); 188 } 189 190 #[test] 191 fn test_partial_result_error() { 192 let result = PartialResult::error( 193 [0u8; 32], 194 0, 195 3, 196 "worker crashed".into(), 197 ); 198 assert!(result.is_error()); 199 } 200 201 #[test] 202 fn test_serialization() { 203 let params = InferenceParams::default().with_max_tokens(100); 204 let json = serde_json::to_string(¶ms).unwrap(); 205 let parsed: InferenceParams = serde_json::from_str(&json).unwrap(); 206 assert_eq!(parsed.max_tokens, 100); 207 } 208 }