/ inference.py
inference.py
1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 from torchvision import datasets, transforms 5 6 # 1. Define the exact same model class (Crucial!) 7 # You must copy the exact Net class definition from server.py and client.py 8 class Net(nn.Module): 9 def __init__(self) -> None: 10 super(Net, self).__init__() 11 self.conv1 = nn.Conv2d(3, 6, 5) 12 self.pool = nn.MaxPool2d(2, 2) 13 self.conv2 = nn.Conv2d(6, 16, 5) 14 self.fc1 = nn.Linear(16 * 5 * 5, 120) 15 self.fc2 = nn.Linear(120, 84) 16 self.fc3 = nn.Linear(84, 10) 17 18 def forward(self, x: torch.Tensor) -> torch.Tensor: 19 x = self.pool(F.relu(self.conv1(x))) 20 x = self.pool(F.relu(self.conv2(x))) 21 x = x.view(-1, 16 * 5 * 5) 22 x = F.relu(self.fc1(x)) 23 x = F.relu(self.fc2(x)) 24 x = self.fc3(x) 25 return x 26 27 # 2. Load the model for inference 28 MODEL_PATH = "./checkpoints/global_model_final_round_100_pi_dell.pt" # Update path if needed 29 DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu") 30 31 transform = transforms.Compose([ 32 transforms.ToTensor(), 33 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 34 ]) 35 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 36 37 def load_inference_model(path: str) -> Net: 38 """Loads the model from the saved state dictionary.""" 39 model = Net() 40 # Load state dict and map to the appropriate device (CPU for now, likely) 41 model.load_state_dict(torch.load(path, map_location=DEVICE)) 42 model.to(DEVICE) 43 # Set model to evaluation mode (turns off dropout, uses moving averages for BatchNorm) 44 model.eval() 45 print(f"Model loaded successfully to {DEVICE}. Ready for inference.") 46 return model 47 48 # 3. Example of running inferenc 49 if __name__ == "__main__": 50 try: 51 inference_model = load_inference_model(MODEL_PATH) 52 except FileNotFoundError: 53 print(f"š ERROR: Model file not found at {MODEL_PATH}") 54 exit() 55 56 # Load the entire CIFAR-10 test set and DataLoader 57 testset = datasets.CIFAR10(root='./data', train=False, 58 download=True, transform=transform) 59 testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) 60 61 # Initialize tracking arrays for 10 classes 62 class_correct = list(0. for i in range(10)) 63 class_total = list(0. for i in range(10)) 64 65 print(f"\nš Starting verbose evaluation on {len(testset)} test images...") 66 67 # Iterate through the entire test set 68 with torch.no_grad(): 69 for data in testloader: 70 images, labels = data 71 images, labels = images.to(DEVICE), labels.to(DEVICE) 72 73 outputs = inference_model(images) 74 _, predicted = torch.max(outputs, 1) 75 76 c = (predicted == labels).squeeze() # Boolean tensor of correct/incorrect predictions 77 78 # Loop through the batch to track individual class performance 79 for i in range(len(labels)): 80 label = labels[i].item() # The true label (0-9) 81 class_correct[label] += c[i].item() # Add 1 if correct 82 class_total[label] += 1 83 84 # --- 1. Overall Accuracy Report --- 85 overall_correct = sum(class_correct) 86 overall_total = sum(class_total) 87 overall_accuracy = 100 * overall_correct / overall_total 88 89 print("\n--- Overall Model Performance ---") 90 print(f"Total test images: {overall_total}") 91 print(f"Overall Accuracy: {overall_accuracy:.2f}%") 92 93 # --- 2. Class-by-Class Analysis --- 94 print("\n--- Detailed Class Accuracy ---") 95 96 total_random_guess = 0 # Track how many classes are near 10% 97 98 for i in range(10): 99 accuracy = 100 * class_correct[i] / class_total[i] 100 101 # Check if the accuracy is close to random guessing (10%) 102 if accuracy < 11: 103 total_random_guess += 1 104 105 print(f'Accuracy of {classes[i]:<5}: {accuracy:.2f}% ({int(class_correct[i])}/{int(class_total[i])})') 106 107 print(f"\nSummary: {total_random_guess} out of 10 classes are performing at or near random chance (10%).")