pool.py
1 from typing import List, Callable 2 3 from genome import Genome 4 from species import Species 5 6 7 class Pool: 8 def __init__(self, population: int, input_nodes: int, output_nodes: int): 9 self.population: int = population 10 self.input_nodes: int = input_nodes 11 self.output_nodes = output_nodes 12 self.species: List[Species] = [] 13 for _ in range(population): 14 self.add_to_species(Genome(input_nodes, output_nodes)) 15 16 def add_to_species(self, genome: Genome): 17 for species in self.species: 18 if Genome.is_same_species(species.genomes[0], genome): 19 species.genomes.append(genome) 20 species.previous_top_fitness = max(species.previous_top_fitness, genome.fitness) 21 return 22 self.species.append(Species(genome)) 23 24 def evaluate_fitness(self, evaluator: Callable[[Genome], float]): 25 for species in self.species: 26 for genome in species.genomes: 27 genome.fitness = evaluator(genome) 28 29 def get_top_genome(self) -> Genome: 30 return max((species.get_top_genome() for species in self.species), key=lambda genome: genome.fitness) 31 32 def calculate_total_adjusted_fitness(self) -> float: 33 return sum(species.get_total_adjusted_fitness() for species in self.species) 34 35 def remove_weak_genomes_from_species(self): 36 for species in self.species: 37 species.remove_weak_genomes() 38 39 def calculate_genome_adjusted_fitness(self): 40 for species in self.species: 41 species.calculate_genome_adjusted_fitness() 42 43 def remove_stale_species(self): 44 top_top_fitness: float = self.get_top_genome().fitness 45 for i in range(len(self.species))[::-1]: 46 species: Species = self.species[i] 47 top_fitness: float = species.get_top_genome().fitness 48 if top_fitness <= species.previous_top_fitness: 49 species.staleness += 1 50 else: 51 species.staleness = 0 52 species.previous_top_fitness = top_fitness 53 if species.staleness >= 15 and top_fitness < top_top_fitness: 54 del self.species[i] 55 56 def breed_new_generation(self): 57 self.calculate_genome_adjusted_fitness() 58 59 self.remove_weak_genomes_from_species() 60 self.remove_stale_species() 61 62 survived_species: List[Species] = [] 63 children: List[Genome] = [] 64 65 total_adjusted_fitness: float = self.calculate_total_adjusted_fitness() 66 67 carry_over: float = 0 68 for species in self.species: 69 fchild: float = self.population * (species.get_total_adjusted_fitness() / total_adjusted_fitness) 70 nchild: int = int(fchild) 71 carry_over += fchild - nchild 72 if carry_over >= 1: 73 carry_over -= 1 74 nchild += 1 75 76 if nchild < 1: 77 continue 78 79 new_species: Species = Species(species.get_top_genome()) 80 new_species.previous_top_fitness = species.previous_top_fitness 81 new_species.staleness = species.staleness 82 survived_species.append(new_species) 83 for _ in range(1, nchild): 84 children.append(species.breed_child()) 85 86 self.species: List[Species] = survived_species 87 for child in children: 88 self.add_to_species(child)