/ Python / 2022 / 13.py
13.py
 1  from lib import *
 2  
 3  input = read_input(2022, 13)
 4  
 5  pairs = [tuple(map(eval, pair.splitlines())) for pair in input.split("\n\n")]
 6  
 7  
 8  def compare(a, b):
 9      if isinstance(a, int) and isinstance(b, int):
10          if a < b:
11              return True
12          elif a > b:
13              return False
14          else:
15              return None
16      if isinstance(a, int):
17          a = [a]
18      if isinstance(b, int):
19          b = [b]
20      for x, y in zip(a, b):
21          if (z := compare(x, y)) is not None:
22              return z
23      if len(a) < len(b):
24          return True
25      elif len(a) > len(b):
26          return False
27      else:
28          return None
29  
30  
31  print(sum((i + 1) * (compare(a, b) is True) for i, (a, b) in enumerate(pairs)))
32  
33  
34  packets = [[[2]], [[6]]] + [x for a, b in pairs for x in [a, b]]
35  out = sum(compare(packet, packets[0]) is True for packet in packets) + 1
36  out *= sum(compare(packet, packets[1]) is True for packet in packets) + 1
37  print(out)