From 0d67830f7e39e66e3ad93eba49bbc3a924dabae1 Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Mon, 2 Dec 2024 23:48:50 -0700 Subject: [PATCH] O1: add training code --- one_run_audit/WideResNet.py | 143 +++++++++++++++++++++++ one_run_audit/audit.py | 222 ++++++++++++++++++++++++++++++++++++ one_run_audit/equations.py | 53 +++++++++ one_run_audit/plot.py | 21 ++++ 4 files changed, 439 insertions(+) create mode 100644 one_run_audit/WideResNet.py create mode 100644 one_run_audit/audit.py create mode 100644 one_run_audit/equations.py create mode 100644 one_run_audit/plot.py diff --git a/one_run_audit/WideResNet.py b/one_run_audit/WideResNet.py new file mode 100644 index 0000000..024818e --- /dev/null +++ b/one_run_audit/WideResNet.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn +from torchsummary import summary +import math + + +class IndividualBlock1(nn.Module): + def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True): + super(IndividualBlock1, self).__init__() + + self.activation = nn.ReLU(inplace=True) + + self.batch_norm1 = nn.BatchNorm2d(input_features) + self.batch_norm2 = nn.BatchNorm2d(output_features) + + self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False) + + self.subsample_input = subsample_input + self.increase_filters = increase_filters + if subsample_input: + self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False) + elif increase_filters: + self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False) + + def forward(self, x): + + if self.subsample_input or self.increase_filters: + x = self.batch_norm1(x) + x = self.activation(x) + x1 = self.conv1(x) + else: + x1 = self.batch_norm1(x) + x1 = self.activation(x1) + x1 = self.conv1(x1) + x1 = self.batch_norm2(x1) + x1 = self.activation(x1) + x1 = self.conv2(x1) + + if self.subsample_input or self.increase_filters: + return self.conv_inp(x) + x1 + else: + return x + x1 + + +class IndividualBlockN(nn.Module): + def __init__(self, input_features, output_features, stride): + super(IndividualBlockN, self).__init__() + + self.activation = nn.ReLU(inplace=True) + + self.batch_norm1 = nn.BatchNorm2d(input_features) + self.batch_norm2 = nn.BatchNorm2d(output_features) + + self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False) + + def forward(self, x): + x1 = self.batch_norm1(x) + x1 = self.activation(x1) + x1 = self.conv1(x1) + x1 = self.batch_norm2(x1) + x1 = self.activation(x1) + x1 = self.conv2(x1) + + return x1 + x + + +class Nblock(nn.Module): + + def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True): + super(Nblock, self).__init__() + + layers = [] + for i in range(N): + if i == 0: + layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters)) + else: + layers.append(IndividualBlockN(output_features, output_features, stride=1)) + + self.nblockLayer = nn.Sequential(*layers) + + def forward(self, x): + return self.nblockLayer(x) + + +class WideResNet(nn.Module): + + def __init__(self, d, k, n_classes, input_features, output_features, strides): + super(WideResNet, self).__init__() + + self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False) + + filters = [16 * k, 32 * k, 64 * k] + self.out_filters = filters[-1] + N = (d - 4) // 6 + increase_filters = k > 1 + self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters) + self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2]) + self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3]) + + self.batch_norm = nn.BatchNorm2d(filters[-1]) + self.activation = nn.ReLU(inplace=True) + self.avg_pool = nn.AvgPool2d(kernel_size=8) + self.fc = nn.Linear(filters[-1], n_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.bias.data.zero_() + + def forward(self, x): + x = self.conv1(x) + attention1 = self.block1(x) + attention2 = self.block2(attention1) + attention3 = self.block3(attention2) + out = self.batch_norm(attention3) + out = self.activation(out) + out = self.avg_pool(out) + out = out.view(-1, self.out_filters) + + return self.fc(out), attention1, attention2, attention3 + + +if __name__ == '__main__': + + # change d and k if you want to check a model other than WRN-40-2 + d = 40 + k = 2 + strides = [1, 1, 2, 2] + net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides) + + # verify that an output is produced + sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False) + net(sample_input) + + # Summarize model + summary(net, input_size=(3, 32, 32)) diff --git a/one_run_audit/audit.py b/one_run_audit/audit.py new file mode 100644 index 0000000..a619025 --- /dev/null +++ b/one_run_audit/audit.py @@ -0,0 +1,222 @@ +import argparse +import equations +import numpy as np +import time +import torch +import torch.nn as nn +from torch import optim +from torch.optim.lr_scheduler import MultiStepLR +from torch.utils.data import DataLoader, Subset +import torch.nn.functional as F +from pathlib import Path +from torchvision import transforms +from torchvision.datasets import CIFAR10 +import pytorch_lightning as pl +import opacus +from opacus.validators import ModuleValidator +from opacus.utils.batch_memory_manager import BatchMemoryManager +from WideResNet import WideResNet +import warnings +warnings.filterwarnings("ignore") + + +DEVICE = torch.device("cpu") + + +def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): + seed = np.random.randint(0, 1e9) + seed ^= int(time.time()) + pl.seed_everything(seed) + + train_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), + (4, 4, 4, 4), mode='reflect').squeeze()), + transforms.ToPILImage(), + transforms.RandomCrop(32), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + datadir = Path("./data") + train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) + test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform) + + keep = np.full(len(train_ds), True) + keep[:m] = False + np.random.shuffle(keep) + + train_ds_p = Subset(train_ds, keep) + train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4) + train_dl_p = DataLoader(train_ds_p, batch_size=train_batch_size, shuffle=True, num_workers=4) + test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4) + + return train_dl, train_dl_p, test_dl + + +def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler): + best_test_set_accuracy = 0 + + for epoch in range(hp['epochs']): + model.train() + for i, data in enumerate(train_loader, 0): + inputs, labels = data + inputs = inputs.to(DEVICE) + labels = labels.to(DEVICE) + + optimizer.zero_grad() + + wrn_outputs = model(inputs) + outputs = wrn_outputs[0] + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + scheduler.step() + + if epoch % 20 == 0 or epoch == hp['epochs'] - 1: + with torch.no_grad(): + correct = 0 + total = 0 + + model.eval() + for data in test_loader: + images, labels = data + images = images.to(DEVICE) + labels = labels.to(DEVICE) + + wrn_outputs = model(images) + outputs = wrn_outputs[0] + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + epoch_accuracy = correct / total + epoch_accuracy = round(100 * epoch_accuracy, 2) + print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%") + + return best_test_set_accuracy + + +def train(hp): + model = WideResNet( + d=hp["wrn_depth"], + k=hp["wrn_width"], + n_classes=10, + input_features=3, + output_features=16, + strides=[1, 1, 2, 2], + ) + model = ModuleValidator.fix(model) + ModuleValidator.validate(model, strict=True) + model = model.to(DEVICE) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD( + model.parameters(), + lr=0.1, + momentum=0.9, + nesterov=True, + weight_decay=5e-4 + ) + scheduler = MultiStepLR( + optimizer, + milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]], + gamma=0.2 + ) + + train_dl, train_dl_p, test_dl = get_dataloaders() + + print(f"Training with {hp['epochs']} epochs") + + if hp['epsilon'] is not None: + privacy_engine = opacus.PrivacyEngine() + model, optimizer, train_loader = privacy_engine.make_private_with_epsilon( + module=model, + optimizer=optimizer, + data_loader=train_dl, + epochs=hp['epochs'], + target_epsilon=hp['epsilon'], + target_delta=hp['delta'], + max_grad_norm=hp['norm'], + ) + + print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}") + print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}") + + with BatchMemoryManager( + data_loader=train_loader, + max_physical_batch_size=1000, # Roughly 12gb vram, uses 9.4 + optimizer=optimizer + ) as memory_safe_data_loader: + best_test_set_accuracy = train_no_cap( + model, + hp, + train_dl, + test_dl, + optimizer, + criterion, + scheduler, + ) + else: + print("Training without differential privacy") + best_test_set_accuracy = train_no_cap( + model, + hp, + train_dl, + test_dl, + optimizer, + criterion, + scheduler, + ) + + return model + + +def main(): + global DEVICE + + parser = argparse.ArgumentParser(description='WideResNet O1 audit') + parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True) + parser.add_argument('--cuda', type=int, help='gpu index', required=False) + parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) + args = parser.parse_args() + + if torch.cuda.is_available() and args.cuda: + DEVICE = torch.device(f'cuda:{args.cuda}') + elif torch.cuda.is_available(): + DEVICE = torch.device('cuda:0') + else: + DEVICE = torch.device('cpu') + + hyperparams = { + "wrn_depth": 16, + "wrn_width": 1, + "epsilon": args.epsilon, + "delta": 1e-5, + "norm": args.norm, + "batch_size": 4096, + "epochs": 200, + } + + hyperparams['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format( + int(time.time()), + hyperparams['wrn_depth'], + hyperparams['wrn_width'], + hyperparams['batch_size'], + hyperparams['epochs'], + hyperparams['epsilon'], + hyperparams['delta'], + hyperparams['norm'], + )) + + model = train(hyperparams) + torch.save(model.state_dict(), hyperparams['logfile'].with_suffix('.pt')) + +if __name__ == '__main__': + main() diff --git a/one_run_audit/equations.py b/one_run_audit/equations.py new file mode 100644 index 0000000..b66ede9 --- /dev/null +++ b/one_run_audit/equations.py @@ -0,0 +1,53 @@ +# These equations come from: +# [1] T. Steinke, M. Nasr, and M. Jagielski, “Privacy Auditing with One (1) +# Training Run,” May 15, 2023, arXiv: arXiv:2305.08846. Accessed: Sep. 15, 2024. +# [Online]. Available: http://arxiv.org/abs/2305.08846 + +import math +import scipy.stats + +# m = number of examples, each included independently with probability 0.5 +# r = number of guesses (i.e. excluding abstentions) +# v = number of correct guesses by auditor +# eps,delta = DP guarantee of null hypothesis +# output: p-value = probability of >=v correct guesses under null hypothesis +def p_value_DP_audit(m, r, v, eps, delta): + assert 0 <= v <= r <= m + assert eps >= 0 + assert 0 <= delta <= 1 + q = 1 / (1 + math.exp(-eps)) # accuracy of eps-DP randomized response + beta = scipy.stats.binom.sf(v - 1, r, q) # = P[Binomial(r, q) >= v] + alpha = 0 + sum = 0 # = P[v > Binomial(r, q) >= v - i] + for i in range(1, v + 1): + sum = sum + scipy.stats.binom.pmf(v - i, r, q) + if sum > i * alpha: + alpha = sum / i + p = beta + alpha * delta * 2 * m + return min(p, 1) + +# m = number of examples, each included independently with probability 0.5 +# r = number of guesses (i.e. excluding abstentions) +# v = number of correct guesses by auditor +# p = 1-confidence e.g. p=0.05 corresponds to 95% +# output: lower bound on eps i.e. algorithm is not (eps,delta)-DP +def get_eps_audit(m, r, v, delta, p): + assert 0 <= v <= r <= m + assert 0 <= delta <= 1 + assert 0 < p < 1 + eps_min = 0 # maintain p_value_DP(eps_min) < p + eps_max = 1 # maintain p_value_DP(eps_max) >= p + while p_value_DP_audit(m, r, v, eps_max, delta) < p: + eps_max = eps_max + 1 + for _ in range(30): # binary search + eps = (eps_min + eps_max) / 2 + if p_value_DP_audit(m, r, v, eps, delta) < p: + eps_min = eps + else: + eps_max = eps + return eps_min + + +if __name__ == '__main__': + x = 100 + print(f"For m=100 r=100 v=100 p=0.05: {get_eps_audit(x, x, x, 1e-5, 0.05)}") diff --git a/one_run_audit/plot.py b/one_run_audit/plot.py new file mode 100644 index 0000000..e9194c9 --- /dev/null +++ b/one_run_audit/plot.py @@ -0,0 +1,21 @@ +import numpy as np +import matplotlib.pyplot as plt +from tqdm import tqdm +from equations import get_eps_audit + + +delta = 1e-5 +p_value = 0.05 + +x_values = np.floor((1.5)**np.arange(30)).astype(int) +x_values = np.concatenate([x_values[x_values < 60000], [60000]]) +y_values = [get_eps_audit(x, x, x, delta, p_value) for x in tqdm(x_values)] + +plt.xscale('log') +plt.plot(x_values, y_values, marker='o') +plt.xlabel("Number of samples guessed correctly") +plt.ylabel("ε value audited") +plt.title("Maximum possible ε from audit") + +# 5. Save the plot as a PNG +plt.savefig("/dev/shm/my_plot.png", dpi=300, bbox_inches='tight')