/ client.py
client.py
1 # client.py 2 3 import flwr as fl 4 import torch 5 import torch.nn as nn 6 import torch.nn.functional as F 7 from torchvision.datasets import CIFAR10 8 from torchvision.transforms import Compose, Normalize, ToTensor 9 from torch.utils.data import DataLoader 10 from collections import OrderedDict 11 from typing import Dict, List, Tuple 12 import numpy as np 13 14 # --- 1. Model Definition (Must be the same as the server) --- 15 class Net(nn.Module): 16 # ... (Copy the Net class definition from server.py here) ... 17 def __init__(self) -> None: 18 super(Net, self).__init__() 19 self.conv1 = nn.Conv2d(3, 6, 5) 20 self.pool = nn.MaxPool2d(2, 2) 21 self.conv2 = nn.Conv2d(6, 16, 5) 22 self.fc1 = nn.Linear(16 * 5 * 5, 120) 23 self.fc2 = nn.Linear(120, 84) 24 self.fc3 = nn.Linear(84, 10) 25 26 def forward(self, x: torch.Tensor) -> torch.Tensor: 27 x = self.pool(F.relu(self.conv1(x))) 28 x = self.pool(F.relu(self.conv2(x))) 29 x = x.view(-1, 16 * 5 * 5) 30 x = F.relu(self.fc1(x)) 31 x = F.relu(self.fc2(x)) 32 x = self.fc3(x) 33 return x 34 35 # --- 2. Data Loading and Partitioning --- 36 # Use a simple IID (Independent and Identically Distributed) partition for verification. 37 def load_data(client_id: int): 38 """Load CIFAR-10 and simulate partitioning.""" 39 transform = Compose([ 40 ToTensor(), 41 Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 42 ]) 43 44 # Download and load the full training set 45 trainset = CIFAR10("./data", train=True, download=True, transform=transform) 46 47 # Since you only have one client, we'll use a subset of the data 48 # In a real FL scenario, each client would have its own, unique local dataset. 49 50 # Simple subset creation (e.g., first 5000 samples for the client) 51 subset_indices = list(range(client_id * 5000, (client_id + 1) * 5000)) 52 train_subset = torch.utils.data.Subset(trainset, subset_indices) 53 54 trainloader = DataLoader(train_subset, batch_size=32, shuffle=True) 55 56 # We will use the test set only for local evaluation (optional) 57 testset = CIFAR10("./data", train=False, download=True, transform=transform) 58 testloader = DataLoader(testset, batch_size=32) 59 60 return trainloader, testloader 61 62 # --- 3. Flower Client Implementation --- 63 class CifarClient(fl.client.NumPyClient): 64 def __init__(self, cid: int) -> None: 65 self.cid = cid 66 self.net = Net() 67 self.trainloader, self.testloader = load_data(cid) 68 self.device = torch.device("cpu") # Force CPU to ensure it runs on your Dell Inspiron 69 70 def get_parameters(self, config: Dict[str, str]) -> List[np.ndarray]: 71 """Return the current local model parameters.""" 72 return [val.cpu().numpy() for _, val in self.net.state_dict().items()] 73 74 def set_parameters(self, parameters: List[np.ndarray]) -> None: 75 """Set the local model parameters from the server's global model.""" 76 params_dict = zip(self.net.state_dict().keys(), parameters) 77 state_dict = OrderedDict() 78 for k, v in params_dict: 79 original_tensor = self.net.state_dict()[k] 80 np_array = np.frombuffer(v.tobytes(), dtype=np.float32).reshape(original_tensor.shape) 81 state_dict[k] = torch.tensor(np_array) 82 83 self.net.load_state_dict(state_dict, strict=True) 84 85 def fit(self, parameters: List[np.ndarray], config: Dict[str, str]) -> Tuple[List[np.ndarray], int, Dict]: 86 """Train the model locally for one round.""" 87 self.set_parameters(parameters) 88 89 # Local training loop (1 epoch for this quick test) 90 criterion = nn.CrossEntropyLoss() 91 optimizer = torch.optim.SGD(self.net.parameters(), lr=0.001, momentum=0.9) 92 self.net.train() 93 94 for _ in range(1): # train for 1 local epoch 95 for images, labels in self.trainloader: 96 images, labels = images.to(self.device), labels.to(self.device) 97 optimizer.zero_grad() 98 outputs = self.net(images) 99 loss = criterion(outputs, labels) 100 loss.backward() 101 optimizer.step() 102 103 # Return updated local parameters and the number of training examples used 104 return self.get_parameters({}), len(self.trainloader.dataset), {} 105 106 def evaluate(self, parameters: List[np.ndarray], config: Dict[str, str]) -> Tuple[float, int, Dict]: 107 """Evaluate the model locally.""" 108 self.set_parameters(parameters) 109 110 # Evaluation code (optional but good practice) 111 criterion = nn.CrossEntropyLoss() 112 self.net.eval() 113 loss, correct = 0.0, 0 114 with torch.no_grad(): 115 for images, labels in self.testloader: 116 images, labels = images.to(self.device), labels.to(self.device) 117 outputs = self.net(images) 118 loss += criterion(outputs, labels).item() 119 _, predicted = torch.max(outputs.data, 1) 120 correct += (predicted == labels).sum().item() 121 122 accuracy = correct / len(self.testloader.dataset) 123 return loss, len(self.testloader.dataset), {"accuracy": accuracy} 124 125 # --- 4. Start the Client --- 126 if __name__ == "__main__": 127 # You need the actual IP address of your MacBook Air on your local network 128 MAC_IP_ADDRESS = "192.168.2.86" # e.g., "192.168.1.100" 129 130 fl.client.start_client( 131 server_address=f"{MAC_IP_ADDRESS}:8080", 132 client=CifarClient(cid=0).to_client() # Client ID 0 for the first client 133 )