/ abzu-inference / src / service.rs
service.rs
  1  use std::sync::Arc;
  2  use abzu_core::node::{Node, ServiceEvent};
  3  use tracing::{info, warn, error};
  4  use abzu_transport::AbzuFrame;
  5  use crate::protocol::{InferenceRequest, PartialResult};
  6  use crate::proof_of_inference::InferenceProof;
  7  use crate::shard::TokenShard;
  8  
  9  /// Service that runs on a worker node to accept and process inference requests.
 10  pub struct InferenceService {
 11      node: Arc<Node>,
 12      /// We sign proofs with our node identity
 13      identity_key: ed25519_dalek::SigningKey,
 14      /// Destination for payments
 15      payment_address: Option<String>,
 16  }
 17  
 18  impl InferenceService {
 19      /// Create a new inference service
 20      pub fn new(node: Arc<Node>, identity_key: ed25519_dalek::SigningKey, payment_address: Option<String>) -> Self {
 21          Self {
 22              node,
 23              identity_key,
 24              payment_address,
 25          }
 26      }
 27  
 28      /// 🧪 SIMULATION: Manually trigger a job to test the payment loop
 29      pub async fn simulate_processing(&self, input: &str) -> PartialResult {
 30          // Construct a mock request (usually this comes from the Network)
 31          // We assume we are processing Shard 0 of 1
 32          let mock_shard = TokenShard {
 33              shard_id: [0u8; 32],
 34              request_id: rand::random(),
 35              index: 0,
 36              total: 1,
 37              data: input.as_bytes().to_vec(),
 38              commitment: [0u8; 32],
 39          };
 40  
 41          let req = InferenceRequest {
 42              request_id: mock_shard.request_id,
 43              model_id: "simulation-llama-3-8b".to_string(),
 44              shard: mock_shard,
 45              params: crate::protocol::InferenceParams::default(),
 46              callback: None,
 47              timestamp_ms: 0,
 48              transport_tier: 0,
 49          };
 50  
 51          // Call the internal logic
 52          self.process_request(req).await
 53      }
 54  
 55      /// Start the inference service loop
 56      pub async fn start(&self) {
 57          info!("Inference Service started. Listening for tasks...");
 58          
 59          // Register service ID 1 (Inference)
 60          let mut rx = self.node.register_service(1);
 61          
 62          while let Some(event) = rx.recv().await {
 63              match event {
 64                  ServiceEvent::Request { request_id, requester, payload } => {
 65                      // Deserialize request
 66                      match postcard::from_bytes::<InferenceRequest>(&payload) {
 67                          Ok(req) => {
 68                              info!(
 69                                  request_id = hex::encode(&req.request_id[..8]), 
 70                                  model = %req.model_id,
 71                                  "Processing inference request"
 72                              );
 73                              
 74                              // Process request
 75                              let result = self.process_request(req).await;
 76                              
 77                              // Serialize response
 78                              match postcard::to_allocvec(&result) {
 79                                  Ok(response_payload) => {
 80                                      let frame = AbzuFrame::ServiceResponse {
 81                                          service_id: 1,
 82                                          request_id,
 83                                          status: 0,
 84                                          payload: response_payload,
 85                                      };
 86                                      
 87                                      if let Err(e) = self.node.send(requester, frame).await {
 88                                          warn!(error = ?e, "Failed to send inference response");
 89                                      }
 90                                  }
 91                                  Err(e) => {
 92                                      error!(error = ?e, "Failed to serialize inference result");
 93                                  }
 94                              }
 95                          }
 96                          Err(e) => {
 97                              warn!(error = ?e, "Failed to deserialize inference request");
 98                          }
 99                      }
100                  }
101                  ServiceEvent::Response { .. } => {
102                      // InferenceService doesn't currently make requests, so shouldn't accept responses
103                      // acting as a server only
104                  }
105              }
106          }
107      }
108      
109      /// Process a request (internal helper)
110      async fn process_request(&self, req: InferenceRequest) -> PartialResult {
111          info!("Processing shard {} for request {}", req.shard.index, hex::encode(req.request_id));
112          
113          // 1. Simulate Work
114          let output = format!("shard_{}_output", req.shard.index);
115          let output_bytes = output.as_bytes().to_vec();
116          
117          // 2. Generate Proof
118          let mut hasher = blake3::Hasher::new();
119          hasher.update(&req.shard.data); // Input (prompt shard)
120          let input_hash = *hasher.finalize().as_bytes();
121          
122          let mut hasher = blake3::Hasher::new();
123          hasher.update(&output_bytes);
124          let output_hash = *hasher.finalize().as_bytes();
125          
126          let proof = InferenceProof::new(
127              &input_hash,
128              &output_hash,
129              req.model_id,
130              &self.identity_key,
131              self.payment_address.clone(),
132          );
133          
134          // 3. Create Result
135          PartialResult {
136              request_id: req.request_id,
137              shard_index: req.shard.index,
138              shard_total: req.shard.total,
139              data: output_bytes,
140              commitment: output_hash, // Simple commitment for now
141              processing_time_ms: 100, // Fake time
142              error: None,
143              proof: Some(proof), // Now we attach the proof!
144          }
145      }
146  }