/ tools / scripts / dfu / bm_load_img_to_flash.py
bm_load_img_to_flash.py
  1  import argparse
  2  import serial
  3  import os
  4  import crcmod
  5  from pathlib import Path
  6  from base64 import b64encode
  7  from collections import namedtuple
  8  import struct
  9  from typing import Dict
 10  from typing import Any
 11  
 12  CLI_WRITE_SIZE = 128
 13  CHUNK_SIZE = 512
 14  
 15  # Header Payload
 16  DFU_HEADER = namedtuple( # NOTE: Must be in sync with bm_dfu_message_structs.h
 17      "DFU_HEADER",
 18      "img_size chunk_size img_crc maj min filter_key gitSHA",
 19  )
 20  # https://docs.python.org/3/library/struct.html#format-characters
 21  DFU_HEADER_STRUCT_ENCODING = "<LHHBBLL"
 22  
 23  NVM_DFU_WRITE_CMD_STR = "nvm b64write dfu"
 24  NVM_DFU_CRC16_CMD_STR = "nvm crc16 dfu"
 25  
 26  def getVersionFromBin(filename:str, offset:int=0) -> Dict[str, Any]:
 27      VersionHeader = namedtuple(
 28          "VersionHeader", "magic gitSHA maj min"
 29      )
 30      magic = 0xDF7F9AFDEC06627C
 31      header_format = "QIBB"
 32      header_len = struct.calcsize(header_format)
 33      version = None
 34      with open(filename, "rb") as binfile:
 35          file = binfile.read()
 36  
 37          for offset in range(offset, len(file) - header_len):
 38              header = VersionHeader._make(
 39                  struct.unpack_from(header_format, file, offset=offset)
 40              )
 41  
 42              if magic == header.magic:
 43                  version = {
 44                      "sha": f"{header.gitSHA:08X}",
 45                      "version": f"{header.maj}.{header.min}",
 46                  }
 47                  break
 48      return version
 49  
 50  def main(img_path:str, port:str, baud:int) -> None:
 51      abs_path = os.path.realpath(img_path)
 52  
 53      # Validity Checks / port open
 54      if not os.path.exists(abs_path):
 55          print("File does not exist")
 56          return
 57      
 58      if not os.path.exists(port):
 59          print("Port does not exist")
 60          return
 61  
 62      fw_ver = getVersionFromBin(abs_path)
 63      major = int(fw_ver["version"].split(".")[0])
 64      minor = int(fw_ver["version"].split(".")[1])
 65      gitSHA = int(fw_ver["sha"],16)
 66      ser = serial.Serial(
 67          port=port,
 68          baudrate=baud,
 69          parity=serial.PARITY_NONE,
 70          stopbits=serial.STOPBITS_ONE,
 71          bytesize=serial.EIGHTBITS,
 72          timeout=30,
 73      )
 74  
 75      # Get image + crc
 76      crc16 = crcmod.predefined.mkCrcFun("kermit")
 77      img_data = Path(abs_path).read_bytes()
 78  
 79      # Get image characteristics
 80      img_size = len(img_data)
 81      img_crc = crc16(img_data)
 82      print(f"Image Info: - size: {img_size}, crc16:{img_crc}, major: {major}, minor: {minor}, gitSHA: {gitSHA}")
 83  
 84      # send header
 85      header = DFU_HEADER(
 86          img_size,
 87          CHUNK_SIZE,
 88          img_crc,
 89          major,
 90          minor,
 91          0,
 92          gitSHA
 93      )
 94      header_payload = struct.pack(DFU_HEADER_STRUCT_ENCODING, *header)
 95      b64_header_payload = b64encode(header_payload).decode()
 96      ser.write(f"{NVM_DFU_WRITE_CMD_STR} 0 {b64_header_payload}\n".encode())
 97      if "#" not in ser.read_until(b'#').decode():
 98          print("Failed to write to flash.")
 99          return
100      ser.flush()
101  
102      # send image 
103      print("Sending image")
104      bytes_to_send = img_size
105      header_size = len(header_payload)
106      while bytes_to_send:
107          img_offset = img_size - bytes_to_send
108          dest_offset = header_size + img_offset
109          size_to_send = CLI_WRITE_SIZE if(bytes_to_send > CLI_WRITE_SIZE) else bytes_to_send
110          bin_payload = bytearray(img_data[img_offset : img_offset+size_to_send])
111          b64_payload = b64encode(bin_payload).decode()
112          print(f"Writing offset {img_offset} out of {img_size}: {b64_payload}")
113          ser.write(f"{NVM_DFU_WRITE_CMD_STR} {dest_offset} {b64_payload}\n".encode())
114          if "#" not in ser.read_until(b'#').decode():
115              print("Failed to write to flash.")
116              return
117          ser.flush()
118          bytes_to_send -= size_to_send
119  
120      # verify image 
121      print("Verifying Image")
122      ser.write(f"{NVM_DFU_CRC16_CMD_STR} {header_size} {img_size}\n".encode())
123      result_bytes = ser.read_until(b'#')
124      ser.flush()
125      if not result_bytes:
126          print("Unable to read crc16 from serial")
127          return
128      
129      result_str = result_bytes.decode()
130      comp_crc16 = int(result_str.split("<crc16>")[1].split('#')[0],16)
131      if comp_crc16 == img_crc:
132          print("Validation success! Image loaded in flash!")
133      else:
134          print(f"Validation failed crc16: {img_crc} computed crc16: {comp_crc16}")
135  
136  if __name__ == "__main__":
137      parser = argparse.ArgumentParser(
138          formatter_class=argparse.ArgumentDefaultsHelpFormatter
139      )
140      parser.add_argument(
141          "-i", "--image", dest="image", required=True, help="absolute path to BM Image"
142      )
143      parser.add_argument(
144          "-p", "--port", dest="port", required=True, help="Absolute path to Port"
145      )
146      parser.add_argument(
147          "-b", "--baud", dest="baud", required=False, help="Baudrate", default=921600
148      )
149      args = parser.parse_args()
150      try:
151          main(args.image, args.port, args.baud)
152      except Exception as e:
153          print("Failed to load image.")
154          print(str(e))