import torch from flamby.utils import evaluate_model_on_tests
# 2 lines of code to change to switch to another dataset from flamby.datasets.fed_tcga_brca import ( BATCH_SIZE, LR, NUM_EPOCHS_POOLED, Baseline, BaselineLoss, metric, NUM_CLIENTS, Optimizer, ) from flamby.datasets.fed_tcga_brca import FedTcgaBrca as FedDataset
Import several macros, datasets and metrics.
1 2 3 4 5 6
# Instantiation of local train set (and data loader)), baseline loss function, baseline model, default optimizer train_dataset = FedDataset(center=0, train=True, pooled=False) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) lossfunc = BaselineLoss() model = Baseline() optimizer = Optimizer(model.parameters(), lr=LR)
In this script, the pooled parameter is set to False when creating the FedDataset instances. This indicates that the dataset is not pooled, meaning that the data is kept separate for each client or center. Each client or center has its own local dataset, which is a common setup in federated learning to simulate real-world scenarios where data is distributed across different locations or devices.
1 2 3 4 5 6 7 8
# Traditional pytorch training loop for epoch inrange(0, NUM_EPOCHS_POOLED): for idx, (X, y) inenumerate(train_dataloader): optimizer.zero_grad() outputs = model(X) loss = lossfunc(outputs, y) loss.backward() optimizer.step()
正常的训练流程
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Evaluation # Instantiation of a list of the local test sets test_dataloaders = [ torch.utils.data.DataLoader( FedDataset(center=i, train=False, pooled=False), batch_size=BATCH_SIZE, shuffle=False, num_workers=0, ) for i inrange(NUM_CLIENTS) ] # Function performing the evaluation dict_cindex = evaluate_model_on_tests(model, test_dataloaders, metric) print(dict_cindex)
import torch from flamby.utils import evaluate_model_on_tests
# 2 lines of code to change to switch to another dataset from flamby.datasets.fed_tcga_brca import ( BATCH_SIZE, LR, NUM_EPOCHS_POOLED, Baseline, BaselineLoss, metric, NUM_CLIENTS, get_nb_max_rounds ) from flamby.datasets.fed_tcga_brca import FedTcgaBrca as FedDataset
# 1st line of code to change to switch to another strategy from flamby.strategies.fed_avg import FedAvg as strat
use `FedAvg` as strategy
1 2 3 4 5 6 7 8 9 10 11 12 13
# We loop on all the clients of the distributed dataset and instantiate associated data loaders train_dataloaders = [ torch.utils.data.DataLoader( FedDataset(center = i, train = True, pooled = False), batch_size = BATCH_SIZE, shuffle = True, num_workers = 0 ) for i inrange(NUM_CLIENTS) ]
# Federated Learning loop # 2nd line of code to change to switch to another strategy (feed the FL strategy the right HPs) args = { "training_dataloaders": train_dataloaders, "model": m, "loss": lossfunc, "optimizer_class": torch.optim.SGD, "learning_rate": LR / 10.0, "num_updates": 100, # This helper function returns the number of rounds necessary to perform approximately as many # epochs on each local dataset as with the pooled training "nrounds": get_nb_max_rounds(100), } s = strat(**args) m = s.run()[0]
# Evaluation # We only instantiate one test set in this particular case: the pooled one test_dataloaders = [ torch.utils.data.DataLoader( FedDataset(train = False, pooled = True), batch_size = BATCH_SIZE, shuffle = False, num_workers = 0, ) ] dict_cindex = evaluate_model_on_tests(m, test_dataloaders, metric) print(dict_cindex)