/ server.py
server.py
 1  import flwr as fl
 2  import torch
 3  import os
 4  from flwr.common import parameters_to_ndarrays
 5  import torch.nn as nn
 6  import torch.nn.functional as F
 7  from collections import OrderedDict
 8  from typing import Dict, List, Tuple
 9  import numpy as np
10  import time
11  
12  # Simple CNN Model
13  class Net(nn.Module):
14      def __init__(self) -> None:
15          super(Net, self).__init__()
16          self.conv1 = nn.Conv2d(3, 6, 5)
17          self.pool = nn.MaxPool2d(2, 2)
18          self.conv2 = nn.Conv2d(6, 16, 5)
19          self.fc1 = nn.Linear(16 * 5 * 5, 120)
20          self.fc2 = nn.Linear(120, 84)
21          self.fc3 = nn.Linear(84, 10)
22  
23      def forward(self, x: torch.Tensor) -> torch.Tensor:
24          x = self.pool(F.relu(self.conv1(x)))
25          x = self.pool(F.relu(self.conv2(x)))
26          x = x.view(-1, 16 * 5 * 5)
27          x = F.relu(self.fc1(x))
28          x = F.relu(self.fc2(x))
29          x = self.fc3(x)
30          return x
31  
32  def get_initial_parameters() -> List[bytes]:
33      """Get the intitial model parameters as a list of Numpy arrays (bytes)."""
34      net = Net()
35      return [val.cpu().numpy().tobytes() for _, val in net.state_dict().items()]
36  
37  def get_eval_fn(net: torch.nn.Module, save_path: str):
38      """Returns a function that saves the global model."""
39      os.makedirs(save_path, exist_ok=True)
40  
41      def evaluate(server_round: int, parameters: fl.common.Parameters, config: dict):
42          # Convert parameters to PyTorch state dict
43          ndarrays = parameters
44          params_dict = zip(net.state_dict().keys(), ndarrays)
45  
46          state_dict = OrderedDict()
47          for k, v in params_dict:
48              original_shape = net.state_dict()[k].shape
49              np_array = np.frombuffer(v.tobytes(), dtype=np.float32).reshape(original_shape)
50              state_dict[k] = torch.tensor(np_array)
51  
52          net.load_state_dict(state_dict, strict=True)
53  
54          # Save the model after the final round
55          if server_round == 10:
56              save_file = os.path.join(save_path, f"global_model_final_round_{server_round}_pi.pt")
57              torch.save(net.state_dict(), save_file)
58              print(f"✅ FINAL GLOBAL MODEL SAVED TO: {save_file}")
59  
60          # You can also run central evaluation here if you have a test set
61          # For now, we return empty results
62          return 0.0, {} # Return loss and metrics
63  
64      return evaluate
65  
66  initial_net = Net()
67  
68  strategy = fl.server.strategy.FedAvg(
69      min_available_clients=1,
70      min_fit_clients=1
71      ,
72      initial_parameters=fl.common.ndarrays_to_parameters(get_initial_parameters()),
73      evaluate_fn=get_eval_fn(initial_net, save_path="./checkpoints"),
74  )
75  
76  print("Starting Flower Server...")
77  start_time = time.time()
78  
79  fl.server.start_server(
80          server_address="0.0.0.0:8080",
81          config=fl.server.ServerConfig(num_rounds=10),
82          strategy=strategy
83  )
84  
85  end_time = time.time()
86  elapsed_time = end_time - start_time
87  print(f"\n⏱️  Total training time: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")