/ contrib / asmap / asmap.py
asmap.py
  1  # Copyright (c) 2022 Pieter Wuille
  2  # Distributed under the MIT software license, see the accompanying
  3  # file LICENSE or http://www.opensource.org/licenses/mit-license.php.
  4  
  5  """
  6  This module provides the ASNEntry and ASMap classes.
  7  """
  8  
  9  import copy
 10  import ipaddress
 11  import random
 12  import unittest
 13  from collections.abc import Callable, Iterable
 14  from enum import Enum
 15  from functools import total_ordering
 16  from typing import Optional, Union, overload
 17  
 18  def net_to_prefix(net: Union[ipaddress.IPv4Network,ipaddress.IPv6Network]) -> list[bool]:
 19      """
 20      Convert an IPv4 or IPv6 network to a prefix represented as a list of bits.
 21  
 22      IPv4 ranges are remapped to their IPv4-mapped IPv6 range (::ffff:0:0/96).
 23      """
 24      num_bits = net.prefixlen
 25      netrange = int.from_bytes(net.network_address.packed, 'big')
 26  
 27      # Map an IPv4 prefix into IPv6 space.
 28      if isinstance(net, ipaddress.IPv4Network):
 29          num_bits += 96
 30          netrange += 0xffff00000000
 31  
 32      # Strip unused bottom bits.
 33      assert (netrange & ((1 << (128 - num_bits)) - 1)) == 0
 34      return [((netrange >> (127 - i)) & 1) != 0 for i in range(num_bits)]
 35  
 36  def prefix_to_net(prefix: list[bool]) -> Union[ipaddress.IPv4Network,ipaddress.IPv6Network]:
 37      """The reverse operation of net_to_prefix."""
 38      # Convert to number
 39      netrange = sum(b << (127 - i) for i, b in enumerate(prefix))
 40      num_bits = len(prefix)
 41      assert num_bits <= 128
 42  
 43      # Return IPv4 range if in ::ffff:0:0/96
 44      if num_bits >= 96 and (netrange >> 32) == 0xffff:
 45          return ipaddress.IPv4Network((netrange & 0xffffffff, num_bits - 96), True)
 46  
 47      # Return IPv6 range otherwise.
 48      return ipaddress.IPv6Network((netrange, num_bits), True)
 49  
 50  # Shortcut for (prefix, ASN) entries.
 51  ASNEntry = tuple[list[bool], int]
 52  
 53  # Shortcut for (prefix, old ASN, new ASN) entries.
 54  ASNDiff = tuple[list[bool], int, int]
 55  
 56  class _VarLenCoder:
 57      """
 58      A class representing a custom variable-length binary encoder/decoder for
 59      integers. Each object represents a different coder, with different parameters
 60      minval and clsbits.
 61  
 62      The encoding is easiest to describe using an example. Let's say minval=100 and
 63      clsbits=[4,2,2,3]. In that case:
 64      - x in [100..115]: encoded as [0] + [4-bit BE encoding of (x-100)].
 65      - x in [116..119]: encoded as [1,0] + [2-bit BE encoding of (x-116)].
 66      - x in [120..123]: encoded as [1,1,0] + [2-bit BE encoding of (x-120)].
 67      - x in [124..131]: encoded as [1,1,1] + [3-bit BE encoding of (x-124)].
 68  
 69      In general, every number is encoded as:
 70      - First, k "1"-bits, where k is the class the number falls in (there is one class
 71        per element of clsbits).
 72      - Then, a "0"-bit, unless k is the highest class, in which case there is nothing.
 73      - Lastly, clsbits[k] bits encoding in big endian the position in its class that
 74        number falls into.
 75      - Every class k consists of 2^clsbits[k] consecutive integers. k=0 starts at minval,
 76        other classes start one past the last element of the class before it.
 77      """
 78  
 79      def __init__(self, minval: int, clsbits: list[int]):
 80          """Construct a new _VarLenCoder."""
 81          self._minval = minval
 82          self._clsbits = clsbits
 83          self._maxval = minval + sum(1 << b for b in clsbits) - 1
 84  
 85      def can_encode(self, val: int) -> bool:
 86          """Check whether value val is in the range this coder supports."""
 87          return self._minval <= val <= self._maxval
 88  
 89      def encode(self, val: int, ret: list[int]) -> None:
 90          """Append encoding of val onto integer list ret."""
 91  
 92          assert self._minval <= val <= self._maxval
 93          val -= self._minval
 94          bits = 0
 95          for k, bits in enumerate(self._clsbits):
 96              if val >> bits:
 97                  # If the value will not fit in class k, subtract its range from v,
 98                  # emit a "1" bit and continue with the next class.
 99                  val -= 1 << bits
