From 7208c16efce07c410a37a8f3cad8b7ba8f1642f3 Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Sun, 1 Dec 2024 14:49:13 -0700 Subject: [PATCH] Wres: epsilon in args --- wresnet-pytorch/src/train.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/wresnet-pytorch/src/train.py b/wresnet-pytorch/src/train.py index a0db1c7..cff770f 100644 --- a/wresnet-pytorch/src/train.py +++ b/wresnet-pytorch/src/train.py @@ -12,6 +12,8 @@ from tqdm import tqdm import opacus from opacus.validators import ModuleValidator from opacus.utils.batch_memory_manager import BatchMemoryManager +import warnings +warnings.filterwarnings("ignore") def set_seed(seed=42): @@ -76,18 +78,15 @@ def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, schedul return best_test_set_accuracy -def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200, norm=1.0): +def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200, norm=1.0, dp_epsilon=None): train_loader, test_loader = loaders - dp_epsilon = None dp_delta = 1e-5 checkpointFile = 'wrn-{}-{}e-{}d-{}n-dict.pt'.format(int(time.time()), dp_epsilon, dp_delta, norm) - if dp_epsilon is not None: - print(f"DP epsilon = {dp_epsilon}, delta = {dp_delta}") - #net = ModuleValidator.fix(net, replace_bn_with_in=True) - net = ModuleValidator.fix(net) - ModuleValidator.validate(net, strict=True) + #net = ModuleValidator.fix(net, replace_bn_with_in=True) + net = ModuleValidator.fix(net) + ModuleValidator.validate(net, strict=True) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) @@ -105,7 +104,8 @@ def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200 max_grad_norm=norm, ) - print(f"Using sigma={optimizer.noise_multiplier} and C={1.0}, norm = {norm}") + print(f"DP epsilon = {dp_epsilon}, delta = {dp_delta}") + print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {norm}") else: print("Training without differential privacy") @@ -166,7 +166,7 @@ def train(args): net = net.to(device) epochs = training_configurations.epochs - best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, logfile, epochs, args.norm) + best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, logfile, epochs, args.norm, args.epsilon) if log: with open(logfile, 'a') as temp: @@ -192,6 +192,7 @@ if __name__ == '__main__': parser.add_argument('-config', '--config', help='Training Configurations', required=True) parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True) parser.add_argument('--cuda', type=int, help='gpu index', required=False) + parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) args = parser.parse_args()