mnist.py
1 import gzip 2 3 from trainingdata import TrainingData 4 5 TEST_IMAGES = "data/test_images.gz" 6 TEST_LABELS = "data/test_labels.gz" 7 TRAIN_IMAGES = "data/train_images.gz" 8 TRAIN_LABELS = "data/train_labels.gz" 9 10 11 def _load(images: str, labels: str): 12 images = gzip.open(images) 13 labels = gzip.open(labels) 14 assert int.from_bytes(images.read(4), 'big') == 2051, "invalid magic number" 15 assert int.from_bytes(labels.read(4), 'big') == 2049, "invalid magic number" 16 size = int.from_bytes(images.read(4), 'big') 17 assert int.from_bytes(labels.read(4), 'big') == size, "different size" 18 rows = int.from_bytes(images.read(4), 'big') 19 cols = int.from_bytes(images.read(4), 'big') 20 for i in range(size): 21 label = int.from_bytes(labels.read(1), 'big') 22 image = [int.from_bytes(images.read(1), 'big') / 255 for _ in range(rows * cols)] 23 yield TrainingData(image, label) 24 25 26 def load_train(): 27 return _load(TRAIN_IMAGES, TRAIN_LABELS) 28 29 30 def load_test(): 31 return _load(TEST_IMAGES, TEST_LABELS) 32 33 34 def print_image(image: list): 35 assert len(image) == 28 ** 2, 'invalid image size' 36 out = "" 37 for i in range(28): 38 for j in range(28): 39 color = int(image[i * 28 + j] * 255) 40 out += f"\033[38;2;{color};{color};{color}m" + "\u2588" * 2 41 out += "\n" 42 print(out, end='\033[0m')