/ server.py
server.py
1 import flwr as fl 2 import torch 3 import os 4 from flwr.common import parameters_to_ndarrays 5 import torch.nn as nn 6 import torch.nn.functional as F 7 from collections import OrderedDict 8 from typing import Dict, List, Tuple 9 import numpy as np 10 import time 11 12 # Simple CNN Model 13 class Net(nn.Module): 14 def __init__(self) -> None: 15 super(Net, self).__init__() 16 self.conv1 = nn.Conv2d(3, 6, 5) 17 self.pool = nn.MaxPool2d(2, 2) 18 self.conv2 = nn.Conv2d(6, 16, 5) 19 self.fc1 = nn.Linear(16 * 5 * 5, 120) 20 self.fc2 = nn.Linear(120, 84) 21 self.fc3 = nn.Linear(84, 10) 22 23 def forward(self, x: torch.Tensor) -> torch.Tensor: 24 x = self.pool(F.relu(self.conv1(x))) 25 x = self.pool(F.relu(self.conv2(x))) 26 x = x.view(-1, 16 * 5 * 5) 27 x = F.relu(self.fc1(x)) 28 x = F.relu(self.fc2(x)) 29 x = self.fc3(x) 30 return x 31 32 def get_initial_parameters() -> List[bytes]: 33 """Get the intitial model parameters as a list of Numpy arrays (bytes).""" 34 net = Net() 35 return [val.cpu().numpy().tobytes() for _, val in net.state_dict().items()] 36 37 def get_eval_fn(net: torch.nn.Module, save_path: str): 38 """Returns a function that saves the global model.""" 39 os.makedirs(save_path, exist_ok=True) 40 41 def evaluate(server_round: int, parameters: fl.common.Parameters, config: dict): 42 # Convert parameters to PyTorch state dict 43 ndarrays = parameters 44 params_dict = zip(net.state_dict().keys(), ndarrays) 45 46 state_dict = OrderedDict() 47 for k, v in params_dict: 48 original_shape = net.state_dict()[k].shape 49 np_array = np.frombuffer(v.tobytes(), dtype=np.float32).reshape(original_shape) 50 state_dict[k] = torch.tensor(np_array) 51 52 net.load_state_dict(state_dict, strict=True) 53 54 # Save the model after the final round 55 if server_round == 10: 56 save_file = os.path.join(save_path, f"global_model_final_round_{server_round}_pi.pt") 57 torch.save(net.state_dict(), save_file) 58 print(f"✅ FINAL GLOBAL MODEL SAVED TO: {save_file}") 59 60 # You can also run central evaluation here if you have a test set 61 # For now, we return empty results 62 return 0.0, {} # Return loss and metrics 63 64 return evaluate 65 66 initial_net = Net() 67 68 strategy = fl.server.strategy.FedAvg( 69 min_available_clients=1, 70 min_fit_clients=1 71 , 72 initial_parameters=fl.common.ndarrays_to_parameters(get_initial_parameters()), 73 evaluate_fn=get_eval_fn(initial_net, save_path="./checkpoints"), 74 ) 75 76 print("Starting Flower Server...") 77 start_time = time.time() 78 79 fl.server.start_server( 80 server_address="0.0.0.0:8080", 81 config=fl.server.ServerConfig(num_rounds=10), 82 strategy=strategy 83 ) 84 85 end_time = time.time() 86 elapsed_time = end_time - start_time 87 print(f"\n⏱️ Total training time: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")