16.py
1 from lib import * 2 3 input = read_input(2022, 16) 4 5 n = len(input.splitlines()) 6 graph = [set() for _ in range(n)] 7 rates = [0] * n 8 names = {"AA": 0} 9 10 11 def name(n: str) -> int: 12 return names.setdefault(n, len(names)) 13 14 15 for line in input.splitlines(): 16 v, r, n = re.match(r"^Valve (.+) has flow rate=(\d+); tunnels? leads? to valves? (.+)$", line).groups() # type: ignore 17 v = name(v) 18 r = int(r) 19 n = list(map(name, n.split(", "))) 20 rates[v] = r 21 graph[v].update(n) 22 23 24 n = len(rates) 25 dist = {i: {j: 0 if i == j else 1 if j in graph[i] else 1e1337 for j in range(n)} for i in range(n)} 26 for k in range(n): 27 for i in range(n): 28 for j in range(n): 29 dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j]) 30 31 32 @cache 33 def solve(p, time, closed): 34 if time <= 0: 35 return 0 36 37 out = 0 38 for q in range(n): 39 if closed & 1 << q == 0: 40 continue 41 t = time - dist[p][q] - 1 42 out = max(out, solve(q, t, closed & ~(1 << q)) + rates[q] * t) 43 return out 44 45 46 print(solve(0, 30, reduce(lambda acc, x: acc | 1 << x, (i for i in range(n) if rates[i]), 0))) 47 48 49 solve.cache_clear() 50 out = 0 51 valves = [i for i in range(n) if rates[i]] 52 for s in range(1 << len(valves)): 53 a = solve(0, 26, reduce(lambda acc, x: acc | 1 << x, (j for i, j in enumerate(valves) if s & 1 << i), 0)) 54 b = solve(0, 26, reduce(lambda acc, x: acc | 1 << x, (j for i, j in enumerate(valves) if s & 1 << i == 0), 0)) 55 out = max(out, a + b) 56 print(out)