asmap.py
1 # Copyright (c) 2022 Pieter Wuille 2 # Distributed under the MIT software license, see the accompanying 3 # file LICENSE or http://www.opensource.org/licenses/mit-license.php. 4 5 """ 6 This module provides the ASNEntry and ASMap classes. 7 """ 8 9 import copy 10 import ipaddress 11 import random 12 import unittest 13 from collections.abc import Callable, Iterable 14 from enum import Enum 15 from functools import total_ordering 16 from typing import Optional, Union, overload 17 18 def net_to_prefix(net: Union[ipaddress.IPv4Network,ipaddress.IPv6Network]) -> list[bool]: 19 """ 20 Convert an IPv4 or IPv6 network to a prefix represented as a list of bits. 21 22 IPv4 ranges are remapped to their IPv4-mapped IPv6 range (::ffff:0:0/96). 23 """ 24 num_bits = net.prefixlen 25 netrange = int.from_bytes(net.network_address.packed, 'big') 26 27 # Map an IPv4 prefix into IPv6 space. 28 if isinstance(net, ipaddress.IPv4Network): 29 num_bits += 96 30 netrange += 0xffff00000000 31 32 # Strip unused bottom bits. 33 assert (netrange & ((1 << (128 - num_bits)) - 1)) == 0 34 return [((netrange >> (127 - i)) & 1) != 0 for i in range(num_bits)] 35 36 def prefix_to_net(prefix: list[bool]) -> Union[ipaddress.IPv4Network,ipaddress.IPv6Network]: 37 """The reverse operation of net_to_prefix.""" 38 # Convert to number 39 netrange = sum(b << (127 - i) for i, b in enumerate(prefix)) 40 num_bits = len(prefix) 41 assert num_bits <= 128 42 43 # Return IPv4 range if in ::ffff:0:0/96 44 if num_bits >= 96 and (netrange >> 32) == 0xffff: 45 return ipaddress.IPv4Network((netrange & 0xffffffff, num_bits - 96), True) 46 47 # Return IPv6 range otherwise. 48 return ipaddress.IPv6Network((netrange, num_bits), True) 49 50 # Shortcut for (prefix, ASN) entries. 51 ASNEntry = tuple[list[bool], int] 52 53 # Shortcut for (prefix, old ASN, new ASN) entries. 54 ASNDiff = tuple[list[bool], int, int] 55 56 class _VarLenCoder: 57 """ 58 A class representing a custom variable-length binary encoder/decoder for 59 integers. Each object represents a different coder, with different parameters 60 minval and clsbits. 61 62 The encoding is easiest to describe using an example. Let's say minval=100 and 63 clsbits=[4,2,2,3]. In that case: 64 - x in [100..115]: encoded as [0] + [4-bit BE encoding of (x-100)]. 65 - x in [116..119]: encoded as [1,0] + [2-bit BE encoding of (x-116)]. 66 - x in [120..123]: encoded as [1,1,0] + [2-bit BE encoding of (x-120)]. 67 - x in [124..131]: encoded as [1,1,1] + [3-bit BE encoding of (x-124)]. 68 69 In general, every number is encoded as: 70 - First, k "1"-bits, where k is the class the number falls in (there is one class 71 per element of clsbits). 72 - Then, a "0"-bit, unless k is the highest class, in which case there is nothing. 73 - Lastly, clsbits[k] bits encoding in big endian the position in its class that 74 number falls into. 75 - Every class k consists of 2^clsbits[k] consecutive integers. k=0 starts at minval, 76 other classes start one past the last element of the class before it. 77 """ 78 79 def __init__(self, minval: int, clsbits: list[int]): 80 """Construct a new _VarLenCoder.""" 81 self._minval = minval 82 self._clsbits = clsbits 83 self._maxval = minval + sum(1 << b for b in clsbits) - 1 84 85 def can_encode(self, val: int) -> bool: 86 """Check whether value val is in the range this coder supports.""" 87 return self._minval <= val <= self._maxval 88 89 def encode(self, val: int, ret: list[int]) -> None: 90 """Append encoding of val onto integer list ret.""" 91 92 assert self._minval <= val <= self._maxval 93 val -= self._minval 94 bits = 0 95 for k, bits in enumerate(self._clsbits): 96 if val >> bits: 97 # If the value will not fit in class k, subtract its range from v, 98 # emit a "1" bit and continue with the next class. 99 val -= 1 << bits 100 ret.append(1) 101 else: 102 if k + 1 < len(self._clsbits): 103 # Unless we're in the last class, emit a "0" bit. 104 ret.append(0) 105 break 106 # And then encode v (now the position within the class) in big endian. 107 ret.extend((val >> (bits - 1 - b)) & 1 for b in range(bits)) 108 109 def encode_size(self, val: int) -> int: 110 """Compute how many bits are needed to encode val.""" 111 assert self._minval <= val <= self._maxval 112 val -= self._minval 113 ret = 0 114 bits = 0 115 for k, bits in enumerate(self._clsbits): 116 if val >> bits: 117 val -= 1 << bits 118 ret += 1 119 else: 120 ret += k + 1 < len(self._clsbits) 121 break 122 return ret + bits 123 124 def decode(self, stream, bitpos) -> tuple[int,int]: 125 """Decode a number starting at bitpos in stream, returning value and new bitpos.""" 126 val = self._minval 127 bits = 0 128 for k, bits in enumerate(self._clsbits): 129 bit = 0 130 if k + 1 < len(self._clsbits): 131 bit = stream[bitpos] 132 bitpos += 1 133 if not bit: 134 break 135 val += 1 << bits 136 for i in range(bits): 137 bit = stream[bitpos] 138 bitpos += 1 139 val += bit << (bits - 1 - i) 140 return val, bitpos 141 142 # Variable-length encoders used in the binary asmap format. 143 _CODER_INS = _VarLenCoder(0, [0, 0, 1]) 144 _CODER_ASN = _VarLenCoder(1, list(range(15, 25))) 145 _CODER_MATCH = _VarLenCoder(2, list(range(1, 9))) 146 _CODER_JUMP = _VarLenCoder(17, list(range(5, 31))) 147 148 class _Instruction(Enum): 149 """One instruction in the binary asmap format.""" 150 # A return instruction, encoded as [0], returns a constant ASN. It is followed by 151 # an integer using the ASN encoding. 152 RETURN = 0 153 # A jump instruction, encoded as [1,0] inspects the next unused bit in the input 154 # and either continues execution (if 0), or skips a specified number of bits (if 1). 155 # It is followed by an integer, and then two subprograms. The integer uses jump encoding 156 # and corresponds to the length of the first subprogram (so it can be skipped). 157 JUMP = 1 158 # A match instruction, encoded as [1,1,0] inspects 1 or more of the next unused bits 159 # in the input with its argument. If they all match, execution continues. If they do 160 # not, failure (represented by 0) is returned. If a default instruction has been executed before, instead 161 # of failure the default instruction's argument is returned. It is followed by an 162 # integer in match encoding, and a subprogram. That value is at least 2 bits and at 163 # most 9 bits. An n-bit value signifies matching (n-1) bits in the input with the lower 164 # (n-1) bits in the match value. 165 MATCH = 2 166 # A default instruction, encoded as [1,1,1] sets the default variable to its argument, 167 # and continues execution. It is followed by an integer in ASN encoding, and a subprogram. 168 DEFAULT = 3 169 # Not an actual instruction, but a way to encode the empty program that fails. In the 170 # encoder, it is used more generally to represent the failure case inside MATCH instructions, 171 # which may (if used inside the context of a DEFAULT instruction) actually correspond to 172 # a successful return. In this usage, they're always converted to an actual MATCH or RETURN 173 # before the top level is reached (see make_default below). 174 END = 4 175 176 class _BinNode: 177 """A class representing a (node of) the parsed binary asmap format.""" 178 179 @overload 180 def __init__(self, ins: _Instruction): ... 181 @overload 182 def __init__(self, ins: _Instruction, arg1: int): ... 183 @overload 184 def __init__(self, ins: _Instruction, arg1: "_BinNode", arg2: "_BinNode"): ... 185 @overload 186 def __init__(self, ins: _Instruction, arg1: int, arg2: "_BinNode"): ... 187 188 def __init__(self, ins: _Instruction, arg1=None, arg2=None): 189 """ 190 Construct a new asmap node. Possibilities are: 191 - _BinNode(_Instruction.RETURN, asn) 192 - _BinNode(_Instruction.JUMP, node_0, node_1) 193 - _BinNode(_Instruction.MATCH, val, node) 194 - _BinNode(_Instruction.DEFAULT, asn, node) 195 - _BinNode(_Instruction.END) 196 """ 197 self.ins = ins 198 self.arg1 = arg1 199 self.arg2 = arg2 200 if ins == _Instruction.RETURN: 201 assert isinstance(arg1, int) 202 assert arg2 is None 203 self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1) 204 elif ins == _Instruction.JUMP: 205 assert isinstance(arg1, _BinNode) 206 assert isinstance(arg2, _BinNode) 207 self.size = (_CODER_INS.encode_size(ins.value) + _CODER_JUMP.encode_size(arg1.size) + 208 arg1.size + arg2.size) 209 elif ins == _Instruction.DEFAULT: 210 assert isinstance(arg1, int) 211 assert isinstance(arg2, _BinNode) 212 self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1) + arg2.size 213 elif ins == _Instruction.MATCH: 214 assert isinstance(arg1, int) 215 assert isinstance(arg2, _BinNode) 216 self.size = (_CODER_INS.encode_size(ins.value) + _CODER_MATCH.encode_size(arg1) 217 + arg2.size) 218 elif ins == _Instruction.END: 219 assert arg1 is None 220 assert arg2 is None 221 self.size = 0 222 else: 223 assert False 224 225 @staticmethod 226 def make_end() -> "_BinNode": 227 """Constructor for a _BinNode with just an END instruction.""" 228 return _BinNode(_Instruction.END) 229 230 @staticmethod 231 def make_leaf(val: int) -> "_BinNode": 232 """Constructor for a _BinNode of just a RETURN instruction.""" 233 assert val is not None and val > 0 234 return _BinNode(_Instruction.RETURN, val) 235 236 @staticmethod 237 def make_branch(node0: "_BinNode", node1: "_BinNode") -> "_BinNode": 238 """ 239 Construct a _BinNode corresponding to running either the node0 or node1 subprogram, 240 based on the next input bit. It exploits shortcuts that are possible in the encoding, 241 and uses either a JUMP, MATCH, or END instruction. 242 """ 243 if node0.ins == _Instruction.END and node1.ins == _Instruction.END: 244 return node0 245 if node0.ins == _Instruction.END: 246 if node1.ins == _Instruction.MATCH and node1.arg1 <= 0xFF: 247 return _BinNode(node1.ins, node1.arg1 + (1 << node1.arg1.bit_length()), node1.arg2) 248 return _BinNode(_Instruction.MATCH, 3, node1) 249 if node1.ins == _Instruction.END: 250 if node0.ins == _Instruction.MATCH and node0.arg1 <= 0xFF: 251 return _BinNode(node0.ins, node0.arg1 + (1 << (node0.arg1.bit_length() - 1)), 252 node0.arg2) 253 return _BinNode(_Instruction.MATCH, 2, node0) 254 return _BinNode(_Instruction.JUMP, node0, node1) 255 256 @staticmethod 257 def make_default(val: int, sub: "_BinNode") -> "_BinNode": 258 """ 259 Construct a _BinNode that corresponds to the specified subprogram, with the specified 260 default value. It exploits shortcuts that are possible in the encoding, and will use 261 either a DEFAULT or a RETURN instruction.""" 262 assert val is not None and val > 0 263 if sub.ins == _Instruction.END: 264 return _BinNode(_Instruction.RETURN, val) 265 if sub.ins in (_Instruction.RETURN, _Instruction.DEFAULT): 266 return sub 267 return _BinNode(_Instruction.DEFAULT, val, sub) 268 269 @total_ordering 270 class ASMap: 271 """ 272 A class whose objects represent a mapping from subnets to ASNs. 273 274 Internally the mapping is stored as a binary trie, but can be converted 275 from/to a list of ASNEntry objects, and from/to the binary asmap file format. 276 277 In the trie representation, nodes are represented as bare lists for efficiency 278 and ease of manipulation: 279 - [0] means an unassigned subnet (no ASN mapping for it is present) 280 - [int] means a subnet mapped entirely to the specified ASN. 281 - [node,node] means a subnet whose lower half and upper half have different 282 - mappings, represented by new trie nodes. 283 """ 284 285 def update(self, prefix: list[bool], asn: int) -> None: 286 """Update this ASMap object to map prefix to the specified asn.""" 287 assert asn == 0 or _CODER_ASN.can_encode(asn) 288 289 def recurse(node: list, offset: int) -> None: 290 if offset == len(prefix): 291 # Reached the end of prefix; overwrite this node. 292 node.clear() 293 node.append(asn) 294 return 295 if len(node) == 1: 296 # Need to descend into a leaf node; split it up. 297 oldasn = node[0] 298 node.clear() 299 node.append([oldasn]) 300 node.append([oldasn]) 301 # Descend into the node. 302 recurse(node[prefix[offset]], offset + 1) 303 # If the result is two identical leaf children, merge them. 304 if len(node[0]) == 1 and len(node[1]) == 1 and node[0] == node[1]: 305 oldasn = node[0][0] 306 node.clear() 307 node.append(oldasn) 308 recurse(self._trie, 0) 309 310 def update_multi(self, entries: list[tuple[list[bool], int]]) -> None: 311 """Apply multiple update operations, where longer prefixes take precedence.""" 312 entries.sort(key=lambda entry: len(entry[0])) 313 for prefix, asn in entries: 314 self.update(prefix, asn) 315 316 def _set_trie(self, trie) -> None: 317 """Set trie directly. Internal use only.""" 318 def recurse(node: list) -> None: 319 if len(node) < 2: 320 return 321 recurse(node[0]) 322 recurse(node[1]) 323 if len(node[0]) == 2: 324 return 325 if node[0] == node[1]: 326 if len(node[0]) == 0: 327 node.clear() 328 else: 329 asn = node[0][0] 330 node.clear() 331 node.append(asn) 332 recurse(trie) 333 self._trie = trie 334 335 def __init__(self, entries: Optional[Iterable[ASNEntry]] = None) -> None: 336 """Construct an ASMap object from an optional list of entries.""" 337 self._trie = [0] 338 if entries is not None: 339 def entry_key(entry): 340 """Sort function that places shorter prefixes first.""" 341 prefix, asn = entry 342 return len(prefix), prefix, asn 343 for prefix, asn in sorted(entries, key=entry_key): 344 self.update(prefix, asn) 345 346 def lookup(self, prefix: list[bool]) -> Optional[int]: 347 """Look up a prefix. Returns ASN, or 0 if unassigned, or None if indeterminate.""" 348 node = self._trie 349 for bit in prefix: 350 if len(node) == 1: 351 break 352 node = node[bit] 353 if len(node) == 1: 354 return node[0] 355 return None 356 357 def _to_entries_flat(self, fill: bool = False) -> list[ASNEntry]: 358 """Convert an ASMap object to a list of non-overlapping (prefix, asn) objects.""" 359 prefix : list[bool] = [] 360 361 def recurse(node: list) -> list[ASNEntry]: 362 ret = [] 363 if len(node) == 1: 364 if node[0] > 0: 365 ret = [(list(prefix), node[0])] 366 elif len(node) == 2: 367 prefix.append(False) 368 ret = recurse(node[0]) 369 prefix[-1] = True 370 ret += recurse(node[1]) 371 prefix.pop() 372 if fill and len(ret) > 1: 373 asns = set(x[1] for x in ret) 374 if len(asns) == 1: 375 ret = [(list(prefix), list(asns)[0])] 376 return ret 377 return recurse(self._trie) 378 379 def _to_entries_minimal(self, fill: bool = False) -> list[ASNEntry]: 380 """Convert a trie to a minimal list of ASNEntry objects, exploiting overlap.""" 381 prefix : list[bool] = [] 382 383 def recurse(node: list) -> (tuple[dict[Optional[int], list[ASNEntry]], bool]): 384 if len(node) == 1 and node[0] == 0: 385 return {None if fill else 0: []}, True 386 if len(node) == 1: 387 return {node[0]: [], None: [(list(prefix), node[0])]}, False 388 ret: dict[Optional[int], list[ASNEntry]] = {} 389 prefix.append(False) 390 left, lhole = recurse(node[0]) 391 prefix[-1] = True 392 right, rhole = recurse(node[1]) 393 prefix.pop() 394 hole = not fill and (lhole or rhole) 395 def candidate(ctx: Optional[int], res0: Optional[list[ASNEntry]], 396 res1: Optional[list[ASNEntry]]): 397 if res0 is not None and res1 is not None: 398 if ctx not in ret or len(res0) + len(res1) < len(ret[ctx]): 399 ret[ctx] = res0 + res1 400 for ctx in set(left) | set(right): 401 candidate(ctx, left.get(ctx), right.get(ctx)) 402 candidate(ctx, left.get(None), right.get(ctx)) 403 candidate(ctx, left.get(ctx), right.get(None)) 404 if not hole: 405 for ctx in list(ret): 406 if ctx is not None: 407 candidate(None, [(list(prefix), ctx)], ret[ctx]) 408 if None in ret: 409 ret = {ctx:entries for ctx, entries in ret.items() 410 if ctx is None or len(entries) < len(ret[None])} 411 if hole: 412 ret = {ctx:entries for ctx, entries in ret.items() if ctx is None or ctx == 0} 413 return ret, hole 414 res, _ = recurse(self._trie) 415 return res[0] if 0 in res else res[None] 416 417 def __str__(self) -> str: 418 """Convert this ASMap object to a string containing Python code constructing it.""" 419 return f"ASMap({self._trie})" 420 421 def to_entries(self, overlapping: bool = True, fill: bool = False) -> list[ASNEntry]: 422 """ 423 Convert the mappings in this ASMap object to a list of ASNEntry objects. 424 425 Arguments: 426 overlapping: Permit the subnets in the resulting ASNEntry to overlap. 427 Setting this can result in a shorter list. 428 fill: Permit the resulting ASNEntry objects to cover subnets that 429 are unassigned in this ASMap object. Setting this can 430 result in a shorter list. 431 """ 432 if overlapping: 433 return self._to_entries_minimal(fill) 434 return self._to_entries_flat(fill) 435 436 @staticmethod 437 def from_random(num_leaves: int = 10, max_asn: int = 6, 438 unassigned_prob: float = 0.5) -> "ASMap": 439 """ 440 Construct a random ASMap object, with specified: 441 - Number of leaves in its trie (at least 1) 442 - Maximum ASN value (at least 1) 443 - Probability for leaf nodes to be unassigned 444 445 The number of leaves in the resulting object may be less than what is 446 requested. This method is mostly intended for testing. 447 """ 448 assert num_leaves >= 1 449 assert max_asn >= 1 or unassigned_prob == 1 450 assert _CODER_ASN.can_encode(max_asn) 451 assert 0.0 <= unassigned_prob <= 1.0 452 trie: list = [] 453 leaves = [trie] 454 ret = ASMap() 455 for i in range(1, num_leaves): 456 idx = random.randrange(i) 457 leaf = leaves[idx] 458 lastleaf = leaves.pop() 459 if idx + 1 < i: 460 leaves[idx] = lastleaf 461 leaf.append([]) 462 leaf.append([]) 463 leaves.append(leaf[0]) 464 leaves.append(leaf[1]) 465 for leaf in leaves: 466 if random.random() >= unassigned_prob: 467 leaf.append(random.randrange(1, max_asn + 1)) 468 else: 469 leaf.append(0) 470 #pylint: disable=protected-access 471 ret._set_trie(trie) 472 return ret 473 474 def _to_binnode(self, fill: bool = False) -> _BinNode: 475 """Convert a trie to a _BinNode object.""" 476 def recurse(node: list) -> tuple[dict[Optional[int], _BinNode], bool]: 477 if len(node) == 1 and node[0] == 0: 478 return {(None if fill else 0): _BinNode.make_end()}, True 479 if len(node) == 1: 480 return {None: _BinNode.make_leaf(node[0]), node[0]: _BinNode.make_end()}, False 481 ret: dict[Optional[int], _BinNode] = {} 482 left, lhole = recurse(node[0]) 483 right, rhole = recurse(node[1]) 484 hole = (lhole or rhole) and not fill 485 486 def candidate(ctx: Optional[int], arg1, arg2, func: Callable): 487 if arg1 is not None and arg2 is not None: 488 cand = func(arg1, arg2) 489 if ctx not in ret or cand.size < ret[ctx].size: 490 ret[ctx] = cand 491 492 union = set(left) | set(right) 493 sorted_union = sorted(union, key=lambda x: (x is None, x)) 494 for ctx in sorted_union: 495 candidate(ctx, left.get(ctx), right.get(ctx), _BinNode.make_branch) 496 candidate(ctx, left.get(None), right.get(ctx), _BinNode.make_branch) 497 candidate(ctx, left.get(ctx), right.get(None), _BinNode.make_branch) 498 if not hole: 499 for ctx in sorted(set(ret) - set([None])): 500 candidate(None, ctx, ret[ctx], _BinNode.make_default) 501 if None in ret: 502 ret = {ctx:enc for ctx, enc in ret.items() 503 if ctx is None or enc.size < ret[None].size} 504 if hole: 505 ret = {ctx:enc for ctx, enc in ret.items() if ctx is None or ctx == 0} 506 return ret, hole 507 res, _ = recurse(self._trie) 508 return res[0] if 0 in res else res[None] 509 510 @staticmethod 511 def _from_binnode(binnode: _BinNode) -> "ASMap": 512 """Construct an ASMap object from a _BinNode. Internal use only.""" 513 def recurse(node: _BinNode, default: int) -> list: 514 if node.ins == _Instruction.RETURN: 515 return [node.arg1] 516 if node.ins == _Instruction.JUMP: 517 return [recurse(node.arg1, default), recurse(node.arg2, default)] 518 if node.ins == _Instruction.MATCH: 519 val = node.arg1 520 sub = recurse(node.arg2, default) 521 while val >= 2: 522 bit = val & 1 523 val >>= 1 524 if bit: 525 sub = [[default], sub] 526 else: 527 sub = [sub, [default]] 528 return sub 529 assert node.ins == _Instruction.DEFAULT 530 return recurse(node.arg2, node.arg1) 531 ret = ASMap() 532 if binnode.ins != _Instruction.END: 533 #pylint: disable=protected-access 534 ret._set_trie(recurse(binnode, 0)) 535 return ret 536 537 def to_binary(self, fill: bool = False) -> bytes: 538 """ 539 Convert this ASMap object to binary. 540 541 Argument: 542 fill: permit the resulting binary encoder to contain mappers for 543 unassigned subnets in this ASMap object. Doing so may 544 reduce the size of the encoding. 545 Returns: 546 A bytes object with the encoding of this ASMap object. 547 """ 548 bits: list[int] = [] 549 550 def recurse(node: _BinNode) -> None: 551 _CODER_INS.encode(node.ins.value, bits) 552 if node.ins == _Instruction.RETURN: 553 _CODER_ASN.encode(node.arg1, bits) 554 elif node.ins == _Instruction.JUMP: 555 _CODER_JUMP.encode(node.arg1.size, bits) 556 recurse(node.arg1) 557 recurse(node.arg2) 558 elif node.ins == _Instruction.DEFAULT: 559 _CODER_ASN.encode(node.arg1, bits) 560 recurse(node.arg2) 561 else: 562 assert node.ins == _Instruction.MATCH 563 _CODER_MATCH.encode(node.arg1, bits) 564 recurse(node.arg2) 565 566 binnode = self._to_binnode(fill) 567 if binnode.ins != _Instruction.END: 568 recurse(binnode) 569 570 val = 0 571 nbits = 0 572 ret = [] 573 for bit in bits: 574 val += (bit << nbits) 575 nbits += 1 576 if nbits == 8: 577 ret.append(val) 578 val = 0 579 nbits = 0 580 if nbits: 581 ret.append(val) 582 return bytes(ret) 583 584 @staticmethod 585 def from_binary(bindata: bytes) -> Optional["ASMap"]: 586 """Decode an ASMap object from the provided binary encoding.""" 587 588 bits: list[int] = [] 589 for byte in bindata: 590 bits.extend((byte >> i) & 1 for i in range(8)) 591 592 def recurse(bitpos: int) -> tuple[_BinNode, int]: 593 insval, bitpos = _CODER_INS.decode(bits, bitpos) 594 ins = _Instruction(insval) 595 if ins == _Instruction.RETURN: 596 asn, bitpos = _CODER_ASN.decode(bits, bitpos) 597 return _BinNode(ins, asn), bitpos 598 if ins == _Instruction.JUMP: 599 jump, bitpos = _CODER_JUMP.decode(bits, bitpos) 600 left, bitpos1 = recurse(bitpos) 601 if bitpos1 != bitpos + jump: 602 raise ValueError("Inconsistent jump") 603 right, bitpos = recurse(bitpos1) 604 return _BinNode(ins, left, right), bitpos 605 if ins == _Instruction.MATCH: 606 match, bitpos = _CODER_MATCH.decode(bits, bitpos) 607 sub, bitpos = recurse(bitpos) 608 return _BinNode(ins, match, sub), bitpos 609 assert ins == _Instruction.DEFAULT 610 asn, bitpos = _CODER_ASN.decode(bits, bitpos) 611 sub, bitpos = recurse(bitpos) 612 return _BinNode(ins, asn, sub), bitpos 613 614 if len(bits) == 0: 615 binnode = _BinNode(_Instruction.END) 616 else: 617 try: 618 binnode, bitpos = recurse(0) 619 except (ValueError, IndexError): 620 return None 621 if bitpos < len(bits) - 7: 622 return None 623 if not all(bit == 0 for bit in bits[bitpos:]): 624 return None 625 626 return ASMap._from_binnode(binnode) 627 628 def __lt__(self, other: "ASMap") -> bool: 629 return self._trie < other._trie 630 631 def __eq__(self, other: object) -> bool: 632 if isinstance(other, ASMap): 633 return self._trie == other._trie 634 return False 635 636 def extends(self, req: "ASMap") -> bool: 637 """Determine whether this matches req for all subranges where req is assigned.""" 638 def recurse(actual: list, require: list) -> bool: 639 if len(require) == 1 and require[0] == 0: 640 return True 641 if len(require) == 1: 642 if len(actual) == 1: 643 return bool(require[0] == actual[0]) 644 return recurse(actual[0], require) and recurse(actual[1], require) 645 if len(actual) == 2: 646 return recurse(actual[0], require[0]) and recurse(actual[1], require[1]) 647 return recurse(actual, require[0]) and recurse(actual, require[1]) 648 assert isinstance(req, ASMap) 649 #pylint: disable=protected-access 650 return recurse(self._trie, req._trie) 651 652 def diff(self, other: "ASMap") -> list[ASNDiff]: 653 """Compute the diff from self to other.""" 654 prefix: list[bool] = [] 655 ret: list[ASNDiff] = [] 656 657 def recurse(old_node: list, new_node: list): 658 if len(old_node) == 1 and len(new_node) == 1: 659 if old_node[0] != new_node[0]: 660 ret.append((list(prefix), old_node[0], new_node[0])) 661 else: 662 old_left: list = old_node if len(old_node) == 1 else old_node[0] 663 old_right: list = old_node if len(old_node) == 1 else old_node[1] 664 new_left: list = new_node if len(new_node) == 1 else new_node[0] 665 new_right: list = new_node if len(new_node) == 1 else new_node[1] 666 prefix.append(False) 667 recurse(old_left, new_left) 668 prefix[-1] = True 669 recurse(old_right, new_right) 670 prefix.pop() 671 assert isinstance(other, ASMap) 672 #pylint: disable=protected-access 673 recurse(self._trie, other._trie) 674 return ret 675 676 def __copy__(self) -> "ASMap": 677 """Construct a copy of this ASMap object. Its state will not be shared.""" 678 ret = ASMap() 679 #pylint: disable=protected-access 680 ret._set_trie(copy.deepcopy(self._trie)) 681 return ret 682 683 def __deepcopy__(self, _) -> "ASMap": 684 # ASMap objects do not allow sharing of the _trie member, so we don't need the memoization. 685 return self.__copy__() 686 687 688 class TestASMap(unittest.TestCase): 689 """Unit tests for this module.""" 690 691 def test_ipv6_prefix_roundtrips(self) -> None: 692 """Test that random IPv6 network ranges roundtrip through prefix encoding.""" 693 for _ in range(20): 694 net_bits = random.getrandbits(128) 695 for prefix_len in range(0, 129): 696 masked_bits = (net_bits >> (128 - prefix_len)) << (128 - prefix_len) 697 net = ipaddress.IPv6Network((masked_bits.to_bytes(16, 'big'), prefix_len)) 698 prefix = net_to_prefix(net) 699 self.assertTrue(len(prefix) <= 128) 700 net2 = prefix_to_net(prefix) 701 self.assertEqual(net, net2) 702 703 def test_ipv4_prefix_roundtrips(self) -> None: 704 """Test that random IPv4 network ranges roundtrip through prefix encoding.""" 705 for _ in range(100): 706 net_bits = random.getrandbits(32) 707 for prefix_len in range(0, 33): 708 masked_bits = (net_bits >> (32 - prefix_len)) << (32 - prefix_len) 709 net = ipaddress.IPv4Network((masked_bits.to_bytes(4, 'big'), prefix_len)) 710 prefix = net_to_prefix(net) 711 self.assertTrue(32 <= len(prefix) <= 128) 712 net2 = prefix_to_net(prefix) 713 self.assertEqual(net, net2) 714 715 def test_asmap_roundtrips(self) -> None: 716 """Test case that verifies random ASMap objects roundtrip to/from entries/binary.""" 717 # Iterate over the number of leaves the random test ASMap objects have. 718 for leaves in range(1, 20): 719 # Iterate over the number of bits in the AS numbers used. 720 for asnbits in range(0, 24): 721 # Iterate over the probability that leaves are unassigned. 722 for pct in range(101): 723 # Construct a random ASMap object according to the above parameters. 724 asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits), 725 unassigned_prob=0.01 * pct) 726 # Run tests for to_entries and construction from those entries, both 727 # for overlapping and non-overlapping ones. 728 for overlapping in [False, True]: 729 entries = asmap.to_entries(overlapping=overlapping, fill=False) 730 random.shuffle(entries) 731 asmap2 = ASMap(entries) 732 assert asmap2 is not None 733 self.assertEqual(asmap2, asmap) 734 entries = asmap.to_entries(overlapping=overlapping, fill=True) 735 random.shuffle(entries) 736 asmap2 = ASMap(entries) 737 assert asmap2 is not None 738 self.assertTrue(asmap2.extends(asmap)) 739 740 # Run tests for to_binary and construction from binary. 741 enc = asmap.to_binary(fill=False) 742 asmap3 = ASMap.from_binary(enc) 743 assert asmap3 is not None 744 self.assertEqual(asmap3, asmap) 745 enc = asmap.to_binary(fill=True) 746 asmap3 = ASMap.from_binary(enc) 747 assert asmap3 is not None 748 self.assertTrue(asmap3.extends(asmap)) 749 750 def test_patching(self) -> None: 751 """Test behavior of update, lookup, extends, and diff.""" 752 #pylint: disable=too-many-locals,too-many-nested-blocks 753 # Iterate over the number of leaves the random test ASMap objects have. 754 for leaves in range(1, 20): 755 # Iterate over the number of bits in the AS numbers used. 756 for asnbits in range(0, 10): 757 # Iterate over the probability that leaves are unassigned. 758 for pct in range(0, 101): 759 # Construct a random ASMap object according to the above parameters. 760 asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits), 761 unassigned_prob=0.01 * pct) 762 # Make a copy of that asmap object to which patches will be applied. 763 # It starts off being equal to asmap. 764 patched = copy.copy(asmap) 765 # Keep a list of patches performed. 766 patches: list[ASNEntry] = [] 767 # Initially there cannot be any difference. 768 self.assertEqual(asmap.diff(patched), []) 769 # Make 5 patches, each building on top of the previous ones. 770 for _ in range(0, 5): 771 # Construct a random path and new ASN to assign it to, apply it to patched, 772 # and remember it in patches. 773 pathlen = random.randrange(5) 774 path = [random.getrandbits(1) != 0 for _ in range(pathlen)] 775 newasn = random.randrange(1 + (1 << asnbits)) 776 patched.update(path, newasn) 777 patches = [(path, newasn)] + patches 778 779 # Compute the diff, and whether asmap extends patched, and the other way 780 # around. 781 diff = asmap.diff(patched) 782 self.assertEqual(asmap == patched, len(diff) == 0) 783 extends = asmap.extends(patched) 784 back_extends = patched.extends(asmap) 785 # Determine whether those extends results are consistent with the diff 786 # result. 787 self.assertEqual(extends, all(d[2] == 0 for d in diff)) 788 self.assertEqual(back_extends, all(d[1] == 0 for d in diff)) 789 # For every diff found: 790 for path, old_asn, new_asn in diff: 791 # Verify asmap and patched actually differ there. 792 self.assertTrue(old_asn != new_asn) 793 self.assertEqual(asmap.lookup(path), old_asn) 794 self.assertEqual(patched.lookup(path), new_asn) 795 for _ in range(2): 796 # Extend the path far enough that it's smaller than any mapped 797 # range, and check the lookup holds there too. 798 spec_path = list(path) 799 while len(spec_path) < 32: 800 spec_path.append(random.getrandbits(1) != 0) 801 self.assertEqual(asmap.lookup(spec_path), old_asn) 802 self.assertEqual(patched.lookup(spec_path), new_asn) 803 # Search through the list of performed patches to find the last one 804 # applying to the extended path (note that patches is in reverse 805 # order, so the first match should work). 806 found = False 807 for patch_path, patch_asn in patches: 808 if spec_path[:len(patch_path)] == patch_path: 809 # When found, it must match whatever the result was patched 810 # to. 811 self.assertEqual(new_asn, patch_asn) 812 found = True 813 break 814 # And such a patch must exist. 815 self.assertTrue(found) 816 817 if __name__ == '__main__': 818 unittest.main()