/ test_libs / pyspec / eth2spec / utils / minimal_ssz.py
minimal_ssz.py
  1  from typing import Any
  2  
  3  from .hash_function import hash
  4  
  5  BYTES_PER_CHUNK = 32
  6  BYTES_PER_LENGTH_OFFSET = 4
  7  ZERO_CHUNK = b'\x00' * BYTES_PER_CHUNK
  8  
  9  
 10  def SSZType(fields):
 11      class SSZObject():
 12          def __init__(self, **kwargs):
 13              for f, t in fields.items():
 14                  if f not in kwargs:
 15                      setattr(self, f, get_zero_value(t))
 16                  else:
 17                      setattr(self, f, kwargs[f])
 18  
 19          def __eq__(self, other):
 20              return self.fields == other.fields and self.serialize() == other.serialize()
 21  
 22          def __hash__(self):
 23              return int.from_bytes(self.hash_tree_root(), byteorder="little")
 24  
 25          def __str__(self):
 26              output = []
 27              for field in self.fields:
 28                  output.append(f'{field}: {getattr(self, field)}')
 29              return "\n".join(output)
 30  
 31          def serialize(self):
 32              return serialize_value(self, self.__class__)
 33  
 34          def hash_tree_root(self):
 35              return hash_tree_root(self, self.__class__)
 36  
 37      SSZObject.fields = fields
 38      return SSZObject
 39  
 40  
 41  class Vector():
 42      def __init__(self, items):
 43          self.items = items
 44          self.length = len(items)
 45  
 46      def __getitem__(self, key):
 47          return self.items[key]
 48  
 49      def __setitem__(self, key, value):
 50          self.items[key] = value
 51  
 52      def __iter__(self):
 53          return iter(self.items)
 54  
 55      def __len__(self):
 56          return self.length
 57  
 58  
 59  def is_basic(typ):
 60      # if not a string, it is a complex, and cannot be basic
 61      if not isinstance(typ, str):
 62          return False
 63      # "uintN": N-bit unsigned integer (where N in [8, 16, 32, 64, 128, 256])
 64      elif typ[:4] == 'uint' and typ[4:] in ['8', '16', '32', '64', '128', '256']:
 65          return True
 66      # "bool": True or False
 67      elif typ == 'bool':
 68          return True
 69      # alias: "byte" -> "uint8"
 70      elif typ == 'byte':
 71          return True
 72      # default
 73      else:
 74          return False
 75  
 76  
 77  def is_constant_sized(typ):
 78      # basic objects are fixed size by definition
 79      if is_basic(typ):
 80          return True
 81      # dynamic size array type, "list": [elem_type].
 82      # Not constant size by definition.
 83      elif isinstance(typ, list) and len(typ) == 1:
 84          return False
 85      # fixed size array type, "vector": [elem_type, length]
 86      # Constant size, but only if the elements are.
 87      elif isinstance(typ, list) and len(typ) == 2:
 88          return is_constant_sized(typ[0])
 89      # bytes array (fixed or dynamic size)
 90      elif isinstance(typ, str) and typ[:5] == 'bytes':
 91          # if no length suffix, it has a dynamic size
 92          return typ != 'bytes'
 93      # containers are only constant-size if all of the fields are constant size.
 94      elif hasattr(typ, 'fields'):
 95          for subtype in typ.fields.values():
 96              if not is_constant_sized(subtype):
 97                  return False
 98          return True
 99      else:
