/ client.py
client.py
  1  # client.py
  2  
  3  import flwr as fl
  4  import torch
  5  import torch.nn as nn
  6  import torch.nn.functional as F
  7  from torchvision.datasets import CIFAR10
  8  from torchvision.transforms import Compose, Normalize, ToTensor
  9  from torch.utils.data import DataLoader
 10  from collections import OrderedDict
 11  from typing import Dict, List, Tuple
 12  import numpy as np
 13  
 14  # --- 1. Model Definition (Must be the same as the server) ---
 15  class Net(nn.Module):
 16      # ... (Copy the Net class definition from server.py here) ...
 17      def __init__(self) -> None:
 18          super(Net, self).__init__()
 19          self.conv1 = nn.Conv2d(3, 6, 5)
 20          self.pool = nn.MaxPool2d(2, 2)
 21          self.conv2 = nn.Conv2d(6, 16, 5)
 22          self.fc1 = nn.Linear(16 * 5 * 5, 120)
 23          self.fc2 = nn.Linear(120, 84)
 24          self.fc3 = nn.Linear(84, 10)
 25  
 26      def forward(self, x: torch.Tensor) -> torch.Tensor:
 27          x = self.pool(F.relu(self.conv1(x)))
 28          x = self.pool(F.relu(self.conv2(x)))
 29          x = x.view(-1, 16 * 5 * 5)
 30          x = F.relu(self.fc1(x))
 31          x = F.relu(self.fc2(x))
 32          x = self.fc3(x)
 33          return x
 34  
 35  # --- 2. Data Loading and Partitioning ---
 36  # Use a simple IID (Independent and Identically Distributed) partition for verification.
 37  def load_data(client_id: int):
 38      """Load CIFAR-10 and simulate partitioning."""
 39      transform = Compose([
 40          ToTensor(),
 41          Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
 42      ])
 43      
 44      # Download and load the full training set
 45      trainset = CIFAR10("./data", train=True, download=True, transform=transform)
 46      
 47      # Since you only have one client, we'll use a subset of the data
 48      # In a real FL scenario, each client would have its own, unique local dataset.
 49      
 50      # Simple subset creation (e.g., first 5000 samples for the client)
 51      subset_indices = list(range(client_id * 5000, (client_id + 1) * 5000))
 52      train_subset = torch.utils.data.Subset(trainset, subset_indices)
 53      
 54      trainloader = DataLoader(train_subset, batch_size=32, shuffle=True)
 55      
 56      # We will use the test set only for local evaluation (optional)
 57      testset = CIFAR10("./data", train=False, download=True, transform=transform)
 58      testloader = DataLoader(testset, batch_size=32)
 59      
 60      return trainloader, testloader
 61  
 62  # --- 3. Flower Client Implementation ---
 63  class CifarClient(fl.client.NumPyClient):
 64      def __init__(self, cid: int) -> None:
 65          self.cid = cid
 66          self.net = Net()
 67          self.trainloader, self.testloader = load_data(cid)
 68          self.device = torch.device("cpu") # Force CPU to ensure it runs on your Dell Inspiron
 69  
 70      def get_parameters(self, config: Dict[str, str]) -> List[np.ndarray]:
 71          """Return the current local model parameters."""
 72          return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
 73  
 74      def set_parameters(self, parameters: List[np.ndarray]) -> None:
 75          """Set the local model parameters from the server's global model."""
 76          params_dict = zip(self.net.state_dict().keys(), parameters)
 77          state_dict = OrderedDict()
 78          for k, v in params_dict:
 79              original_tensor = self.net.state_dict()[k]
 80              np_array = np.frombuffer(v.tobytes(), dtype=np.float32).reshape(original_tensor.shape)
 81              state_dict[k] = torch.tensor(np_array)
 82  
 83          self.net.load_state_dict(state_dict, strict=True)
 84  
 85      def fit(self, parameters: List[np.ndarray], config: Dict[str, str]) -> Tuple[List[np.ndarray], int, Dict]:
 86          """Train the model locally for one round."""
 87          self.set_parameters(parameters)
 88          
 89          # Local training loop (1 epoch for this quick test)
 90          criterion = nn.CrossEntropyLoss()
 91          optimizer = torch.optim.SGD(self.net.parameters(), lr=0.001, momentum=0.9)
 92          self.net.train()
 93          
 94          for _ in range(1): # train for 1 local epoch
 95              for images, labels in self.trainloader:
 96                  images, labels = images.to(self.device), labels.to(self.device)
 97                  optimizer.zero_grad()
 98                  outputs = self.net(images)
 99                  loss = criterion(outputs, labels)
100                  loss.backward()
101                  optimizer.step()
102          
103          # Return updated local parameters and the number of training examples used
104          return self.get_parameters({}), len(self.trainloader.dataset), {}
105  
106      def evaluate(self, parameters: List[np.ndarray], config: Dict[str, str]) -> Tuple[float, int, Dict]:
107          """Evaluate the model locally."""
108          self.set_parameters(parameters)
109          
110          # Evaluation code (optional but good practice)
111          criterion = nn.CrossEntropyLoss()
112          self.net.eval()
113          loss, correct = 0.0, 0
114          with torch.no_grad():
115              for images, labels in self.testloader:
116                  images, labels = images.to(self.device), labels.to(self.device)
117                  outputs = self.net(images)
118                  loss += criterion(outputs, labels).item()
119                  _, predicted = torch.max(outputs.data, 1)
120                  correct += (predicted == labels).sum().item()
121          
122          accuracy = correct / len(self.testloader.dataset)
123          return loss, len(self.testloader.dataset), {"accuracy": accuracy}
124  
125  # --- 4. Start the Client ---
126  if __name__ == "__main__":
127      # You need the actual IP address of your MacBook Air on your local network
128      MAC_IP_ADDRESS = "192.168.2.86"  # e.g., "192.168.1.100"
129      
130      fl.client.start_client(
131          server_address=f"{MAC_IP_ADDRESS}:8080",
132          client=CifarClient(cid=0).to_client() # Client ID 0 for the first client
133      )