/ Python / 2021 / 12.py
12.py
 1  from lib import *
 2  
 3  input = read_input(2021, 12)
 4  
 5  lines = input.splitlines()
 6  
 7  edges = {}
 8  for line in lines:
 9      a, b = line.split("-")
10      edges.setdefault(a, []).append(b)
11      edges.setdefault(b, []).append(a)
12  
13  
14  def search1(node, visited):
15      if node == "end":
16          return 1
17      if node.islower() and node in visited:
18          return 0
19  
20      visited.add(node)
21      out = sum(search1(n, visited) for n in edges.get(node, []))
22      visited.discard(node)
23      return out
24  
25  
26  print(search1("start", set()))
27  
28  
29  edges = {}
30  for line in lines:
31      a, b = line.split("-")
32      edges.setdefault(a, []).append(b)
33      edges.setdefault(b, []).append(a)
34  
35  
36  def search2(node, visited, twice):
37      if node == "end":
38          return 1
39  
40      tw = False
41      if node.islower() and node in visited:
42          if twice or node == "start":
43              return 0
44          tw = True
45  
46      visited.add(node)
47      out = sum(search2(n, visited, twice or tw) for n in edges.get(node, []))
48      if not tw:
49          visited.discard(node)
50  
51      return out
52  
53  
54  print(search2("start", set(), False))