/ inference.py
inference.py
  1  import torch
  2  import torch.nn as nn
  3  import torch.nn.functional as F
  4  from torchvision import datasets, transforms
  5  
  6  # 1. Define the exact same model class (Crucial!)
  7  # You must copy the exact Net class definition from server.py and client.py
  8  class Net(nn.Module):
  9      def __init__(self) -> None:
 10          super(Net, self).__init__()
 11          self.conv1 = nn.Conv2d(3, 6, 5)
 12          self.pool = nn.MaxPool2d(2, 2)
 13          self.conv2 = nn.Conv2d(6, 16, 5)
 14          self.fc1 = nn.Linear(16 * 5 * 5, 120)
 15          self.fc2 = nn.Linear(120, 84)
 16          self.fc3 = nn.Linear(84, 10)
 17  
 18      def forward(self, x: torch.Tensor) -> torch.Tensor:
 19          x = self.pool(F.relu(self.conv1(x)))
 20          x = self.pool(F.relu(self.conv2(x)))
 21          x = x.view(-1, 16 * 5 * 5)
 22          x = F.relu(self.fc1(x))
 23          x = F.relu(self.fc2(x))
 24          x = self.fc3(x)
 25          return x
 26  
 27  # 2. Load the model for inference
 28  MODEL_PATH = "./checkpoints/global_model_final_round_100_pi_dell.pt" # Update path if needed
 29  DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
 30  
 31  transform = transforms.Compose([
 32      transforms.ToTensor(),
 33      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
 34  ])
 35  classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
 36  
 37  def load_inference_model(path: str) -> Net:
 38      """Loads the model from the saved state dictionary."""
 39      model = Net()
 40      # Load state dict and map to the appropriate device (CPU for now, likely)
 41      model.load_state_dict(torch.load(path, map_location=DEVICE))
 42      model.to(DEVICE)
 43      # Set model to evaluation mode (turns off dropout, uses moving averages for BatchNorm)
 44      model.eval() 
 45      print(f"Model loaded successfully to {DEVICE}. Ready for inference.")
 46      return model
 47  
 48  # 3. Example of running inferenc
 49  if __name__ == "__main__":
 50      try:
 51          inference_model = load_inference_model(MODEL_PATH)
 52      except FileNotFoundError:
 53          print(f"šŸ›‘ ERROR: Model file not found at {MODEL_PATH}")
 54          exit()
 55  
 56      # Load the entire CIFAR-10 test set and DataLoader
 57      testset = datasets.CIFAR10(root='./data', train=False,
 58                                 download=True, transform=transform)
 59      testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
 60      
 61      # Initialize tracking arrays for 10 classes
 62      class_correct = list(0. for i in range(10))
 63      class_total = list(0. for i in range(10))
 64      
 65      print(f"\nšŸš€ Starting verbose evaluation on {len(testset)} test images...")
 66      
 67      # Iterate through the entire test set
 68      with torch.no_grad():
 69          for data in testloader:
 70              images, labels = data
 71              images, labels = images.to(DEVICE), labels.to(DEVICE)
 72              
 73              outputs = inference_model(images)
 74              _, predicted = torch.max(outputs, 1)
 75              
 76              c = (predicted == labels).squeeze() # Boolean tensor of correct/incorrect predictions
 77              
 78              # Loop through the batch to track individual class performance
 79              for i in range(len(labels)):
 80                  label = labels[i].item() # The true label (0-9)
 81                  class_correct[label] += c[i].item() # Add 1 if correct
 82                  class_total[label] += 1
 83                  
 84      # --- 1. Overall Accuracy Report ---
 85      overall_correct = sum(class_correct)
 86      overall_total = sum(class_total)
 87      overall_accuracy = 100 * overall_correct / overall_total
 88      
 89      print("\n--- Overall Model Performance ---")
 90      print(f"Total test images: {overall_total}")
 91      print(f"Overall Accuracy: {overall_accuracy:.2f}%")
 92  
 93      # --- 2. Class-by-Class Analysis ---
 94      print("\n--- Detailed Class Accuracy ---")
 95      
 96      total_random_guess = 0 # Track how many classes are near 10%
 97      
 98      for i in range(10):
 99          accuracy = 100 * class_correct[i] / class_total[i]
100          
101          # Check if the accuracy is close to random guessing (10%)
102          if accuracy < 11:
103              total_random_guess += 1
104          
105          print(f'Accuracy of {classes[i]:<5}: {accuracy:.2f}% ({int(class_correct[i])}/{int(class_total[i])})')
106          
107      print(f"\nSummary: {total_random_guess} out of 10 classes are performing at or near random chance (10%).")