import argparse import equations import numpy as np import time import torch import torch.nn as nn from torch import optim from torch.optim.lr_scheduler import MultiStepLR from torch.utils.data import DataLoader, Subset import torch.nn.functional as F from pathlib import Path from torchvision import transforms from torchvision.datasets import CIFAR10 import pytorch_lightning as pl import opacus from opacus.validators import ModuleValidator from opacus.utils.batch_memory_manager import BatchMemoryManager from WideResNet import WideResNet import warnings warnings.filterwarnings("ignore") DEVICE = torch.device("cpu") def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): seed = np.random.randint(0, 1e9) seed ^= int(time.time()) pl.seed_everything(seed) train_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode='reflect').squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) datadir = Path("./data") train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform) keep = np.full(len(train_ds), True) keep[:m] = False np.random.shuffle(keep) train_ds_p = Subset(train_ds, keep) train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4) train_dl_p = DataLoader(train_ds_p, batch_size=train_batch_size, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4) return train_dl, train_dl_p, test_dl def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler): best_test_set_accuracy = 0 for epoch in range(hp['epochs']): model.train() for i, data in enumerate(train_loader, 0): inputs, labels = data inputs = inputs.to(DEVICE) labels = labels.to(DEVICE) optimizer.zero_grad() wrn_outputs = model(inputs) outputs = wrn_outputs[0] loss = criterion(outputs, labels) loss.backward() optimizer.step() scheduler.step() if epoch % 20 == 0 or epoch == hp['epochs'] - 1: with torch.no_grad(): correct = 0 total = 0 model.eval() for data in test_loader: images, labels = data images = images.to(DEVICE) labels = labels.to(DEVICE) wrn_outputs = model(images) outputs = wrn_outputs[0] _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() epoch_accuracy = correct / total epoch_accuracy = round(100 * epoch_accuracy, 2) print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%") return best_test_set_accuracy def train(hp): model = WideResNet( d=hp["wrn_depth"], k=hp["wrn_width"], n_classes=10, input_features=3, output_features=16, strides=[1, 1, 2, 2], ) model = ModuleValidator.fix(model) ModuleValidator.validate(model, strict=True) model = model.to(DEVICE) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD( model.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4 ) scheduler = MultiStepLR( optimizer, milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]], gamma=0.2 ) train_dl, train_dl_p, test_dl = get_dataloaders() print(f"Training with {hp['epochs']} epochs") if hp['epsilon'] is not None: privacy_engine = opacus.PrivacyEngine() model, optimizer, train_loader = privacy_engine.make_private_with_epsilon( module=model, optimizer=optimizer, data_loader=train_dl, epochs=hp['epochs'], target_epsilon=hp['epsilon'], target_delta=hp['delta'], max_grad_norm=hp['norm'], ) print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}") print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}") with BatchMemoryManager( data_loader=train_loader, max_physical_batch_size=1000, # Roughly 12gb vram, uses 9.4 optimizer=optimizer ) as memory_safe_data_loader: best_test_set_accuracy = train_no_cap( model, hp, train_dl, test_dl, optimizer, criterion, scheduler, ) else: print("Training without differential privacy") best_test_set_accuracy = train_no_cap( model, hp, train_dl, test_dl, optimizer, criterion, scheduler, ) return model def main(): global DEVICE parser = argparse.ArgumentParser(description='WideResNet O1 audit') parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True) parser.add_argument('--cuda', type=int, help='gpu index', required=False) parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) args = parser.parse_args() if torch.cuda.is_available() and args.cuda: DEVICE = torch.device(f'cuda:{args.cuda}') elif torch.cuda.is_available(): DEVICE = torch.device('cuda:0') else: DEVICE = torch.device('cpu') hyperparams = { "wrn_depth": 16, "wrn_width": 1, "epsilon": args.epsilon, "delta": 1e-5, "norm": args.norm, "batch_size": 4096, "epochs": 200, } hyperparams['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format( int(time.time()), hyperparams['wrn_depth'], hyperparams['wrn_width'], hyperparams['batch_size'], hyperparams['epochs'], hyperparams['epsilon'], hyperparams['delta'], hyperparams['norm'], )) model = train(hyperparams) torch.save(model.state_dict(), hyperparams['logfile'].with_suffix('.pt')) if __name__ == '__main__': main()