100                  ret.append(1)
101              else:
102                  if k + 1 < len(self._clsbits):
103                      # Unless we're in the last class, emit a "0" bit.
104                      ret.append(0)
105                  break
106          # And then encode v (now the position within the class) in big endian.
107          ret.extend((val >> (bits - 1 - b)) & 1 for b in range(bits))
108  
109      def encode_size(self, val: int) -> int:
110          """Compute how many bits are needed to encode val."""
111          assert self._minval <= val <= self._maxval
112          val -= self._minval
113          ret = 0
114          bits = 0
115          for k, bits in enumerate(self._clsbits):
116              if val >> bits:
117                  val -= 1 << bits
118                  ret += 1
119              else:
120                  ret += k + 1 < len(self._clsbits)
121                  break
122          return ret + bits
123  
124      def decode(self, stream, bitpos) -> tuple[int,int]:
125          """Decode a number starting at bitpos in stream, returning value and new bitpos."""
126          val = self._minval
127          bits = 0
128          for k, bits in enumerate(self._clsbits):
129              bit = 0
130              if k + 1 < len(self._clsbits):
131                  bit = stream[bitpos]
132                  bitpos += 1
133              if not bit:
134                  break
135              val += 1 << bits
136          for i in range(bits):
137              bit = stream[bitpos]
138              bitpos += 1
139              val += bit << (bits - 1 - i)
140          return val, bitpos
141  
142  # Variable-length encoders used in the binary asmap format.
143  _CODER_INS = _VarLenCoder(0, [0, 0, 1])
144  _CODER_ASN = _VarLenCoder(1, list(range(15, 25)))
145  _CODER_MATCH = _VarLenCoder(2, list(range(1, 9)))
146  _CODER_JUMP = _VarLenCoder(17, list(range(5, 31)))
147  
148  class _Instruction(Enum):
149      """One instruction in the binary asmap format."""
150      # A return instruction, encoded as [0], returns a constant ASN. It is followed by
151      # an integer using the ASN encoding.
152      RETURN = 0
153      # A jump instruction, encoded as [1,0] inspects the next unused bit in the input
154      # and either continues execution (if 0), or skips a specified number of bits (if 1).
155      # It is followed by an integer, and then two subprograms. The integer uses jump encoding
156      # and corresponds to the length of the first subprogram (so it can be skipped).
157      JUMP = 1
158      # A match instruction, encoded as [1,1,0] inspects 1 or more of the next unused bits
159      # in the input with its argument. If they all match, execution continues. If they do
160      # not, failure (represented by 0) is returned. If a default instruction has been executed before, instead
161      # of failure the default instruction's argument is returned. It is followed by an
162      # integer in match encoding, and a subprogram. That value is at least 2 bits and at
163      # most 9 bits. An n-bit value signifies matching (n-1) bits in the input with the lower
164      # (n-1) bits in the match value.
165      MATCH = 2
166      # A default instruction, encoded as [1,1,1] sets the default variable to its argument,
167      # and continues execution. It is followed by an integer in ASN encoding, and a subprogram.
168      DEFAULT = 3
169      # Not an actual instruction, but a way to encode the empty program that fails. In the
170      # encoder, it is used more generally to represent the failure case inside MATCH instructions,
171      # which may (if used inside the context of a DEFAULT instruction) actually correspond to
172      # a successful return. In this usage, they're always converted to an actual MATCH or RETURN
173      # before the top level is reached (see make_default below).
174      END = 4
175  
176  class _BinNode:
177      """A class representing a (node of) the parsed binary asmap format."""
178  
179      @overload
180      def __init__(self, ins: _Instruction): ...
181      @overload
182      def __init__(self, ins: _Instruction, arg1: int): ...
183      @overload
184      def __init__(self, ins: _Instruction, arg1: "_BinNode", arg2: "_BinNode"): ...
185      @overload
186      def __init__(self, ins: _Instruction, arg1: int, arg2: "_BinNode"): ...
187  
188      def __init__(self, ins: _Instruction, arg1=None, arg2=None):
189          """
190          Construct a new asmap node. Possibilities are:
191          - _BinNode(_Instruction.RETURN, asn)
192          - _BinNode(_Instruction.JUMP, node_0, node_1)
193          - _BinNode(_Instruction.MATCH, val, node)
194          - _BinNode(_Instruction.DEFAULT, asn, node)
195          - _BinNode(_Instruction.END)
196          """
197          self.ins = ins
198          self.arg1 = arg1
199          self.arg2 = arg2
200          if ins == _Instruction.RETURN:
201              assert isinstance(arg1, int)
202              assert arg2 is None
203              self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1)
204          elif ins == _Instruction.JUMP:
205              assert isinstance(arg1, _BinNode)
206              assert isinstance(arg2, _BinNode)
207              self.size = (_CODER_INS.encode_size(ins.value) + _CODER_JUMP.encode_size(arg1.size) +
208                           arg1.size + arg2.size)
209          elif ins == _Instruction.DEFAULT:
210              assert isinstance(arg1, int)
211              assert isinstance(arg2, _BinNode)
212              self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1) + arg2.size
213          elif ins == _Instruction.MATCH:
214              assert isinstance(arg1, int)
215              assert isinstance(arg2, _BinNode)
216              self.size = (_CODER_INS.encode_size(ins.value) + _CODER_MATCH.encode_size(arg1)
217                           + arg2.size)
218          elif ins == _Instruction.END:
219              assert arg1 is None
220              assert arg2 is None
221              self.size = 0
222          else:
223              assert False
224  
225      @staticmethod
226      def make_end() -> "_BinNode":
227          """Constructor for a _BinNode with just an END instruction."""
228          return _BinNode(_Instruction.END)
229  
230      @staticmethod
231      def make_leaf(val: int) -> "_BinNode":
232          """Constructor for a _BinNode of just a RETURN instruction."""
233          assert val is not None and val > 0
234          return _BinNode(_Instruction.RETURN, val)
235  
236      @staticmethod
237      def make_branch(node0: "_BinNode", node1: "_BinNode") -> "_BinNode":
238          """
239          Construct a _BinNode corresponding to running either the node0 or node1 subprogram,
240          based on the next input bit. It exploits shortcuts that are possible in the encoding,
241          and uses either a JUMP, MATCH, or END instruction.
242          """
243          if node0.ins == _Instruction.END and node1.ins == _Instruction.END:
244              return node0
245          if node0.ins == _Instruction.END:
246              if node1.ins == _Instruction.MATCH and node1.arg1 <= 0xFF:
247                  return _BinNode(node1.ins, node1.arg1 + (1 << node1.arg1.bit_length()), node1.arg2)
248              return _BinNode(_Instruction.MATCH, 3, node1)
249          if node1.ins == _Instruction.END:
250              if node0.ins == _Instruction.MATCH and node0.arg1 <= 0xFF:
251                  return _BinNode(node0.ins, node0.arg1 + (1 << (node0.arg1.bit_length() - 1)),
252                                  node0.arg2)
253              return _BinNode(_Instruction.MATCH, 2, node0)
254          return _BinNode(_Instruction.JUMP, node0, node1)
255  
256      @staticmethod
257      def make_default(val: int, sub: "_BinNode") -> "_BinNode":
258          """
259          Construct a _BinNode that corresponds to the specified subprogram, with the specified
260          default value. It exploits shortcuts that are possible in the encoding, and will use
261          either a DEFAULT or a RETURN instruction."""
262          assert val is not None and val > 0
263          if sub.ins == _Instruction.END:
264              return _BinNode(_Instruction.RETURN, val)
265          if sub.ins in (_Instruction.RETURN, _Instruction.DEFAULT):
266              return sub
267          return _BinNode(_Instruction.DEFAULT, val, sub)
268  
269  @total_ordering
270  class ASMap:
271      """
272      A class whose objects represent a mapping from subnets to ASNs.
273  
274      Internally the mapping is stored as a binary trie, but can be converted
275      from/to a list of ASNEntry objects, and from/to the binary asmap file format.
276  
277      In the trie representation, nodes are represented as bare lists for efficiency
278      and ease of manipulation:
279      - [0] means an unassigned subnet (no ASN mapping for it is present)
280      - [int] means a subnet mapped entirely to the specified ASN.
281      - [node,node] means a subnet whose lower half and upper half have different
282      -             mappings, represented by new trie nodes.
283      """
284  
285      def update(self, prefix: list[bool], asn: int) -> None:
286          """Update this ASMap object to map prefix to the specified asn."""
287          assert asn == 0 or _CODER_ASN.can_encode(asn)
288  
289          def recurse(node: list, offset: int) -> None:
290              if offset == len(prefix):
291                  # Reached the end of prefix; overwrite this node.
292                  node.clear()
293                  node.append(asn)
294                  return
295              if len(node) == 1:
296                  # Need to descend into a leaf node; split it up.
297                  oldasn = node[0]
298                  node.clear()
299                  node.append([oldasn])
300                  node.append([oldasn])
301              # Descend into the node.
302              recurse(node[prefix[offset]], offset + 1)
303              # If the result is two identical leaf children, merge them.
304              if len(node[0]) == 1 and len(node[1]) == 1 and node[0] == node[1]:
305                  oldasn = node[0][0]
306                  node.clear()
307                  node.append(oldasn)
308          recurse(self._trie, 0)
309  
310      def update_multi(self, entries: list[tuple[list[bool], int]]) -> None:
311          """Apply multiple update operations, where longer prefixes take precedence."""
312          entries.sort(key=lambda entry: len(entry[0]))
313          for prefix, asn in entries:
314              self.update(prefix, asn)
315  
316      def _set_trie(self, trie) -> None:
317          """Set trie directly. Internal use only."""
318          def recurse(node: list) -> None:
319              if len(node) < 2:
320                  return
321              recurse(node[0])
322              recurse(node[1])
323              if len(node[0]) == 2:
324                  return
325              if node[0] == node[1]:
326                  if len(node[0]) == 0:
327                      node.clear()
328                  else:
329                      asn = node[0][0]
330                      node.clear()
331                      node.append(asn)
332          recurse(trie)
333          self._trie = trie
334  
335      def __init__(self, entries: Optional[Iterable[ASNEntry]] = None) -> None:
336          """Construct an ASMap object from an optional list of entries."""
337          self._trie = [0]
338          if entries is not None:
339              def entry_key(entry):
340                  """Sort function that places shorter prefixes first."""
341                  prefix, asn = entry
342                  return len(prefix), prefix, asn
343              for prefix, asn in sorted(entries, key=entry_key):
344                  self.update(prefix, asn)
345  
346      def lookup(self, prefix: list[bool]) -> Optional[int]:
347          """Look up a prefix. Returns ASN, or 0 if unassigned, or None if indeterminate."""
348          node = self._trie
349          for bit in prefix:
350              if len(node) == 1:
351                  break
352              node = node[bit]
353          if len(node) == 1:
354              return node[0]
355          return None
356  
357      def _to_entries_flat(self, fill: bool = False) -> list[ASNEntry]:
358          """Convert an ASMap object to a list of non-overlapping (prefix, asn) objects."""
359          prefix : list[bool] = []
360  
361          def recurse(node: list) -> list[ASNEntry]:
362              ret = []
363              if len(node) == 1:
364                  if node[0] > 0:
365                      ret = [(list(prefix), node[0])]
366              elif len(node) == 2:
367                  prefix.append(False)
368                  ret = recurse(node[0])
369                  prefix[-1] = True
370                  ret += recurse(node[1])
371                  prefix.pop()
372                  if fill and len(ret) > 1:
373                      asns = set(x[1] for x in ret)
374                      if len(asns) == 1:
375                          ret = [(list(prefix), list(asns)[0])]
376              return ret
377          return recurse(self._trie)
378  
379      def _to_entries_minimal(self, fill: bool = False) -> list[ASNEntry]:
380          """Convert a trie to a minimal list of ASNEntry objects, exploiting overlap."""
381          prefix : list[bool] = []
382  
383          def recurse(node: list) -> (tuple[dict[Optional[int], list[ASNEntry]], bool]):
384              if len(node) == 1 and node[0] == 0:
385                  return {None if fill else 0: []}, True
386              if len(node) == 1:
387                  return {node[0]: [], None: [(list(prefix), node[0])]}, False
388              ret: dict[Optional[int], list[ASNEntry]] = {}
389              prefix.append(False)
390              left, lhole = recurse(node[0])
391              prefix[-1] = True
392              right, rhole = recurse(node[1])
393              prefix.pop()
394              hole = not fill and (lhole or rhole)
395              def candidate(ctx: Optional[int], res0: Optional[list[ASNEntry]],
396                      res1: Optional[list[ASNEntry]]):
397                  if res0 is not None and res1 is not None:
398                      if ctx not in ret or len(res0) + len(res1) < len(ret[ctx]):
399                          ret[ctx] = res0 + res1
400              for ctx in set(left) | set(right):
401                  candidate(ctx, left.get(ctx), right.get(ctx))
402                  candidate(ctx, left.get(None), right.get(ctx))
403                  candidate(ctx, left.get(ctx), right.get(None))
404              if not hole:
405                  for ctx in list(ret):
406                      if ctx is not None:
407                          candidate(None, [(list(prefix), ctx)], ret[ctx])
408              if None in ret:
409                  ret = {ctx:entries for ctx, entries in ret.items()
410                         if ctx is None or len(entries) < len(ret[None])}
411              if hole:
412                  ret = {ctx:entries for ctx, entries in ret.items() if ctx is None or ctx == 0}
413              return ret, hole
414          res, _ = recurse(self._trie)
415          return res[0] if 0 in res else res[None]
416  
417      def __str__(self) -> str:
418          """Convert this ASMap object to a string containing Python code constructing it."""
419          return f"ASMap({self._trie})"
420  
421      def to_entries(self, overlapping: bool = True, fill: bool = False) -> list[ASNEntry]:
422          """
423          Convert the mappings in this ASMap object to a list of ASNEntry objects.
424  
425          Arguments:
426              overlapping: Permit the subnets in the resulting ASNEntry to overlap.
427                           Setting this can result in a shorter list.
428              fill:        Permit the resulting ASNEntry objects to cover subnets that
429                           are unassigned in this ASMap object. Setting this can
430                           result in a shorter list.
431          """
432          if overlapping:
433              return self._to_entries_minimal(fill)
434          return self._to_entries_flat(fill)
435  
436      @staticmethod
437      def from_random(num_leaves: int = 10, max_asn: int = 6,
438                      unassigned_prob: float = 0.5) -> "ASMap":
439          """
440          Construct a random ASMap object, with specified:
441           - Number of leaves in its trie (at least 1)
442           - Maximum ASN value (at least 1)
443           - Probability for leaf nodes to be unassigned
444  
445          The number of leaves in the resulting object may be less than what is
446          requested. This method is mostly intended for testing.
447          """
448          assert num_leaves >= 1
449          assert max_asn >= 1 or unassigned_prob == 1
450          assert _CODER_ASN.can_encode(max_asn)
451          assert 0.0 <= unassigned_prob <= 1.0
452          trie: list = []
453          leaves = [trie]
454          ret = ASMap()
455          for i in range(1, num_leaves):
456              idx = random.randrange(i)
457              leaf = leaves[idx]
458              lastleaf = leaves.pop()
459              if idx + 1 < i:
460                  leaves[idx] = lastleaf
461              leaf.append([])
462              leaf.append([])
463              leaves.append(leaf[0])
464              leaves.append(leaf[1])
465          for leaf in leaves:
466              if random.random() >= unassigned_prob:
467                  leaf.append(random.randrange(1, max_asn + 1))
468              else:
469                  leaf.append(0)
470          #pylint: disable=protected-access
471          ret._set_trie(trie)
472          return ret
473  
474      def _to_binnode(self, fill: bool = False) -> _BinNode:
475          """Convert a trie to a _BinNode object."""
476          def recurse(node: list) -> tuple[dict[Optional[int], _BinNode], bool]:
477              if len(node) == 1 and node[0] == 0:
478                  return {(None if fill else 0): _BinNode.make_end()}, True
479              if len(node) == 1:
480                  return {None: _BinNode.make_leaf(node[0]), node[0]: _BinNode.make_end()}, False
481              ret: dict[Optional[int], _BinNode] = {}
482              left, lhole = recurse(node[0])
483              right, rhole = recurse(node[1])
484              hole = (lhole or rhole) and not fill
485  
486              def candidate(ctx: Optional[int], arg1, arg2, func: Callable):
487                  if arg1 is not None and arg2 is not None:
488                      cand = func(arg1, arg2)
489                      if ctx not in ret or cand.size < ret[ctx].size:
490                          ret[ctx] = cand
491  
492              union = set(left) | set(right)
493              sorted_union = sorted(union, key=lambda x: (x is None, x))
494              for ctx in sorted_union:
495                  candidate(ctx, left.get(ctx), right.get(ctx), _BinNode.make_branch)
496                  candidate(ctx, left.get(None), right.get(ctx), _BinNode.make_branch)
497                  candidate(ctx, left.get(ctx), right.get(None), _BinNode.make_branch)
498              if not hole:
499                  for ctx in sorted(set(ret) - set([None])):
500                      candidate(None, ctx, ret[ctx], _BinNode.make_default)
501              if None in ret:
502                  ret = {ctx:enc for ctx, enc in ret.items()
503                         if ctx is None or enc.size < ret[None].size}
504              if hole:
505                  ret = {ctx:enc for ctx, enc in ret.items() if ctx is None or ctx == 0}
506              return ret, hole
507          res, _ = recurse(self._trie)
508          return res[0] if 0 in res else res[None]
509  
510      @staticmethod
511      def _from_binnode(binnode: _BinNode) -> "ASMap":
512          """Construct an ASMap object from a _BinNode. Internal use only."""
513          def recurse(node: _BinNode, default: int) -> list:
514              if node.ins == _Instruction.RETURN:
515                  return [node.arg1]
516              if node.ins == _Instruction.JUMP:
517                  return [recurse(node.arg1, default), recurse(node.arg2, default)]
518              if node.ins == _Instruction.MATCH:
519                  val = node.arg1
520                  sub = recurse(node.arg2, default)
521                  while val >= 2:
522                      bit = val & 1
523                      val >>= 1
524                      if bit:
525                          sub = [[default], sub]
526                      else:
527                          sub = [sub, [default]]
528                  return sub
529              assert node.ins == _Instruction.DEFAULT
530              return recurse(node.arg2, node.arg1)
531          ret = ASMap()
532          if binnode.ins != _Instruction.END:
533              #pylint: disable=protected-access
534              ret._set_trie(recurse(binnode, 0))
535          return ret
536  
537      def to_binary(self, fill: bool = False) -> bytes:
538          """
539          Convert this ASMap object to binary.
540  
541          Argument:
542              fill: permit the resulting binary encoder to contain mappers for
543                    unassigned subnets in this ASMap object. Doing so may
544                    reduce the size of the encoding.
545          Returns:
546              A bytes object with the encoding of this ASMap object.
547          """
548          bits: list[int] = []
549  
550          def recurse(node: _BinNode) -> None:
551              _CODER_INS.encode(node.ins.value, bits)
552              if node.ins == _Instruction.RETURN:
553                  _CODER_ASN.encode(node.arg1, bits)
554              elif node.ins == _Instruction.JUMP:
555                  _CODER_JUMP.encode(node.arg1.size, bits)
556                  recurse(node.arg1)
557                  recurse(node.arg2)
558              elif node.ins == _Instruction.DEFAULT:
559                  _CODER_ASN.encode(node.arg1, bits)
560                  recurse(node.arg2)
561              else:
562                  assert node.ins == _Instruction.MATCH
563                  _CODER_MATCH.encode(node.arg1, bits)
564                  recurse(node.arg2)
565  
566          binnode = self._to_binnode(fill)
567          if binnode.ins != _Instruction.END:
568              recurse(binnode)
569  
570          val = 0
571          nbits = 0
572          ret = []
573          for bit in bits:
574              val += (bit << nbits)
575              nbits += 1
576              if nbits == 8:
577                  ret.append(val)
578                  val = 0
579                  nbits = 0
580          if nbits:
581              ret.append(val)
582          return bytes(ret)
583  
584      @staticmethod
585      def from_binary(bindata: bytes) -> Optional["ASMap"]:
586          """Decode an ASMap object from the provided binary encoding."""
587  
588          bits: list[int] = []
589          for byte in bindata:
590              bits.extend((byte >> i) & 1 for i in range(8))
591  
592          def recurse(bitpos: int) -> tuple[_BinNode, int]:
593              insval, bitpos = _CODER_INS.decode(bits, bitpos)
594              ins = _Instruction(insval)
595              if ins == _Instruction.RETURN:
596                  asn, bitpos = _CODER_ASN.decode(bits, bitpos)
597                  return _BinNode(ins, asn), bitpos
598              if ins == _Instruction.JUMP:
599                  jump, bitpos = _CODER_JUMP.decode(bits, bitpos)
600                  left, bitpos1 = recurse(bitpos)
601                  if bitpos1 != bitpos + jump:
602                      raise ValueError("Inconsistent jump")
603                  right, bitpos = recurse(bitpos1)
604                  return _BinNode(ins, left, right), bitpos
605              if ins == _Instruction.MATCH:
606                  match, bitpos = _CODER_MATCH.decode(bits, bitpos)
607                  sub, bitpos = recurse(bitpos)
608                  return _BinNode(ins, match, sub), bitpos
609              assert ins == _Instruction.DEFAULT
610              asn, bitpos = _CODER_ASN.decode(bits, bitpos)
611              sub, bitpos = recurse(bitpos)
612              return _BinNode(ins, asn, sub), bitpos
613  
614          if len(bits) == 0:
615              binnode = _BinNode(_Instruction.END)
616          else:
617              try:
618                  binnode, bitpos = recurse(0)
619              except (ValueError, IndexError):
620                  return None
621              if bitpos < len(bits) - 7:
622                  return None
623              if not all(bit == 0 for bit in bits[bitpos:]):
624                  return None
625  
626          return ASMap._from_binnode(binnode)
627  
628      def __lt__(self, other: "ASMap") -> bool:
629          return self._trie < other._trie
630  
631      def __eq__(self, other: object) -> bool:
632          if isinstance(other, ASMap):
633              return self._trie == other._trie
634          return False
635  
636      def extends(self, req: "ASMap") -> bool:
637          """Determine whether this matches req for all subranges where req is assigned."""
638          def recurse(actual: list, require: list) -> bool:
639              if len(require) == 1 and require[0] == 0:
640                  return True
641              if len(require) == 1:
642                  if len(actual) == 1:
643                      return bool(require[0] == actual[0])
644                  return recurse(actual[0], require) and recurse(actual[1], require)
645              if len(actual) == 2:
646                  return recurse(actual[0], require[0]) and recurse(actual[1], require[1])
647              return recurse(actual, require[0]) and recurse(actual, require[1])
648          assert isinstance(req, ASMap)
649          #pylint: disable=protected-access
650          return recurse(self._trie, req._trie)
651  
652      def diff(self, other: "ASMap") -> list[ASNDiff]:
653          """Compute the diff from self to other."""
654          prefix: list[bool] = []
655          ret: list[ASNDiff] = []
656  
657          def recurse(old_node: list, new_node: list):
658              if len(old_node) == 1 and len(new_node) == 1:
659                  if old_node[0] != new_node[0]:
660                      ret.append((list(prefix), old_node[0], new_node[0]))
661              else:
662                  old_left: list = old_node if len(old_node) == 1 else old_node[0]
663                  old_right: list = old_node if len(old_node) == 1 else old_node[1]
664                  new_left: list = new_node if len(new_node) == 1 else new_node[0]
665                  new_right: list = new_node if len(new_node) == 1 else new_node[1]
666                  prefix.append(False)
667                  recurse(old_left, new_left)
668                  prefix[-1] = True
669                  recurse(old_right, new_right)
670                  prefix.pop()
671          assert isinstance(other, ASMap)
672          #pylint: disable=protected-access
673          recurse(self._trie, other._trie)
674          return ret
675  
676      def __copy__(self) -> "ASMap":
677          """Construct a copy of this ASMap object. Its state will not be shared."""
678          ret = ASMap()
679          #pylint: disable=protected-access
680          ret._set_trie(copy.deepcopy(self._trie))
681          return ret
682  
683      def __deepcopy__(self, _) -> "ASMap":
684          # ASMap objects do not allow sharing of the _trie member, so we don't need the memoization.
685          return self.__copy__()
686  
687  
688  class TestASMap(unittest.TestCase):
689      """Unit tests for this module."""
690  
691      def test_ipv6_prefix_roundtrips(self) -> None:
692          """Test that random IPv6 network ranges roundtrip through prefix encoding."""
693          for _ in range(20):
694              net_bits = random.getrandbits(128)
695              for prefix_len in range(0, 129):
696                  masked_bits = (net_bits >> (128 - prefix_len)) << (128 - prefix_len)
697                  net = ipaddress.IPv6Network((masked_bits.to_bytes(16, 'big'), prefix_len))
698                  prefix = net_to_prefix(net)
699                  self.assertTrue(len(prefix) <= 128)
700                  net2 = prefix_to_net(prefix)
701                  self.assertEqual(net, net2)
702  
703      def test_ipv4_prefix_roundtrips(self) -> None:
704          """Test that random IPv4 network ranges roundtrip through prefix encoding."""
705          for _ in range(100):
706              net_bits = random.getrandbits(32)
707              for prefix_len in range(0, 33):
708                  masked_bits = (net_bits >> (32 - prefix_len)) << (32 - prefix_len)
709                  net = ipaddress.IPv4Network((masked_bits.to_bytes(4, 'big'), prefix_len))
710                  prefix = net_to_prefix(net)
711                  self.assertTrue(32 <= len(prefix) <= 128)
712                  net2 = prefix_to_net(prefix)
713                  self.assertEqual(net, net2)
714  
715      def test_asmap_roundtrips(self) -> None:
716          """Test case that verifies random ASMap objects roundtrip to/from entries/binary."""
717          # Iterate over the number of leaves the random test ASMap objects have.
718          for leaves in range(1, 20):
719              # Iterate over the number of bits in the AS numbers used.
720              for asnbits in range(0, 24):
721                  # Iterate over the probability that leaves are unassigned.
722                  for pct in range(101):
723                      # Construct a random ASMap object according to the above parameters.
724                      asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits),
725                                                unassigned_prob=0.01 * pct)
726                      # Run tests for to_entries and construction from those entries, both
727                      # for overlapping and non-overlapping ones.
728                      for overlapping in [False, True]:
729                          entries = asmap.to_entries(overlapping=overlapping, fill=False)
730                          random.shuffle(entries)
731                          asmap2 = ASMap(entries)
732                          assert asmap2 is not None
733                          self.assertEqual(asmap2, asmap)
734                          entries = asmap.to_entries(overlapping=overlapping, fill=True)
735                          random.shuffle(entries)
736                          asmap2 = ASMap(entries)
737                          assert asmap2 is not None
738                          self.assertTrue(asmap2.extends(asmap))
739  
740                      # Run tests for to_binary and construction from binary.
741                      enc = asmap.to_binary(fill=False)
742                      asmap3 = ASMap.from_binary(enc)
743                      assert asmap3 is not None
744                      self.assertEqual(asmap3, asmap)
745                      enc = asmap.to_binary(fill=True)
746                      asmap3 = ASMap.from_binary(enc)
747                      assert asmap3 is not None
748                      self.assertTrue(asmap3.extends(asmap))
749  
750      def test_patching(self) -> None:
751          """Test behavior of update, lookup, extends, and diff."""
752          #pylint: disable=too-many-locals,too-many-nested-blocks
753          # Iterate over the number of leaves the random test ASMap objects have.
754          for leaves in range(1, 20):
755              # Iterate over the number of bits in the AS numbers used.
756              for asnbits in range(0, 10):
757                  # Iterate over the probability that leaves are unassigned.
758                  for pct in range(0, 101):
759                      # Construct a random ASMap object according to the above parameters.
760                      asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits),
761                                                unassigned_prob=0.01 * pct)
762                      # Make a copy of that asmap object to which patches will be applied.
763                      # It starts off being equal to asmap.
764                      patched = copy.copy(asmap)
765                      # Keep a list of patches performed.
766                      patches: list[ASNEntry] = []
767                      # Initially there cannot be any difference.
768                      self.assertEqual(asmap.diff(patched), [])
769                      # Make 5 patches, each building on top of the previous ones.
770                      for _ in range(0, 5):
771                          # Construct a random path and new ASN to assign it to, apply it to patched,
772                          # and remember it in patches.
773                          pathlen = random.randrange(5)
774                          path = [random.getrandbits(1) != 0 for _ in range(pathlen)]
775                          newasn = random.randrange(1 + (1 << asnbits))
776                          patched.update(path, newasn)
777                          patches = [(path, newasn)] + patches
778  
779                          # Compute the diff, and whether asmap extends patched, and the other way
780                          # around.
781                          diff = asmap.diff(patched)
782                          self.assertEqual(asmap == patched, len(diff) == 0)
783                          extends = asmap.extends(patched)
784                          back_extends = patched.extends(asmap)
785                          # Determine whether those extends results are consistent with the diff
786                          # result.
787                          self.assertEqual(extends, all(d[2] == 0 for d in diff))
788                          self.assertEqual(back_extends, all(d[1] == 0 for d in diff))
789                          # For every diff found:
790                          for path, old_asn, new_asn in diff:
791                              # Verify asmap and patched actually differ there.
792                              self.assertTrue(old_asn != new_asn)
793                              self.assertEqual(asmap.lookup(path), old_asn)
794                              self.assertEqual(patched.lookup(path), new_asn)
795                              for _ in range(2):
796                                  # Extend the path far enough that it's smaller than any mapped
797                                  # range, and check the lookup holds there too.
798                                  spec_path = list(path)
799                                  while len(spec_path) < 32:
800                                      spec_path.append(random.getrandbits(1) != 0)
801                                  self.assertEqual(asmap.lookup(spec_path), old_asn)
802                                  self.assertEqual(patched.lookup(spec_path), new_asn)
803                                  # Search through the list of performed patches to find the last one
804                                  # applying to the extended path (note that patches is in reverse
805                                  # order, so the first match should work).
806                                  found = False
807                                  for patch_path, patch_asn in patches:
808                                      if spec_path[:len(patch_path)] == patch_path:
809                                          # When found, it must match whatever the result was patched
810                                          # to.
811                                          self.assertEqual(new_asn, patch_asn)
812                                          found = True
813                                          break
814                                  # And such a patch must exist.
815                                  self.assertTrue(found)
816  
817  if __name__ == '__main__':
818      unittest.main()