O1: add training code
This commit is contained in:
parent
2eef211415
commit
0d67830f7e
4 changed files with 439 additions and 0 deletions
143
one_run_audit/WideResNet.py
Normal file
143
one_run_audit/WideResNet.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchsummary import summary
|
||||
import math
|
||||
|
||||
|
||||
class IndividualBlock1(nn.Module):
|
||||
def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
|
||||
super(IndividualBlock1, self).__init__()
|
||||
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
self.batch_norm1 = nn.BatchNorm2d(input_features)
|
||||
self.batch_norm2 = nn.BatchNorm2d(output_features)
|
||||
|
||||
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
self.subsample_input = subsample_input
|
||||
self.increase_filters = increase_filters
|
||||
if subsample_input:
|
||||
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
|
||||
elif increase_filters:
|
||||
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.subsample_input or self.increase_filters:
|
||||
x = self.batch_norm1(x)
|
||||
x = self.activation(x)
|
||||
x1 = self.conv1(x)
|
||||
else:
|
||||
x1 = self.batch_norm1(x)
|
||||
x1 = self.activation(x1)
|
||||
x1 = self.conv1(x1)
|
||||
x1 = self.batch_norm2(x1)
|
||||
x1 = self.activation(x1)
|
||||
x1 = self.conv2(x1)
|
||||
|
||||
if self.subsample_input or self.increase_filters:
|
||||
return self.conv_inp(x) + x1
|
||||
else:
|
||||
return x + x1
|
||||
|
||||
|
||||
class IndividualBlockN(nn.Module):
|
||||
def __init__(self, input_features, output_features, stride):
|
||||
super(IndividualBlockN, self).__init__()
|
||||
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
self.batch_norm1 = nn.BatchNorm2d(input_features)
|
||||
self.batch_norm2 = nn.BatchNorm2d(output_features)
|
||||
|
||||
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.batch_norm1(x)
|
||||
x1 = self.activation(x1)
|
||||
x1 = self.conv1(x1)
|
||||
x1 = self.batch_norm2(x1)
|
||||
x1 = self.activation(x1)
|
||||
x1 = self.conv2(x1)
|
||||
|
||||
return x1 + x
|
||||
|
||||
|
||||
class Nblock(nn.Module):
|
||||
|
||||
def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
|
||||
super(Nblock, self).__init__()
|
||||
|
||||
layers = []
|
||||
for i in range(N):
|
||||
if i == 0:
|
||||
layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
|
||||
else:
|
||||
layers.append(IndividualBlockN(output_features, output_features, stride=1))
|
||||
|
||||
self.nblockLayer = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.nblockLayer(x)
|
||||
|
||||
|
||||
class WideResNet(nn.Module):
|
||||
|
||||
def __init__(self, d, k, n_classes, input_features, output_features, strides):
|
||||
super(WideResNet, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)
|
||||
|
||||
filters = [16 * k, 32 * k, 64 * k]
|
||||
self.out_filters = filters[-1]
|
||||
N = (d - 4) // 6
|
||||
increase_filters = k > 1
|
||||
self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
|
||||
self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
|
||||
self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])
|
||||
|
||||
self.batch_norm = nn.BatchNorm2d(filters[-1])
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size=8)
|
||||
self.fc = nn.Linear(filters[-1], n_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
attention1 = self.block1(x)
|
||||
attention2 = self.block2(attention1)
|
||||
attention3 = self.block3(attention2)
|
||||
out = self.batch_norm(attention3)
|
||||
out = self.activation(out)
|
||||
out = self.avg_pool(out)
|
||||
out = out.view(-1, self.out_filters)
|
||||
|
||||
return self.fc(out), attention1, attention2, attention3
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# change d and k if you want to check a model other than WRN-40-2
|
||||
d = 40
|
||||
k = 2
|
||||
strides = [1, 1, 2, 2]
|
||||
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides)
|
||||
|
||||
# verify that an output is produced
|
||||
sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
|
||||
net(sample_input)
|
||||
|
||||
# Summarize model
|
||||
summary(net, input_size=(3, 32, 32))
|
222
one_run_audit/audit.py
Normal file
222
one_run_audit/audit.py
Normal file
|
@ -0,0 +1,222 @@
|
|||
import argparse
|
||||
import equations
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
import torch.nn.functional as F
|
||||
from pathlib import Path
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
import pytorch_lightning as pl
|
||||
import opacus
|
||||
from opacus.validators import ModuleValidator
|
||||
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
||||
from WideResNet import WideResNet
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||
seed = np.random.randint(0, 1e9)
|
||||
seed ^= int(time.time())
|
||||
pl.seed_everything(seed)
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
|
||||
(4, 4, 4, 4), mode='reflect').squeeze()),
|
||||
transforms.ToPILImage(),
|
||||
transforms.RandomCrop(32),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
test_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
datadir = Path("./data")
|
||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
||||
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
||||
|
||||
keep = np.full(len(train_ds), True)
|
||||
keep[:m] = False
|
||||
np.random.shuffle(keep)
|
||||
|
||||
train_ds_p = Subset(train_ds, keep)
|
||||
train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||
train_dl_p = DataLoader(train_ds_p, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||
|
||||
return train_dl, train_dl_p, test_dl
|
||||
|
||||
|
||||
def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler):
|
||||
best_test_set_accuracy = 0
|
||||
|
||||
for epoch in range(hp['epochs']):
|
||||
model.train()
|
||||
for i, data in enumerate(train_loader, 0):
|
||||
inputs, labels = data
|
||||
inputs = inputs.to(DEVICE)
|
||||
labels = labels.to(DEVICE)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
wrn_outputs = model(inputs)
|
||||
outputs = wrn_outputs[0]
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
if epoch % 20 == 0 or epoch == hp['epochs'] - 1:
|
||||
with torch.no_grad():
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
model.eval()
|
||||
for data in test_loader:
|
||||
images, labels = data
|
||||
images = images.to(DEVICE)
|
||||
labels = labels.to(DEVICE)
|
||||
|
||||
wrn_outputs = model(images)
|
||||
outputs = wrn_outputs[0]
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
epoch_accuracy = correct / total
|
||||
epoch_accuracy = round(100 * epoch_accuracy, 2)
|
||||
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%")
|
||||
|
||||
return best_test_set_accuracy
|
||||
|
||||
|
||||
def train(hp):
|
||||
model = WideResNet(
|
||||
d=hp["wrn_depth"],
|
||||
k=hp["wrn_width"],
|
||||
n_classes=10,
|
||||
input_features=3,
|
||||
output_features=16,
|
||||
strides=[1, 1, 2, 2],
|
||||
)
|
||||
model = ModuleValidator.fix(model)
|
||||
ModuleValidator.validate(model, strict=True)
|
||||
model = model.to(DEVICE)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=0.1,
|
||||
momentum=0.9,
|
||||
nesterov=True,
|
||||
weight_decay=5e-4
|
||||
)
|
||||
scheduler = MultiStepLR(
|
||||
optimizer,
|
||||
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
|
||||
gamma=0.2
|
||||
)
|
||||
|
||||
train_dl, train_dl_p, test_dl = get_dataloaders()
|
||||
|
||||
print(f"Training with {hp['epochs']} epochs")
|
||||
|
||||
if hp['epsilon'] is not None:
|
||||
privacy_engine = opacus.PrivacyEngine()
|
||||
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
||||
module=model,
|
||||
optimizer=optimizer,
|
||||
data_loader=train_dl,
|
||||
epochs=hp['epochs'],
|
||||
target_epsilon=hp['epsilon'],
|
||||
target_delta=hp['delta'],
|
||||
max_grad_norm=hp['norm'],
|
||||
)
|
||||
|
||||
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
|
||||
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
|
||||
|
||||
with BatchMemoryManager(
|
||||
data_loader=train_loader,
|
||||
max_physical_batch_size=1000, # Roughly 12gb vram, uses 9.4
|
||||
optimizer=optimizer
|
||||
) as memory_safe_data_loader:
|
||||
best_test_set_accuracy = train_no_cap(
|
||||
model,
|
||||
hp,
|
||||
train_dl,
|
||||
test_dl,
|
||||
optimizer,
|
||||
criterion,
|
||||
scheduler,
|
||||
)
|
||||
else:
|
||||
print("Training without differential privacy")
|
||||
best_test_set_accuracy = train_no_cap(
|
||||
model,
|
||||
hp,
|
||||
train_dl,
|
||||
test_dl,
|
||||
optimizer,
|
||||
criterion,
|
||||
scheduler,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
global DEVICE
|
||||
|
||||
parser = argparse.ArgumentParser(description='WideResNet O1 audit')
|
||||
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
|
||||
parser.add_argument('--cuda', type=int, help='gpu index', required=False)
|
||||
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available() and args.cuda:
|
||||
DEVICE = torch.device(f'cuda:{args.cuda}')
|
||||
elif torch.cuda.is_available():
|
||||
DEVICE = torch.device('cuda:0')
|
||||
else:
|
||||
DEVICE = torch.device('cpu')
|
||||
|
||||
hyperparams = {
|
||||
"wrn_depth": 16,
|
||||
"wrn_width": 1,
|
||||
"epsilon": args.epsilon,
|
||||
"delta": 1e-5,
|
||||
"norm": args.norm,
|
||||
"batch_size": 4096,
|
||||
"epochs": 200,
|
||||
}
|
||||
|
||||
hyperparams['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format(
|
||||
int(time.time()),
|
||||
hyperparams['wrn_depth'],
|
||||
hyperparams['wrn_width'],
|
||||
hyperparams['batch_size'],
|
||||
hyperparams['epochs'],
|
||||
hyperparams['epsilon'],
|
||||
hyperparams['delta'],
|
||||
hyperparams['norm'],
|
||||
))
|
||||
|
||||
model = train(hyperparams)
|
||||
torch.save(model.state_dict(), hyperparams['logfile'].with_suffix('.pt'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
53
one_run_audit/equations.py
Normal file
53
one_run_audit/equations.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
# These equations come from:
|
||||
# [1] T. Steinke, M. Nasr, and M. Jagielski, “Privacy Auditing with One (1)
|
||||
# Training Run,” May 15, 2023, arXiv: arXiv:2305.08846. Accessed: Sep. 15, 2024.
|
||||
# [Online]. Available: http://arxiv.org/abs/2305.08846
|
||||
|
||||
import math
|
||||
import scipy.stats
|
||||
|
||||
# m = number of examples, each included independently with probability 0.5
|
||||
# r = number of guesses (i.e. excluding abstentions)
|
||||
# v = number of correct guesses by auditor
|
||||
# eps,delta = DP guarantee of null hypothesis
|
||||
# output: p-value = probability of >=v correct guesses under null hypothesis
|
||||
def p_value_DP_audit(m, r, v, eps, delta):
|
||||
assert 0 <= v <= r <= m
|
||||
assert eps >= 0
|
||||
assert 0 <= delta <= 1
|
||||
q = 1 / (1 + math.exp(-eps)) # accuracy of eps-DP randomized response
|
||||
beta = scipy.stats.binom.sf(v - 1, r, q) # = P[Binomial(r, q) >= v]
|
||||
alpha = 0
|
||||
sum = 0 # = P[v > Binomial(r, q) >= v - i]
|
||||
for i in range(1, v + 1):
|
||||
sum = sum + scipy.stats.binom.pmf(v - i, r, q)
|
||||
if sum > i * alpha:
|
||||
alpha = sum / i
|
||||
p = beta + alpha * delta * 2 * m
|
||||
return min(p, 1)
|
||||
|
||||
# m = number of examples, each included independently with probability 0.5
|
||||
# r = number of guesses (i.e. excluding abstentions)
|
||||
# v = number of correct guesses by auditor
|
||||
# p = 1-confidence e.g. p=0.05 corresponds to 95%
|
||||
# output: lower bound on eps i.e. algorithm is not (eps,delta)-DP
|
||||
def get_eps_audit(m, r, v, delta, p):
|
||||
assert 0 <= v <= r <= m
|
||||
assert 0 <= delta <= 1
|
||||
assert 0 < p < 1
|
||||
eps_min = 0 # maintain p_value_DP(eps_min) < p
|
||||
eps_max = 1 # maintain p_value_DP(eps_max) >= p
|
||||
while p_value_DP_audit(m, r, v, eps_max, delta) < p:
|
||||
eps_max = eps_max + 1
|
||||
for _ in range(30): # binary search
|
||||
eps = (eps_min + eps_max) / 2
|
||||
if p_value_DP_audit(m, r, v, eps, delta) < p:
|
||||
eps_min = eps
|
||||
else:
|
||||
eps_max = eps
|
||||
return eps_min
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
x = 100
|
||||
print(f"For m=100 r=100 v=100 p=0.05: {get_eps_audit(x, x, x, 1e-5, 0.05)}")
|
21
one_run_audit/plot.py
Normal file
21
one_run_audit/plot.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
from equations import get_eps_audit
|
||||
|
||||
|
||||
delta = 1e-5
|
||||
p_value = 0.05
|
||||
|
||||
x_values = np.floor((1.5)**np.arange(30)).astype(int)
|
||||
x_values = np.concatenate([x_values[x_values < 60000], [60000]])
|
||||
y_values = [get_eps_audit(x, x, x, delta, p_value) for x in tqdm(x_values)]
|
||||
|
||||
plt.xscale('log')
|
||||
plt.plot(x_values, y_values, marker='o')
|
||||
plt.xlabel("Number of samples guessed correctly")
|
||||
plt.ylabel("ε value audited")
|
||||
plt.title("Maximum possible ε from audit")
|
||||
|
||||
# 5. Save the plot as a PNG
|
||||
plt.savefig("/dev/shm/my_plot.png", dpi=300, bbox_inches='tight')
|
Loading…
Reference in a new issue