graph.py
1 class UnionFind: 2 def __init__(self, n): 3 self.parents = [*range(n)] 4 5 def find(self, x): 6 if self.parents[x] == x: 7 return x 8 self.parents[x] = self.find(self.parents[x]) 9 return self.parents[x] 10 11 def merge(self, x, y): 12 self.parents[self.find(x)] = self.find(y) 13 14 def connected(self, x, y): 15 return self.find(x) == self.find(y) 16 17 def count(self): 18 return len(set(map(self.find, range(len(self.parents))))) 19 20 def components(self): 21 out = {} 22 for parent, i in zip(map(self.find, range(len(self.parents))), range(len(self.parents))): 23 out.setdefault(parent, []).append(i) 24 25 return out