/ NEAT / pool.py
pool.py
 1  from typing import List, Callable
 2  
 3  from genome import Genome
 4  from species import Species
 5  
 6  
 7  class Pool:
 8      def __init__(self, population: int, input_nodes: int, output_nodes: int):
 9          self.population: int = population
10          self.input_nodes: int = input_nodes
11          self.output_nodes = output_nodes
12          self.species: List[Species] = []
13          for _ in range(population):
14              self.add_to_species(Genome(input_nodes, output_nodes))
15  
16      def add_to_species(self, genome: Genome):
17          for species in self.species:
18              if Genome.is_same_species(species.genomes[0], genome):
19                  species.genomes.append(genome)
20                  species.previous_top_fitness = max(species.previous_top_fitness, genome.fitness)
21                  return
22          self.species.append(Species(genome))
23  
24      def evaluate_fitness(self, evaluator: Callable[[Genome], float]):
25          for species in self.species:
26              for genome in species.genomes:
27                  genome.fitness = evaluator(genome)
28  
29      def get_top_genome(self) -> Genome:
30          return max((species.get_top_genome() for species in self.species), key=lambda genome: genome.fitness)
31  
32      def calculate_total_adjusted_fitness(self) -> float:
33          return sum(species.get_total_adjusted_fitness() for species in self.species)
34  
35      def remove_weak_genomes_from_species(self):
36          for species in self.species:
37              species.remove_weak_genomes()
38  
39      def calculate_genome_adjusted_fitness(self):
40          for species in self.species:
41              species.calculate_genome_adjusted_fitness()
42  
43      def remove_stale_species(self):
44          top_top_fitness: float = self.get_top_genome().fitness
45          for i in range(len(self.species))[::-1]:
46              species: Species = self.species[i]
47              top_fitness: float = species.get_top_genome().fitness
48              if top_fitness <= species.previous_top_fitness:
49                  species.staleness += 1
50              else:
51                  species.staleness = 0
52              species.previous_top_fitness = top_fitness
53              if species.staleness >= 15 and top_fitness < top_top_fitness:
54                  del self.species[i]
55  
56      def breed_new_generation(self):
57          self.calculate_genome_adjusted_fitness()
58  
59          self.remove_weak_genomes_from_species()
60          self.remove_stale_species()
61  
62          survived_species: List[Species] = []
63          children: List[Genome] = []
64  
65          total_adjusted_fitness: float = self.calculate_total_adjusted_fitness()
66  
67          carry_over: float = 0
68          for species in self.species:
69              fchild: float = self.population * (species.get_total_adjusted_fitness() / total_adjusted_fitness)
70              nchild: int = int(fchild)
71              carry_over += fchild - nchild
72              if carry_over >= 1:
73                  carry_over -= 1
74                  nchild += 1
75  
76              if nchild < 1:
77                  continue
78  
79              new_species: Species = Species(species.get_top_genome())
80              new_species.previous_top_fitness = species.previous_top_fitness
81              new_species.staleness = species.staleness
82              survived_species.append(new_species)
83              for _ in range(1, nchild):
84                  children.append(species.breed_child())
85  
86          self.species: List[Species] = survived_species
87          for child in children:
88              self.add_to_species(child)