From e4b5998dbb0a08527fd4a02afc73e4b6397abc07 Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Sat, 30 Nov 2024 20:28:25 -0700 Subject: [PATCH] Wres: dp training --- wresnet-pytorch/src/train.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/wresnet-pytorch/src/train.py b/wresnet-pytorch/src/train.py index 0977f72..2d591ca 100644 --- a/wresnet-pytorch/src/train.py +++ b/wresnet-pytorch/src/train.py @@ -8,6 +8,8 @@ import random from utils import json_file_to_pyobj, get_loaders from WideResNet import WideResNet from tqdm import tqdm +import opacus +from opacus.validators import ModuleValidator def set_seed(seed=42): @@ -21,16 +23,39 @@ def set_seed(seed=42): def _train_seed(net, loaders, device, dataset, log=False, checkpoint=False, logfile='', checkpointFile='', epochs=200): train_loader, test_loader = loaders + dp_epsilon = 8 + if dp_epsilon is not None: + print(f"DP epsilon: {dp_epsilon}") + #net = ModuleValidator.fix(net, replace_bn_with_in=True) + net = ModuleValidator.fix(net) + print(net) + ModuleValidator.validate(net, strict=True) + criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) best_test_set_accuracy = 0 + privacy_engine = opacus.PrivacyEngine() + net, optimizer, train_loader = privacy_engine.make_private_with_epsilon( + module=net, + optimizer=optimizer, + data_loader=train_loader, + epochs=epochs, + target_epsilon=8, + target_delta=1e-5, + max_grad_norm=3.0, + ) + + print(f"Using sigma={optimizer.noise_multiplier} and C={1.0}") + print(f"Training with {epochs} epochs") - for epoch in tqdm(range(epochs)): + #for epoch in tqdm(range(epochs)): + for epoch in range(epochs): net.train() - for i, data in tqdm(enumerate(train_loader, 0), leave=False): + #for i, data in tqdm(enumerate(train_loader, 0), leave=False): + for i, data in enumerate(train_loader, 0): inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) @@ -66,6 +91,7 @@ def _train_seed(net, loaders, device, dataset, log=False, checkpoint=False, logf epoch_accuracy = round(100 * epoch_accuracy, 2) if log: + print('Accuracy at epoch {} is {}%\n'.format(epoch + 1, epoch_accuracy)) with open(logfile, 'a') as temp: temp.write('Accuracy at epoch {} is {}%\n'.format(epoch + 1, epoch_accuracy)) @@ -85,7 +111,7 @@ def train(args): wrn_width = training_configurations.wrn_width dataset = training_configurations.dataset.lower() #seeds = [int(seed) for seed in training_configurations.seeds] - seeds = [int.from_bytes(os.urandom(8), byteorder='big')] + seeds = [int.from_bytes(os.urandom(4), byteorder='big')] log = True if training_configurations.log.lower() == 'true' else False if log: