import argparse import equations import numpy as np import time import torch import torch.nn as nn from torch import optim from torch.optim.lr_scheduler import MultiStepLR from torch.utils.data import DataLoader, Subset, TensorDataset import torch.nn.functional as F from pathlib import Path from torchvision import transforms from torchvision.datasets import CIFAR10 import pytorch_lightning as pl import opacus from opacus.validators import ModuleValidator from opacus.utils.batch_memory_manager import BatchMemoryManager from WideResNet import WideResNet import warnings warnings.filterwarnings("ignore") DEVICE = None def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): seed = np.random.randint(0, 1e9) seed ^= int(time.time()) pl.seed_everything(seed) train_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode='reflect').squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) datadir = Path("./data") train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform) S = np.full(len(train_ds), True) S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out p = np.random.permutation(len(train_ds)) x_in = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms x_in = x_in[S[p]] y_in = np.array(train_ds.targets).astype(np.int64) y_in = y_in[S[p]] x_m = np.full(len(train_ds), False) x_m[:m] = True x_m = x_m[p] # These are the points being guessed at td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long()) train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4) return train_dl, test_dl, x_in, x_m, S[p] def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler): best_test_set_accuracy = 0 for epoch in range(hp['epochs']): model.train() for i, data in enumerate(train_loader, 0): inputs, labels = data inputs = inputs.to(DEVICE) labels = labels.to(DEVICE) optimizer.zero_grad() wrn_outputs = model(inputs) outputs = wrn_outputs[0] loss = criterion(outputs, labels) loss.backward() optimizer.step() scheduler.step() if epoch % 10 == 0 or epoch == hp['epochs'] - 1: with torch.no_grad(): correct = 0 total = 0 model.eval() for data in test_loader: images, labels = data images = images.to(DEVICE) labels = labels.to(DEVICE) wrn_outputs = model(images) outputs = wrn_outputs[0] _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() epoch_accuracy = correct / total epoch_accuracy = round(100 * epoch_accuracy, 2) print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%") return best_test_set_accuracy def train(hp): model = WideResNet( d=hp["wrn_depth"], k=hp["wrn_width"], n_classes=10, input_features=3, output_features=16, strides=[1, 1, 2, 2], ) model = model.to(DEVICE) model = ModuleValidator.fix(model) ModuleValidator.validate(model, strict=True) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD( model.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4 ) scheduler = MultiStepLR( optimizer, milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]], gamma=0.2 ) train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size']) print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}") print(f"Got x_in: {x_in.shape}") print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}") print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}") print(f"Got train dataloader: {len(train_dl)}") print(f"Training with {hp['epochs']} epochs") if hp['epsilon'] is not None: privacy_engine = opacus.PrivacyEngine() model, optimizer, train_loader = privacy_engine.make_private_with_epsilon( module=model, optimizer=optimizer, data_loader=train_dl, epochs=hp['epochs'], target_epsilon=hp['epsilon'], target_delta=hp['delta'], max_grad_norm=hp['norm'], ) print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}") print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}") with BatchMemoryManager( data_loader=train_loader, max_physical_batch_size=2000, # 1000 ~= 9.4GB vram optimizer=optimizer ) as memory_safe_data_loader: best_test_set_accuracy = train_no_cap( model, hp, memory_safe_data_loader, test_dl, optimizer, criterion, scheduler, ) else: print("Training without differential privacy") best_test_set_accuracy = train_no_cap( model, hp, train_dl, test_dl, optimizer, criterion, scheduler, ) return model def main(): global DEVICE parser = argparse.ArgumentParser(description='WideResNet O1 audit') parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True) parser.add_argument('--cuda', type=int, help='gpu index', required=False) parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) parser.add_argument('--m', type=int, help='number of target points', required=True) args = parser.parse_args() if torch.cuda.is_available() and args.cuda: DEVICE = torch.device(f'cuda:{args.cuda}') elif torch.cuda.is_available(): DEVICE = torch.device('cuda:0') else: DEVICE = torch.device('cpu') hyperparams = { "target_points": args.m, "wrn_depth": 16, "wrn_width": 1, "epsilon": args.epsilon, "delta": 1e-5, "norm": args.norm, "batch_size": 4096, "epochs": 200, } hyperparams['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format( int(time.time()), hyperparams['wrn_depth'], hyperparams['wrn_width'], hyperparams['batch_size'], hyperparams['epochs'], hyperparams['epsilon'], hyperparams['delta'], hyperparams['norm'], )) model = train(hyperparams) torch.save(model.state_dict(), hyperparams['logfile'].with_suffix('.pt')) if __name__ == '__main__': main()