minimal_ssz.py
1 from typing import Any 2 3 from .hash_function import hash 4 5 BYTES_PER_CHUNK = 32 6 BYTES_PER_LENGTH_OFFSET = 4 7 ZERO_CHUNK = b'\x00' * BYTES_PER_CHUNK 8 9 10 def SSZType(fields): 11 class SSZObject(): 12 def __init__(self, **kwargs): 13 for f, t in fields.items(): 14 if f not in kwargs: 15 setattr(self, f, get_zero_value(t)) 16 else: 17 setattr(self, f, kwargs[f]) 18 19 def __eq__(self, other): 20 return self.fields == other.fields and self.serialize() == other.serialize() 21 22 def __hash__(self): 23 return int.from_bytes(self.hash_tree_root(), byteorder="little") 24 25 def __str__(self): 26 output = [] 27 for field in self.fields: 28 output.append(f'{field}: {getattr(self, field)}') 29 return "\n".join(output) 30 31 def serialize(self): 32 return serialize_value(self, self.__class__) 33 34 def hash_tree_root(self): 35 return hash_tree_root(self, self.__class__) 36 37 SSZObject.fields = fields 38 return SSZObject 39 40 41 class Vector(): 42 def __init__(self, items): 43 self.items = items 44 self.length = len(items) 45 46 def __getitem__(self, key): 47 return self.items[key] 48 49 def __setitem__(self, key, value): 50 self.items[key] = value 51 52 def __iter__(self): 53 return iter(self.items) 54 55 def __len__(self): 56 return self.length 57 58 59 def is_basic(typ): 60 # if not a string, it is a complex, and cannot be basic 61 if not isinstance(typ, str): 62 return False 63 # "uintN": N-bit unsigned integer (where N in [8, 16, 32, 64, 128, 256]) 64 elif typ[:4] == 'uint' and typ[4:] in ['8', '16', '32', '64', '128', '256']: 65 return True 66 # "bool": True or False 67 elif typ == 'bool': 68 return True 69 # alias: "byte" -> "uint8" 70 elif typ == 'byte': 71 return True 72 # default 73 else: 74 return False 75 76 77 def is_constant_sized(typ): 78 # basic objects are fixed size by definition 79 if is_basic(typ): 80 return True 81 # dynamic size array type, "list": [elem_type]. 82 # Not constant size by definition. 83 elif isinstance(typ, list) and len(typ) == 1: 84 return False 85 # fixed size array type, "vector": [elem_type, length] 86 # Constant size, but only if the elements are. 87 elif isinstance(typ, list) and len(typ) == 2: 88 return is_constant_sized(typ[0]) 89 # bytes array (fixed or dynamic size) 90 elif isinstance(typ, str) and typ[:5] == 'bytes': 91 # if no length suffix, it has a dynamic size 92 return typ != 'bytes' 93 # containers are only constant-size if all of the fields are constant size. 94 elif hasattr(typ, 'fields'): 95 for subtype in typ.fields.values(): 96 if not is_constant_sized(subtype): 97 return False 98 return True 99 else: 100 raise Exception("Type not recognized") 101 102 103 def coerce_to_bytes(x): 104 if isinstance(x, str): 105 o = x.encode('utf-8') 106 assert len(o) == len(x) 107 return o 108 elif isinstance(x, bytes): 109 return x 110 else: 111 raise Exception("Expecting bytes") 112 113 114 def encode_series(values, types): 115 # Recursively serialize 116 parts = [(is_constant_sized(types[i]), serialize_value(values[i], types[i])) for i in range(len(values))] 117 118 # Compute and check lengths 119 fixed_lengths = [len(serialized) if constant_size else BYTES_PER_LENGTH_OFFSET 120 for (constant_size, serialized) in parts] 121 variable_lengths = [len(serialized) if not constant_size else 0 122 for (constant_size, serialized) in parts] 123 124 # Check if integer is not out of bounds (Python) 125 assert sum(fixed_lengths + variable_lengths) < 2 ** (BYTES_PER_LENGTH_OFFSET * 8) 126 127 # Interleave offsets of variable-size parts with fixed-size parts. 128 # Avoid quadratic complexity in calculation of offsets. 129 offset = sum(fixed_lengths) 130 variable_parts = [] 131 fixed_parts = [] 132 for (constant_size, serialized) in parts: 133 if constant_size: 134 fixed_parts.append(serialized) 135 else: 136 fixed_parts.append(offset.to_bytes(BYTES_PER_LENGTH_OFFSET, 'little')) 137 variable_parts.append(serialized) 138 offset += len(serialized) 139 140 # Return the concatenation of the fixed-size parts (offsets interleaved) with the variable-size parts 141 return b"".join(fixed_parts + variable_parts) 142 143 144 def serialize_value(value, typ=None): 145 if typ is None: 146 typ = infer_type(value) 147 # "uintN" 148 if isinstance(typ, str) and typ[:4] == 'uint': 149 length = int(typ[4:]) 150 assert length in (8, 16, 32, 64, 128, 256) 151 return value.to_bytes(length // 8, 'little') 152 # "bool" 153 elif isinstance(typ, str) and typ == 'bool': 154 assert value in (True, False) 155 return b'\x01' if value is True else b'\x00' 156 # Vector 157 elif isinstance(typ, list) and len(typ) == 2: 158 # (regardless of element type, sanity-check if the length reported in the vector type matches the value length) 159 assert len(value) == typ[1] 160 return encode_series(value, [typ[0]] * len(value)) 161 # List 162 elif isinstance(typ, list) and len(typ) == 1: 163 return encode_series(value, [typ[0]] * len(value)) 164 # "bytes" (variable size) 165 elif isinstance(typ, str) and typ == 'bytes': 166 return coerce_to_bytes(value) 167 # "bytesN" (fixed size) 168 elif isinstance(typ, str) and len(typ) > 5 and typ[:5] == 'bytes': 169 assert len(value) == int(typ[5:]), (value, int(typ[5:])) 170 return coerce_to_bytes(value) 171 # containers 172 elif hasattr(typ, 'fields'): 173 values = [getattr(value, field) for field in typ.fields.keys()] 174 types = list(typ.fields.values()) 175 return encode_series(values, types) 176 else: 177 print(value, typ) 178 raise Exception("Type not recognized") 179 180 181 def get_zero_value(typ: Any) -> Any: 182 if isinstance(typ, str): 183 # Bytes array 184 if typ == 'bytes': 185 return b'' 186 # bytesN 187 elif typ[:5] == 'bytes' and len(typ) > 5: 188 length = int(typ[5:]) 189 return b'\x00' * length 190 # Basic types 191 elif typ == 'bool': 192 return False 193 elif typ[:4] == 'uint': 194 return 0 195 elif typ == 'byte': 196 return 0x00 197 else: 198 raise ValueError("Type not recognized") 199 # Vector: 200 elif isinstance(typ, list) and len(typ) == 2: 201 return [get_zero_value(typ[0]) for _ in range(typ[1])] 202 # List: 203 elif isinstance(typ, list) and len(typ) == 1: 204 return [] 205 # Container: 206 elif hasattr(typ, 'fields'): 207 return typ(**{field: get_zero_value(subtype) for field, subtype in typ.fields.items()}) 208 else: 209 print(typ) 210 raise Exception("Type not recognized") 211 212 213 def chunkify(bytez): 214 bytez += b'\x00' * (-len(bytez) % BYTES_PER_CHUNK) 215 return [bytez[i:i + 32] for i in range(0, len(bytez), 32)] 216 217 218 def pack(values, subtype): 219 return chunkify(b''.join([serialize_value(value, subtype) for value in values])) 220 221 222 def is_power_of_two(x): 223 return x > 0 and x & (x - 1) == 0 224 225 226 def merkleize(chunks): 227 tree = chunks[::] 228 while not is_power_of_two(len(tree)): 229 tree.append(ZERO_CHUNK) 230 tree = [ZERO_CHUNK] * len(tree) + tree 231 for i in range(len(tree) // 2 - 1, 0, -1): 232 tree[i] = hash(tree[i * 2] + tree[i * 2 + 1]) 233 return tree[1] 234 235 236 def mix_in_length(root, length): 237 return hash(root + length.to_bytes(32, 'little')) 238 239 240 def infer_type(value): 241 """ 242 Note: defaults to uint64 for integer type inference due to lack of information. 243 Other integer sizes are still supported, see spec. 244 :param value: The value to infer a SSZ type for. 245 :return: The SSZ type. 246 """ 247 if hasattr(value.__class__, 'fields'): 248 return value.__class__ 249 elif isinstance(value, Vector): 250 if len(value) > 0: 251 return [infer_type(value[0]), len(value)] 252 else: 253 # Element type does not matter too much, 254 # assumed to be a basic type for size-encoding purposes, vector is empty. 255 return ['uint64'] 256 elif isinstance(value, list): 257 if len(value) > 0: 258 return [infer_type(value[0])] 259 else: 260 # Element type does not matter, list-content size will be encoded regardless, list is empty. 261 return ['uint64'] 262 elif isinstance(value, (bytes, str)): 263 return 'bytes' 264 elif isinstance(value, int): 265 return 'uint64' 266 else: 267 raise Exception("Failed to infer type") 268 269 270 def hash_tree_root(value, typ=None): 271 if typ is None: 272 typ = infer_type(value) 273 # ------------------------------------- 274 # merkleize(pack(value)) 275 # basic object: merkleize packed version (merkleization pads it to 32 bytes if it is not already) 276 if is_basic(typ): 277 return merkleize(pack([value], typ)) 278 # or a vector of basic objects 279 elif isinstance(typ, list) and len(typ) == 2 and is_basic(typ[0]): 280 assert len(value) == typ[1] 281 return merkleize(pack(value, typ[0])) 282 # ------------------------------------- 283 # mix_in_length(merkleize(pack(value)), len(value)) 284 # if value is a list of basic objects 285 elif isinstance(typ, list) and len(typ) == 1 and is_basic(typ[0]): 286 return mix_in_length(merkleize(pack(value, typ[0])), len(value)) 287 # (needs some extra work for non-fixed-sized bytes array) 288 elif typ == 'bytes': 289 return mix_in_length(merkleize(chunkify(coerce_to_bytes(value))), len(value)) 290 # ------------------------------------- 291 # merkleize([hash_tree_root(element) for element in value]) 292 # if value is a vector of composite objects 293 elif isinstance(typ, list) and len(typ) == 2 and not is_basic(typ[0]): 294 return merkleize([hash_tree_root(element, typ[0]) for element in value]) 295 # (needs some extra work for fixed-sized bytes array) 296 elif isinstance(typ, str) and typ[:5] == 'bytes' and len(typ) > 5: 297 assert len(value) == int(typ[5:]) 298 return merkleize(chunkify(coerce_to_bytes(value))) 299 # or a container 300 elif hasattr(typ, 'fields'): 301 return merkleize([hash_tree_root(getattr(value, field), subtype) for field, subtype in typ.fields.items()]) 302 # ------------------------------------- 303 # mix_in_length(merkleize([hash_tree_root(element) for element in value]), len(value)) 304 # if value is a list of composite objects 305 elif isinstance(typ, list) and len(typ) == 1 and not is_basic(typ[0]): 306 return mix_in_length(merkleize([hash_tree_root(element, typ[0]) for element in value]), len(value)) 307 # ------------------------------------- 308 else: 309 raise Exception("Type not recognized") 310 311 312 def truncate(container): 313 field_keys = list(container.fields.keys()) 314 truncated_fields = { 315 key: container.fields[key] 316 for key in field_keys[:-1] 317 } 318 truncated_class = SSZType(truncated_fields) 319 kwargs = { 320 field: getattr(container, field) 321 for field in field_keys[:-1] 322 } 323 return truncated_class(**kwargs) 324 325 326 def signing_root(container): 327 return hash_tree_root(truncate(container)) 328 329 330 def serialize(ssz_object): 331 return getattr(ssz_object, 'serialize')()