/ abzu-inference / src / protocol.rs
protocol.rs
  1  //! Protocol messages for mesh inference.
  2  //!
  3  //! Defines the request/response types that flow between coordinator and workers.
  4  
  5  use serde::{Deserialize, Serialize};
  6  use crate::shard::TokenShard;
  7  
  8  /// Request for distributed inference.
  9  #[derive(Debug, Clone, Serialize, Deserialize)]
 10  pub struct InferenceRequest {
 11      /// Unique request identifier
 12      pub request_id: [u8; 32],
 13      
 14      /// Model identifier (e.g., "llama-7b", "mistral-7b")
 15      pub model_id: String,
 16      
 17      /// The sharded prompt data
 18      pub shard: TokenShard,
 19      
 20      /// Inference parameters
 21      pub params: InferenceParams,
 22      
 23      /// Requesting node's callback address (for async response)
 24      pub callback: Option<String>,
 25      
 26      /// Request timestamp (for timeout tracking)
 27      pub timestamp_ms: u64,
 28  
 29      /// Transport priority tier (0 = No Jitter/Real-time, 1 = Reliable/Bulk)
 30      #[serde(default)]
 31      pub transport_tier: u8,
 32  }
 33  
 34  /// Parameters for inference generation.
 35  #[derive(Debug, Clone, Serialize, Deserialize)]
 36  pub struct InferenceParams {
 37      /// Maximum tokens to generate
 38      pub max_tokens: usize,
 39      
 40      /// Temperature (0.0 = deterministic, 1.0 = creative)
 41      pub temperature: f32,
 42      
 43      /// Top-p nucleus sampling
 44      pub top_p: f32,
 45      
 46      /// Stop sequences
 47      pub stop_sequences: Vec<String>,
 48      
 49      /// Whether to stream partial results
 50      pub stream: bool,
 51  }
 52  
 53  impl Default for InferenceParams {
 54      fn default() -> Self {
 55          Self {
 56              max_tokens: 256,
 57              temperature: 0.7,
 58              top_p: 0.9,
 59              stop_sequences: vec![],
 60              stream: false,
 61          }
 62      }
 63  }
 64  
 65  impl InferenceParams {
 66      /// Create deterministic params (temperature = 0)
 67      pub fn deterministic() -> Self {
 68          Self {
 69              temperature: 0.0,
 70              top_p: 1.0,
 71              ..Default::default()
 72          }
 73      }
 74      
 75      /// Set max tokens
 76      pub fn with_max_tokens(mut self, n: usize) -> Self {
 77          self.max_tokens = n;
 78          self
 79      }
 80  }
 81  
 82  /// Partial result from a single shard worker.
 83  #[derive(Debug, Clone, Serialize, Deserialize)]
 84  pub struct PartialResult {
 85      /// Request ID this result belongs to
 86      pub request_id: [u8; 32],
 87      
 88      /// Shard index that produced this result
 89      pub shard_index: usize,
 90      
 91      /// Total shards in the request
 92      pub shard_total: usize,
 93      
 94      /// The partial computation (sharded)
 95      pub data: Vec<u8>,
 96      
 97      /// Commitment for verification
 98      pub commitment: [u8; 32],
 99      
100      /// Processing time in milliseconds
101      pub processing_time_ms: u64,
102      
103      /// Any error that occurred
104      pub error: Option<String>,
105  
106      /// Cryptographic Proof of Inference
107      pub proof: Option<crate::proof_of_inference::InferenceProof>,
108  }
109  
110  impl PartialResult {
111      /// Create an error result
112      pub fn error(request_id: [u8; 32], shard_index: usize, total: usize, msg: String) -> Self {
113          Self {
114              request_id,
115              shard_index,
116              shard_total: total,
117              data: vec![],
118              commitment: [0u8; 32],
119              processing_time_ms: 0,
120              error: Some(msg),
121              proof: None,
122          }
123      }
124      
125      /// Check if this result indicates an error
126      pub fn is_error(&self) -> bool {
127          self.error.is_some()
128      }
129  }
130  
131  /// Final aggregated inference response.
132  #[derive(Debug, Clone, Serialize, Deserialize)]
133  pub struct InferenceResponse {
134      /// Request ID
135      pub request_id: [u8; 32],
136      
137      /// Generated text (reconstructed from shards)
138      pub text: String,
139      
140      /// Tokens generated
141      pub tokens_generated: usize,
142      
143      /// Total processing time across all shards
144      pub total_time_ms: u64,
145      
146      /// Number of shards that participated
147      pub shards_used: usize,
148      
149      /// Model that was used
150      pub model_id: String,
151      
152      /// Finish reason
153      pub finish_reason: FinishReason,
154  
155      /// Cryptographic proofs for payment
156      #[serde(default)]
157      pub proofs: Vec<crate::proof_of_inference::InferenceProof>,
158  }
159  
160  /// Why inference completed.
161  #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
162  pub enum FinishReason {
163      /// Hit max_tokens limit
164      Length,
165      /// Hit a stop sequence
166      Stop,
167      /// Model completed naturally
168      EndOfText,
169      /// Error occurred
170      Error,
171  }
172  
173  #[cfg(test)]
174  mod tests {
175      use super::*;
176  
177      #[test]
178      fn test_inference_params_default() {
179          let params = InferenceParams::default();
180          assert_eq!(params.max_tokens, 256);
181          assert!((params.temperature - 0.7).abs() < 0.01);
182      }
183  
184      #[test]
185      fn test_deterministic_params() {
186          let params = InferenceParams::deterministic();
187          assert!((params.temperature - 0.0).abs() < 0.001);
188      }
189  
190      #[test]
191      fn test_partial_result_error() {
192          let result = PartialResult::error(
193              [0u8; 32],
194              0,
195              3,
196              "worker crashed".into(),
197          );
198          assert!(result.is_error());
199      }
200  
201      #[test]
202      fn test_serialization() {
203          let params = InferenceParams::default().with_max_tokens(100);
204          let json = serde_json::to_string(&params).unwrap();
205          let parsed: InferenceParams = serde_json::from_str(&json).unwrap();
206          assert_eq!(parsed.max_tokens, 100);
207      }
208  }