digitclassifier.py
1 from typing import List 2 3 import mnist 4 from matrix import Matrix 5 from network import Network 6 from trainingdata import TrainingData 7 from util import sigmoid, sigmoid_prime 8 9 10 class DigitClassifier(Network): 11 def __init__(self, weights: List[Matrix], biases: List[Matrix]): 12 super().__init__(weights, biases) 13 14 @classmethod 15 def random_network(cls, layer_sizes: List[int]): 16 return super().random_network([28 ** 2] + layer_sizes + [10]) 17 18 @staticmethod 19 def activation(x: float) -> float: 20 return sigmoid(x) 21 22 @staticmethod 23 def activation_derivative(x: float) -> float: 24 return sigmoid_prime(x) 25 26 def prepare_expected_output(self, expected) -> Matrix: 27 return Matrix.from_vector([float(i == expected) for i in range(10)]) 28 29 def prepare_input(self, inp) -> Matrix: 30 return Matrix.from_vector(inp) 31 32 def interpret_output(self, out: Matrix): 33 out = out.to_vector() 34 return out.index(max(out)) 35 36 def evaluate(self, test_data: List[TrainingData]) -> float: 37 result = 0 38 for td in test_data: 39 if self.feedforward(td.input_data) == td.expected_output: 40 result += 1 41 return result / len(test_data) 42 43 44 if __name__ == '__main__': 45 network = DigitClassifier.random_network([10, 10]) 46 print("Loading training data ...") 47 training_data = [x for _, x in zip(range(2000), mnist.load_train())] 48 print("Loading validation data ...") 49 validation_data = [x for _, x in zip(range(1000), mnist.load_test())] 50 print("Start training") 51 network.train(training_data, 30, 10, 3, validation_data, "network.json") 52 print(f"Accuracy: {network.evaluate(validation_data)}")