/ src / hardware / ota.rs
ota.rs
  1  /// OTA (Over-The-Air) firmware update functionality
  2  use anyhow::{bail, Result};
  3  use esp_idf_svc::http::client::{Configuration as HttpConfig, EspHttpConnection};
  4  use esp_idf_svc::ota::EspOta;
  5  use log::{error, info};
  6  use std::sync::{Arc, Mutex};
  7  use std::time::{Duration, SystemTime, UNIX_EPOCH};
  8  
  9  use crate::network::auth;
 10  
 11  const OTA_BUFFER_SIZE: usize = 1024;
 12  const OTA_CHECK_INTERVAL_SECS: u64 = 86400; // 24 hours
 13  const OTA_TIMEOUT_SECS: u64 = 300; // 5 minutes
 14  
 15  /// OTA update manager
 16  pub struct OtaManager {
 17      update_url: Arc<Mutex<String>>,
 18      last_check: Arc<Mutex<u64>>,
 19      auto_update_enabled: Arc<Mutex<bool>>,
 20      device_id: Arc<Mutex<String>>,
 21      device_secret: Arc<Mutex<String>>,
 22  }
 23  
 24  impl OtaManager {
 25      /// Create new OTA manager
 26      pub fn new(
 27          update_url: String,
 28          auto_update_enabled: bool,
 29          device_id: String,
 30          device_secret: String,
 31      ) -> Self {
 32          Self {
 33              update_url: Arc::new(Mutex::new(update_url)),
 34              last_check: Arc::new(Mutex::new(0)),
 35              auto_update_enabled: Arc::new(Mutex::new(auto_update_enabled)),
 36              device_id: Arc::new(Mutex::new(device_id)),
 37              device_secret: Arc::new(Mutex::new(device_secret)),
 38          }
 39      }
 40  
 41      /// Update device credentials
 42      #[allow(dead_code)]
 43      pub fn update_credentials(&self, device_id: String, device_secret: String) {
 44          *self.device_id.lock().unwrap() = device_id;
 45          *self.device_secret.lock().unwrap() = device_secret;
 46      }
 47  
 48      /// Set the update URL
 49      pub fn set_update_url(&self, url: String) {
 50          *self.update_url.lock().unwrap() = url;
 51      }
 52  
 53      /// Enable or disable automatic updates
 54      pub fn set_auto_update(&self, enabled: bool) {
 55          *self.auto_update_enabled.lock().unwrap() = enabled;
 56      }
 57  
 58      /// Check if it's time for automatic update check
 59      pub fn should_check_for_updates(&self) -> bool {
 60          if !*self.auto_update_enabled.lock().unwrap() {
 61              return false;
 62          }
 63  
 64          let now = SystemTime::now()
 65              .duration_since(UNIX_EPOCH)
 66              .unwrap_or_default()
 67              .as_secs();
 68          let last = *self.last_check.lock().unwrap();
 69  
 70          now - last >= OTA_CHECK_INTERVAL_SECS
 71      }
 72  
 73      /// Update last check timestamp
 74      fn update_last_check(&self) {
 75          let now = SystemTime::now()
 76              .duration_since(UNIX_EPOCH)
 77              .unwrap_or_default()
 78              .as_secs();
 79          *self.last_check.lock().unwrap() = now;
 80      }
 81  
 82      /// Check for updates from remote server
 83      pub fn check_and_update(&self) -> Result<bool> {
 84          self.update_last_check();
 85  
 86          let url = self.update_url.lock().unwrap().clone();
 87          if url.is_empty() {
 88              info!("OTA update URL not configured, skipping update check");
 89              return Ok(false);
 90          }
 91  
 92          info!("Checking for firmware updates from: {}", url);
 93  
 94          // Check version first (optional - implement version endpoint)
 95          match self.download_and_install(&url) {
 96              Ok(updated) => {
 97                  if updated {
 98                      info!("Firmware update successful! Rebooting...");
 99                      unsafe {
100                          esp_idf_svc::sys::esp_restart();
101                      }
102                  }
103                  Ok(updated)
104              }
105              Err(e) => {
106                  error!("OTA update failed: {:?}", e);
107                  Err(e)
108              }
109          }
110      }
111  
112      /// Download and install firmware from URL with HMAC authentication
113      pub fn download_and_install(&self, url: &str) -> Result<bool> {
114          info!("Starting OTA download from: {}", url);
115  
116          let device_id = self.device_id.lock().unwrap().clone();
117          let device_secret = self.device_secret.lock().unwrap().clone();
118  
119          if device_secret.is_empty() {
120              error!("Device secret not configured - cannot authenticate OTA request");
121              bail!("Device not provisioned for OTA updates");
122          }
123  
124          // For OTA GET request, body is empty
125          let empty_body: &[u8] = &[];
126  
127          // Generate HMAC authentication headers
128          let auth_headers = auth::generate_auth_headers(&device_id, &device_secret, url, empty_body)
129              .map_err(|e| anyhow::anyhow!("Failed to generate auth headers: {}", e))?;
130  
131          // Initialize HTTP client
132          let http_config = HttpConfig {
133              timeout: Some(Duration::from_secs(OTA_TIMEOUT_SECS)),
134              buffer_size: Some(OTA_BUFFER_SIZE),
135              ..Default::default()
136          };
137  
138          let mut client = EspHttpConnection::new(&http_config)?;
139  
140          // Prepare headers for GET request
141          let mut headers: Vec<(&str, &str)> = vec![("User-Agent", "ESP32-Harrastila-OTA/1.0")];
142  
143          // Add authentication headers
144          let auth_header_refs: Vec<(&str, &str)> = auth_headers
145              .iter()
146              .map(|(k, v)| (k.as_ref(), v.as_ref()))
147              .collect();
148          headers.extend(auth_header_refs);
149  
150          // Make GET request
151          client.initiate_request(embedded_svc::http::Method::Get, url, &headers)?;
152  
153          client.initiate_response()?;
154  
155          let status = client.status();
156          if status == 401 || status == 403 {
157              error!("OTA authentication failed (status {})", status);
158              bail!("Authentication failed - device may not be registered");
159          } else if status != 200 {
160              bail!("HTTP error: {}", status);
161          }
162  
163          // Get content length if available
164          let content_length = client
165              .header("Content-Length")
166              .and_then(|h| h.parse::<usize>().ok());
167  
168          if let Some(len) = content_length {
169              info!("Firmware size: {} bytes", len);
170          }
171  
172          // Initialize OTA
173          let mut ota = EspOta::new()?;
174          let mut ota_update = ota.initiate_update()?;
175  
176          info!("OTA partition prepared, starting download...");
177  
178          // Download and write firmware
179          let mut buffer = vec![0u8; OTA_BUFFER_SIZE];
180          let mut total_read = 0usize;
181  
182          loop {
183              match client.read(&mut buffer) {
184                  Ok(0) => break, // EOF
185                  Ok(size) => {
186                      ota_update.write(&buffer[..size])?;
187                      total_read += size;
188  
189                      if let Some(total) = content_length {
190                          let progress = (total_read as f32 / total as f32 * 100.0) as u32;
191                          if total_read % (100 * 1024) == 0 {
192                              info!(
193                                  "Download progress: {}% ({}/{})",
194                                  progress, total_read, total
195                              );
196                          }
197                      } else if total_read % (100 * 1024) == 0 {
198                          info!("Downloaded: {} bytes", total_read);
199                      }
200                  }
201                  Err(e) => {
202                      error!("Download error: {:?}", e);
203                      bail!("Failed to download firmware");
204                  }
205              }
206          }
207  
208          info!("Download complete: {} bytes", total_read);
209  
210          // Finalize and validate
211          if total_read == 0 {
212              bail!("No data received");
213          }
214  
215          info!("Finalizing OTA update...");
216          ota_update.complete()?;
217  
218          info!("OTA update completed successfully");
219          Ok(true)
220      }
221  
222      /// Install firmware from uploaded data
223      pub fn install_from_data(&self, firmware_data: &[u8]) -> Result<()> {
224          info!(
225              "Installing firmware from uploaded data ({} bytes)",
226              firmware_data.len()
227          );
228  
229          if firmware_data.is_empty() {
230              bail!("Empty firmware data");
231          }
232  
233          // Initialize OTA
234          let mut ota = EspOta::new()?;
235          let mut ota_update = ota.initiate_update()?;
236  
237          info!("OTA partition prepared, writing firmware...");
238  
239          // Write firmware data
240          let chunk_size = 4096;
241          for (i, chunk) in firmware_data.chunks(chunk_size).enumerate() {
242              ota_update.write(chunk)?;
243  
244              if i % 100 == 0 {
245                  let progress =
246                      ((i * chunk_size) as f32 / firmware_data.len() as f32 * 100.0) as u32;
247                  info!("Write progress: {}%", progress);
248              }
249          }
250  
251          info!("Firmware written, finalizing...");
252          ota_update.complete()?;
253  
254          info!("OTA update from upload completed successfully");
255          Ok(())
256      }
257  
258      /// Get current firmware version
259      pub fn get_current_version(&self) -> String {
260          // Read from app description if available
261          unsafe {
262              let app_desc = esp_idf_svc::sys::esp_app_get_description();
263              if !app_desc.is_null() {
264                  let version = std::ffi::CStr::from_ptr((*app_desc).version.as_ptr())
265                      .to_string_lossy()
266                      .to_string();
267                  if !version.is_empty() {
268                      return version;
269                  }
270              }
271          }
272          "unknown".to_string()
273      }
274  
275      /// Get partition information
276      pub fn get_partition_info(&self) -> Result<OtaPartitionInfo> {
277          unsafe {
278              let running = esp_idf_svc::sys::esp_ota_get_running_partition();
279              if running.is_null() {
280                  bail!("Failed to get running partition");
281              }
282  
283              let boot = esp_idf_svc::sys::esp_ota_get_boot_partition();
284              let next = esp_idf_svc::sys::esp_ota_get_next_update_partition(std::ptr::null());
285  
286              let running_label = if !running.is_null() {
287                  std::ffi::CStr::from_ptr((*running).label.as_ptr())
288                      .to_string_lossy()
289                      .to_string()
290              } else {
291                  "unknown".to_string()
292              };
293  
294              let next_label = if !next.is_null() {
295                  std::ffi::CStr::from_ptr((*next).label.as_ptr())
296                      .to_string_lossy()
297                      .to_string()
298              } else {
299                  "unknown".to_string()
300              };
301  
302              Ok(OtaPartitionInfo {
303                  running: running_label.clone(),
304                  boot: if !boot.is_null() && boot == running {
305                      running_label
306                  } else {
307                      "unknown".to_string()
308                  },
309                  next: next_label,
310              })
311          }
312      }
313  
314      /// Mark current firmware as valid (confirms successful boot)
315      /// Should be called after verifying the system is stable
316      pub fn mark_valid(&self) -> Result<()> {
317          info!("Marking current firmware as valid...");
318  
319          unsafe {
320              let running = esp_idf_svc::sys::esp_ota_get_running_partition();
321              if running.is_null() {
322                  bail!("Failed to get running partition");
323              }
324  
325              let result = esp_idf_svc::sys::esp_ota_mark_app_valid_cancel_rollback();
326              if result != esp_idf_svc::sys::ESP_OK {
327                  bail!("Failed to mark firmware as valid");
328              }
329          }
330  
331          info!("Firmware marked as valid - rollback canceled");
332          Ok(())
333      }
334  
335      /// Check if current firmware is pending validation
336      /// Returns true if this is first boot after OTA update
337      pub fn is_pending_validation(&self) -> bool {
338          unsafe {
339              let running = esp_idf_svc::sys::esp_ota_get_running_partition();
340              if running.is_null() {
341                  return false;
342              }
343  
344              let boot = esp_idf_svc::sys::esp_ota_get_boot_partition();
345              if boot.is_null() {
346                  return false;
347              }
348  
349              // If running partition != boot partition, we're pending validation
350              running != boot
351          }
352      }
353  
354      /// Initiate rollback to previous firmware
355      /// This will reboot the device into the previous firmware partition
356      pub fn rollback_and_reboot(&self) -> Result<()> {
357          info!("Initiating firmware rollback...");
358  
359          unsafe {
360              let running = esp_idf_svc::sys::esp_ota_get_running_partition();
361              if running.is_null() {
362                  bail!("Failed to get running partition");
363              }
364  
365              // Mark current firmware as invalid
366              let result = esp_idf_svc::sys::esp_ota_mark_app_invalid_rollback_and_reboot();
367  
368              // This function should not return if successful (device reboots)
369              if result != esp_idf_svc::sys::ESP_OK {
370                  bail!("Failed to initiate rollback");
371              }
372          }
373  
374          // Should never reach here
375          Ok(())
376      }
377  
378      /// Get last invalid firmware reason (if available)
379      #[allow(dead_code)]
380      pub fn get_last_invalid_reason(&self) -> Option<String> {
381          // This would require storing reason in NVS - implement if needed
382          None
383      }
384  }
385  
386  /// OTA partition information
387  #[derive(Debug, Clone)]
388  pub struct OtaPartitionInfo {
389      pub running: String,
390      pub boot: String,
391      pub next: String,
392  }