100          raise Exception("Type not recognized")
101  
102  
103  def coerce_to_bytes(x):
104      if isinstance(x, str):
105          o = x.encode('utf-8')
106          assert len(o) == len(x)
107          return o
108      elif isinstance(x, bytes):
109          return x
110      else:
111          raise Exception("Expecting bytes")
112  
113  
114  def encode_series(values, types):
115      # Recursively serialize
116      parts = [(is_constant_sized(types[i]), serialize_value(values[i], types[i])) for i in range(len(values))]
117  
118      # Compute and check lengths
119      fixed_lengths = [len(serialized) if constant_size else BYTES_PER_LENGTH_OFFSET
120                       for (constant_size, serialized) in parts]
121      variable_lengths = [len(serialized) if not constant_size else 0
122                          for (constant_size, serialized) in parts]
123  
124      # Check if integer is not out of bounds (Python)
125      assert sum(fixed_lengths + variable_lengths) < 2 ** (BYTES_PER_LENGTH_OFFSET * 8)
126  
127      # Interleave offsets of variable-size parts with fixed-size parts.
128      # Avoid quadratic complexity in calculation of offsets.
129      offset = sum(fixed_lengths)
130      variable_parts = []
131      fixed_parts = []
132      for (constant_size, serialized) in parts:
133          if constant_size:
134              fixed_parts.append(serialized)
135          else:
136              fixed_parts.append(offset.to_bytes(BYTES_PER_LENGTH_OFFSET, 'little'))
137              variable_parts.append(serialized)
138              offset += len(serialized)
139  
140      # Return the concatenation of the fixed-size parts (offsets interleaved) with the variable-size parts
141      return b"".join(fixed_parts + variable_parts)
142  
143  
144  def serialize_value(value, typ=None):
145      if typ is None:
146          typ = infer_type(value)
147      # "uintN"
148      if isinstance(typ, str) and typ[:4] == 'uint':
149          length = int(typ[4:])
150          assert length in (8, 16, 32, 64, 128, 256)
151          return value.to_bytes(length // 8, 'little')
152      # "bool"
153      elif isinstance(typ, str) and typ == 'bool':
154          assert value in (True, False)
155          return b'\x01' if value is True else b'\x00'
156      # Vector
157      elif isinstance(typ, list) and len(typ) == 2:
158          # (regardless of element type, sanity-check if the length reported in the vector type matches the value length)
159          assert len(value) == typ[1]
160          return encode_series(value, [typ[0]] * len(value))
161      # List
162      elif isinstance(typ, list) and len(typ) == 1:
163          return encode_series(value, [typ[0]] * len(value))
164      # "bytes" (variable size)
165      elif isinstance(typ, str) and typ == 'bytes':
166          return coerce_to_bytes(value)
167      # "bytesN" (fixed size)
168      elif isinstance(typ, str) and len(typ) > 5 and typ[:5] == 'bytes':
169          assert len(value) == int(typ[5:]), (value, int(typ[5:]))
170          return coerce_to_bytes(value)
171      # containers
172      elif hasattr(typ, 'fields'):
173          values = [getattr(value, field) for field in typ.fields.keys()]
174          types = list(typ.fields.values())
175          return encode_series(values, types)
176      else:
177          print(value, typ)
178          raise Exception("Type not recognized")
179  
180  
181  def get_zero_value(typ: Any) -> Any:
182      if isinstance(typ, str):
183          # Bytes array
184          if typ == 'bytes':
185              return b''
186          # bytesN
187          elif typ[:5] == 'bytes' and len(typ) > 5:
188              length = int(typ[5:])
189              return b'\x00' * length
190          # Basic types
191          elif typ == 'bool':
192              return False
193          elif typ[:4] == 'uint':
194              return 0
195          elif typ == 'byte':
196              return 0x00
197          else:
198              raise ValueError("Type not recognized")
199      # Vector:
200      elif isinstance(typ, list) and len(typ) == 2:
201          return [get_zero_value(typ[0]) for _ in range(typ[1])]
202      # List:
203      elif isinstance(typ, list) and len(typ) == 1:
204          return []
205      # Container:
206      elif hasattr(typ, 'fields'):
207          return typ(**{field: get_zero_value(subtype) for field, subtype in typ.fields.items()})
208      else:
209          print(typ)
210          raise Exception("Type not recognized")
211  
212  
213  def chunkify(bytez):
214      bytez += b'\x00' * (-len(bytez) % BYTES_PER_CHUNK)
215      return [bytez[i:i + 32] for i in range(0, len(bytez), 32)]
216  
217  
218  def pack(values, subtype):
219      return chunkify(b''.join([serialize_value(value, subtype) for value in values]))
220  
221  
222  def is_power_of_two(x):
223      return x > 0 and x & (x - 1) == 0
224  
225  
226  def merkleize(chunks):
227      tree = chunks[::]
228      while not is_power_of_two(len(tree)):
229          tree.append(ZERO_CHUNK)
230      tree = [ZERO_CHUNK] * len(tree) + tree
231      for i in range(len(tree) // 2 - 1, 0, -1):
232          tree[i] = hash(tree[i * 2] + tree[i * 2 + 1])
233      return tree[1]
234  
235  
236  def mix_in_length(root, length):
237      return hash(root + length.to_bytes(32, 'little'))
238  
239  
240  def infer_type(value):
241      """
242      Note: defaults to uint64 for integer type inference due to lack of information.
243      Other integer sizes are still supported, see spec.
244      :param value: The value to infer a SSZ type for.
245      :return: The SSZ type.
246      """
247      if hasattr(value.__class__, 'fields'):
248          return value.__class__
249      elif isinstance(value, Vector):
250          if len(value) > 0:
251              return [infer_type(value[0]), len(value)]
252          else:
253              # Element type does not matter too much,
254              # assumed to be a basic type for size-encoding purposes, vector is empty.
255              return ['uint64']
256      elif isinstance(value, list):
257          if len(value) > 0:
258              return [infer_type(value[0])]
259          else:
260              # Element type does not matter, list-content size will be encoded regardless, list is empty.
261              return ['uint64']
262      elif isinstance(value, (bytes, str)):
263          return 'bytes'
264      elif isinstance(value, int):
265          return 'uint64'
266      else:
267          raise Exception("Failed to infer type")
268  
269  
270  def hash_tree_root(value, typ=None):
271      if typ is None:
272          typ = infer_type(value)
273      # -------------------------------------
274      # merkleize(pack(value))
275      # basic object: merkleize packed version (merkleization pads it to 32 bytes if it is not already)
276      if is_basic(typ):
277          return merkleize(pack([value], typ))
278      # or a vector of basic objects
279      elif isinstance(typ, list) and len(typ) == 2 and is_basic(typ[0]):
280          assert len(value) == typ[1]
281          return merkleize(pack(value, typ[0]))
282      # -------------------------------------
283      # mix_in_length(merkleize(pack(value)), len(value))
284      # if value is a list of basic objects
285      elif isinstance(typ, list) and len(typ) == 1 and is_basic(typ[0]):
286          return mix_in_length(merkleize(pack(value, typ[0])), len(value))
287      # (needs some extra work for non-fixed-sized bytes array)
288      elif typ == 'bytes':
289          return mix_in_length(merkleize(chunkify(coerce_to_bytes(value))), len(value))
290      # -------------------------------------
291      # merkleize([hash_tree_root(element) for element in value])
292      # if value is a vector of composite objects
293      elif isinstance(typ, list) and len(typ) == 2 and not is_basic(typ[0]):
294          return merkleize([hash_tree_root(element, typ[0]) for element in value])
295      # (needs some extra work for fixed-sized bytes array)
296      elif isinstance(typ, str) and typ[:5] == 'bytes' and len(typ) > 5:
297          assert len(value) == int(typ[5:])
298          return merkleize(chunkify(coerce_to_bytes(value)))
299      # or a container
300      elif hasattr(typ, 'fields'):
301          return merkleize([hash_tree_root(getattr(value, field), subtype) for field, subtype in typ.fields.items()])
302      # -------------------------------------
303      # mix_in_length(merkleize([hash_tree_root(element) for element in value]), len(value))
304      # if value is a list of composite objects
305      elif isinstance(typ, list) and len(typ) == 1 and not is_basic(typ[0]):
306          return mix_in_length(merkleize([hash_tree_root(element, typ[0]) for element in value]), len(value))
307      # -------------------------------------
308      else:
309          raise Exception("Type not recognized")
310  
311  
312  def truncate(container):
313      field_keys = list(container.fields.keys())
314      truncated_fields = {
315          key: container.fields[key]
316          for key in field_keys[:-1]
317      }
318      truncated_class = SSZType(truncated_fields)
319      kwargs = {
320          field: getattr(container, field)
321          for field in field_keys[:-1]
322      }
323      return truncated_class(**kwargs)
324  
325  
326  def signing_root(container):
327      return hash_tree_root(truncate(container))
328  
329  
330  def serialize(ssz_object):
331      return getattr(ssz_object, 'serialize')()