/ Python / 2022 / 16.py
16.py
 1  from lib import *
 2  
 3  input = read_input(2022, 16)
 4  
 5  n = len(input.splitlines())
 6  graph = [set() for _ in range(n)]
 7  rates = [0] * n
 8  names = {"AA": 0}
 9  
10  
11  def name(n: str) -> int:
12      return names.setdefault(n, len(names))
13  
14  
15  for line in input.splitlines():
16      v, r, n = re.match(r"^Valve (.+) has flow rate=(\d+); tunnels? leads? to valves? (.+)$", line).groups()  # type: ignore
17      v = name(v)
18      r = int(r)
19      n = list(map(name, n.split(", ")))
20      rates[v] = r
21      graph[v].update(n)
22  
23  
24  n = len(rates)
25  dist = {i: {j: 0 if i == j else 1 if j in graph[i] else 1e1337 for j in range(n)} for i in range(n)}
26  for k in range(n):
27      for i in range(n):
28          for j in range(n):
29              dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j])
30  
31  
32  @cache
33  def solve(p, time, closed):
34      if time <= 0:
35          return 0
36  
37      out = 0
38      for q in range(n):
39          if closed & 1 << q == 0:
40              continue
41          t = time - dist[p][q] - 1
42          out = max(out, solve(q, t, closed & ~(1 << q)) + rates[q] * t)
43      return out
44  
45  
46  print(solve(0, 30, reduce(lambda acc, x: acc | 1 << x, (i for i in range(n) if rates[i]), 0)))
47  
48  
49  solve.cache_clear()
50  out = 0
51  valves = [i for i in range(n) if rates[i]]
52  for s in range(1 << len(valves)):
53      a = solve(0, 26, reduce(lambda acc, x: acc | 1 << x, (j for i, j in enumerate(valves) if s & 1 << i), 0))
54      b = solve(0, 26, reduce(lambda acc, x: acc | 1 << x, (j for i, j in enumerate(valves) if s & 1 << i == 0), 0))
55      out = max(out, a + b)
56  print(out)