O1: add training code
This commit is contained in:
parent
2eef211415
commit
0d67830f7e
4 changed files with 439 additions and 0 deletions
143
one_run_audit/WideResNet.py
Normal file
143
one_run_audit/WideResNet.py
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchsummary import summary
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class IndividualBlock1(nn.Module):
|
||||||
|
def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
|
||||||
|
super(IndividualBlock1, self).__init__()
|
||||||
|
|
||||||
|
self.activation = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.batch_norm1 = nn.BatchNorm2d(input_features)
|
||||||
|
self.batch_norm2 = nn.BatchNorm2d(output_features)
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||||
|
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
|
|
||||||
|
self.subsample_input = subsample_input
|
||||||
|
self.increase_filters = increase_filters
|
||||||
|
if subsample_input:
|
||||||
|
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
|
||||||
|
elif increase_filters:
|
||||||
|
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
if self.subsample_input or self.increase_filters:
|
||||||
|
x = self.batch_norm1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x1 = self.conv1(x)
|
||||||
|
else:
|
||||||
|
x1 = self.batch_norm1(x)
|
||||||
|
x1 = self.activation(x1)
|
||||||
|
x1 = self.conv1(x1)
|
||||||
|
x1 = self.batch_norm2(x1)
|
||||||
|
x1 = self.activation(x1)
|
||||||
|
x1 = self.conv2(x1)
|
||||||
|
|
||||||
|
if self.subsample_input or self.increase_filters:
|
||||||
|
return self.conv_inp(x) + x1
|
||||||
|
else:
|
||||||
|
return x + x1
|
||||||
|
|
||||||
|
|
||||||
|
class IndividualBlockN(nn.Module):
|
||||||
|
def __init__(self, input_features, output_features, stride):
|
||||||
|
super(IndividualBlockN, self).__init__()
|
||||||
|
|
||||||
|
self.activation = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.batch_norm1 = nn.BatchNorm2d(input_features)
|
||||||
|
self.batch_norm2 = nn.BatchNorm2d(output_features)
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||||
|
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.batch_norm1(x)
|
||||||
|
x1 = self.activation(x1)
|
||||||
|
x1 = self.conv1(x1)
|
||||||
|
x1 = self.batch_norm2(x1)
|
||||||
|
x1 = self.activation(x1)
|
||||||
|
x1 = self.conv2(x1)
|
||||||
|
|
||||||
|
return x1 + x
|
||||||
|
|
||||||
|
|
||||||
|
class Nblock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
|
||||||
|
super(Nblock, self).__init__()
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
for i in range(N):
|
||||||
|
if i == 0:
|
||||||
|
layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
|
||||||
|
else:
|
||||||
|
layers.append(IndividualBlockN(output_features, output_features, stride=1))
|
||||||
|
|
||||||
|
self.nblockLayer = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.nblockLayer(x)
|
||||||
|
|
||||||
|
|
||||||
|
class WideResNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d, k, n_classes, input_features, output_features, strides):
|
||||||
|
super(WideResNet, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)
|
||||||
|
|
||||||
|
filters = [16 * k, 32 * k, 64 * k]
|
||||||
|
self.out_filters = filters[-1]
|
||||||
|
N = (d - 4) // 6
|
||||||
|
increase_filters = k > 1
|
||||||
|
self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
|
||||||
|
self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
|
||||||
|
self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])
|
||||||
|
|
||||||
|
self.batch_norm = nn.BatchNorm2d(filters[-1])
|
||||||
|
self.activation = nn.ReLU(inplace=True)
|
||||||
|
self.avg_pool = nn.AvgPool2d(kernel_size=8)
|
||||||
|
self.fc = nn.Linear(filters[-1], n_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
attention1 = self.block1(x)
|
||||||
|
attention2 = self.block2(attention1)
|
||||||
|
attention3 = self.block3(attention2)
|
||||||
|
out = self.batch_norm(attention3)
|
||||||
|
out = self.activation(out)
|
||||||
|
out = self.avg_pool(out)
|
||||||
|
out = out.view(-1, self.out_filters)
|
||||||
|
|
||||||
|
return self.fc(out), attention1, attention2, attention3
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# change d and k if you want to check a model other than WRN-40-2
|
||||||
|
d = 40
|
||||||
|
k = 2
|
||||||
|
strides = [1, 1, 2, 2]
|
||||||
|
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides)
|
||||||
|
|
||||||
|
# verify that an output is produced
|
||||||
|
sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
|
||||||
|
net(sample_input)
|
||||||
|
|
||||||
|
# Summarize model
|
||||||
|
summary(net, input_size=(3, 32, 32))
|
222
one_run_audit/audit.py
Normal file
222
one_run_audit/audit.py
Normal file
|
@ -0,0 +1,222 @@
|
||||||
|
import argparse
|
||||||
|
import equations
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import optim
|
||||||
|
from torch.optim.lr_scheduler import MultiStepLR
|
||||||
|
from torch.utils.data import DataLoader, Subset
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from pathlib import Path
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.datasets import CIFAR10
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import opacus
|
||||||
|
from opacus.validators import ModuleValidator
|
||||||
|
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
||||||
|
from WideResNet import WideResNet
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
|
DEVICE = torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
|
seed = np.random.randint(0, 1e9)
|
||||||
|
seed ^= int(time.time())
|
||||||
|
pl.seed_everything(seed)
|
||||||
|
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
|
||||||
|
(4, 4, 4, 4), mode='reflect').squeeze()),
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
transforms.RandomCrop(32),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
|
])
|
||||||
|
|
||||||
|
test_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
|
])
|
||||||
|
datadir = Path("./data")
|
||||||
|
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
||||||
|
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
||||||
|
|
||||||
|
keep = np.full(len(train_ds), True)
|
||||||
|
keep[:m] = False
|
||||||
|
np.random.shuffle(keep)
|
||||||
|
|
||||||
|
train_ds_p = Subset(train_ds, keep)
|
||||||
|
train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||||
|
train_dl_p = DataLoader(train_ds_p, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||||
|
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
return train_dl, train_dl_p, test_dl
|
||||||
|
|
||||||
|
|
||||||
|
def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler):
|
||||||
|
best_test_set_accuracy = 0
|
||||||
|
|
||||||
|
for epoch in range(hp['epochs']):
|
||||||
|
model.train()
|
||||||
|
for i, data in enumerate(train_loader, 0):
|
||||||
|
inputs, labels = data
|
||||||
|
inputs = inputs.to(DEVICE)
|
||||||
|
labels = labels.to(DEVICE)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
wrn_outputs = model(inputs)
|
||||||
|
outputs = wrn_outputs[0]
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
if epoch % 20 == 0 or epoch == hp['epochs'] - 1:
|
||||||
|
with torch.no_grad():
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
for data in test_loader:
|
||||||
|
images, labels = data
|
||||||
|
images = images.to(DEVICE)
|
||||||
|
labels = labels.to(DEVICE)
|
||||||
|
|
||||||
|
wrn_outputs = model(images)
|
||||||
|
outputs = wrn_outputs[0]
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
epoch_accuracy = correct / total
|
||||||
|
epoch_accuracy = round(100 * epoch_accuracy, 2)
|
||||||
|
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%")
|
||||||
|
|
||||||
|
return best_test_set_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
def train(hp):
|
||||||
|
model = WideResNet(
|
||||||
|
d=hp["wrn_depth"],
|
||||||
|
k=hp["wrn_width"],
|
||||||
|
n_classes=10,
|
||||||
|
input_features=3,
|
||||||
|
output_features=16,
|
||||||
|
strides=[1, 1, 2, 2],
|
||||||
|
)
|
||||||
|
model = ModuleValidator.fix(model)
|
||||||
|
ModuleValidator.validate(model, strict=True)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.SGD(
|
||||||
|
model.parameters(),
|
||||||
|
lr=0.1,
|
||||||
|
momentum=0.9,
|
||||||
|
nesterov=True,
|
||||||
|
weight_decay=5e-4
|
||||||
|
)
|
||||||
|
scheduler = MultiStepLR(
|
||||||
|
optimizer,
|
||||||
|
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
|
||||||
|
gamma=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dl, train_dl_p, test_dl = get_dataloaders()
|
||||||
|
|
||||||
|
print(f"Training with {hp['epochs']} epochs")
|
||||||
|
|
||||||
|
if hp['epsilon'] is not None:
|
||||||
|
privacy_engine = opacus.PrivacyEngine()
|
||||||
|
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
||||||
|
module=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
data_loader=train_dl,
|
||||||
|
epochs=hp['epochs'],
|
||||||
|
target_epsilon=hp['epsilon'],
|
||||||
|
target_delta=hp['delta'],
|
||||||
|
max_grad_norm=hp['norm'],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
|
||||||
|
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
|
||||||
|
|
||||||
|
with BatchMemoryManager(
|
||||||
|
data_loader=train_loader,
|
||||||
|
max_physical_batch_size=1000, # Roughly 12gb vram, uses 9.4
|
||||||
|
optimizer=optimizer
|
||||||
|
) as memory_safe_data_loader:
|
||||||
|
best_test_set_accuracy = train_no_cap(
|
||||||
|
model,
|
||||||
|
hp,
|
||||||
|
train_dl,
|
||||||
|
test_dl,
|
||||||
|
optimizer,
|
||||||
|
criterion,
|
||||||
|
scheduler,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Training without differential privacy")
|
||||||
|
best_test_set_accuracy = train_no_cap(
|
||||||
|
model,
|
||||||
|
hp,
|
||||||
|
train_dl,
|
||||||
|
test_dl,
|
||||||
|
optimizer,
|
||||||
|
criterion,
|
||||||
|
scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
global DEVICE
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='WideResNet O1 audit')
|
||||||
|
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
|
||||||
|
parser.add_argument('--cuda', type=int, help='gpu index', required=False)
|
||||||
|
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if torch.cuda.is_available() and args.cuda:
|
||||||
|
DEVICE = torch.device(f'cuda:{args.cuda}')
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
DEVICE = torch.device('cuda:0')
|
||||||
|
else:
|
||||||
|
DEVICE = torch.device('cpu')
|
||||||
|
|
||||||
|
hyperparams = {
|
||||||
|
"wrn_depth": 16,
|
||||||
|
"wrn_width": 1,
|
||||||
|
"epsilon": args.epsilon,
|
||||||
|
"delta": 1e-5,
|
||||||
|
"norm": args.norm,
|
||||||
|
"batch_size": 4096,
|
||||||
|
"epochs": 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
hyperparams['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format(
|
||||||
|
int(time.time()),
|
||||||
|
hyperparams['wrn_depth'],
|
||||||
|
hyperparams['wrn_width'],
|
||||||
|
hyperparams['batch_size'],
|
||||||
|
hyperparams['epochs'],
|
||||||
|
hyperparams['epsilon'],
|
||||||
|
hyperparams['delta'],
|
||||||
|
hyperparams['norm'],
|
||||||
|
))
|
||||||
|
|
||||||
|
model = train(hyperparams)
|
||||||
|
torch.save(model.state_dict(), hyperparams['logfile'].with_suffix('.pt'))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
53
one_run_audit/equations.py
Normal file
53
one_run_audit/equations.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
# These equations come from:
|
||||||
|
# [1] T. Steinke, M. Nasr, and M. Jagielski, “Privacy Auditing with One (1)
|
||||||
|
# Training Run,” May 15, 2023, arXiv: arXiv:2305.08846. Accessed: Sep. 15, 2024.
|
||||||
|
# [Online]. Available: http://arxiv.org/abs/2305.08846
|
||||||
|
|
||||||
|
import math
|
||||||
|
import scipy.stats
|
||||||
|
|
||||||
|
# m = number of examples, each included independently with probability 0.5
|
||||||
|
# r = number of guesses (i.e. excluding abstentions)
|
||||||
|
# v = number of correct guesses by auditor
|
||||||
|
# eps,delta = DP guarantee of null hypothesis
|
||||||
|
# output: p-value = probability of >=v correct guesses under null hypothesis
|
||||||
|
def p_value_DP_audit(m, r, v, eps, delta):
|
||||||
|
assert 0 <= v <= r <= m
|
||||||
|
assert eps >= 0
|
||||||
|
assert 0 <= delta <= 1
|
||||||
|
q = 1 / (1 + math.exp(-eps)) # accuracy of eps-DP randomized response
|
||||||
|
beta = scipy.stats.binom.sf(v - 1, r, q) # = P[Binomial(r, q) >= v]
|
||||||
|
alpha = 0
|
||||||
|
sum = 0 # = P[v > Binomial(r, q) >= v - i]
|
||||||
|
for i in range(1, v + 1):
|
||||||
|
sum = sum + scipy.stats.binom.pmf(v - i, r, q)
|
||||||
|
if sum > i * alpha:
|
||||||
|
alpha = sum / i
|
||||||
|
p = beta + alpha * delta * 2 * m
|
||||||
|
return min(p, 1)
|
||||||
|
|
||||||
|
# m = number of examples, each included independently with probability 0.5
|
||||||
|
# r = number of guesses (i.e. excluding abstentions)
|
||||||
|
# v = number of correct guesses by auditor
|
||||||
|
# p = 1-confidence e.g. p=0.05 corresponds to 95%
|
||||||
|
# output: lower bound on eps i.e. algorithm is not (eps,delta)-DP
|
||||||
|
def get_eps_audit(m, r, v, delta, p):
|
||||||
|
assert 0 <= v <= r <= m
|
||||||
|
assert 0 <= delta <= 1
|
||||||
|
assert 0 < p < 1
|
||||||
|
eps_min = 0 # maintain p_value_DP(eps_min) < p
|
||||||
|
eps_max = 1 # maintain p_value_DP(eps_max) >= p
|
||||||
|
while p_value_DP_audit(m, r, v, eps_max, delta) < p:
|
||||||
|
eps_max = eps_max + 1
|
||||||
|
for _ in range(30): # binary search
|
||||||
|
eps = (eps_min + eps_max) / 2
|
||||||
|
if p_value_DP_audit(m, r, v, eps, delta) < p:
|
||||||
|
eps_min = eps
|
||||||
|
else:
|
||||||
|
eps_max = eps
|
||||||
|
return eps_min
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
x = 100
|
||||||
|
print(f"For m=100 r=100 v=100 p=0.05: {get_eps_audit(x, x, x, 1e-5, 0.05)}")
|
21
one_run_audit/plot.py
Normal file
21
one_run_audit/plot.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from tqdm import tqdm
|
||||||
|
from equations import get_eps_audit
|
||||||
|
|
||||||
|
|
||||||
|
delta = 1e-5
|
||||||
|
p_value = 0.05
|
||||||
|
|
||||||
|
x_values = np.floor((1.5)**np.arange(30)).astype(int)
|
||||||
|
x_values = np.concatenate([x_values[x_values < 60000], [60000]])
|
||||||
|
y_values = [get_eps_audit(x, x, x, delta, p_value) for x in tqdm(x_values)]
|
||||||
|
|
||||||
|
plt.xscale('log')
|
||||||
|
plt.plot(x_values, y_values, marker='o')
|
||||||
|
plt.xlabel("Number of samples guessed correctly")
|
||||||
|
plt.ylabel("ε value audited")
|
||||||
|
plt.title("Maximum possible ε from audit")
|
||||||
|
|
||||||
|
# 5. Save the plot as a PNG
|
||||||
|
plt.savefig("/dev/shm/my_plot.png", dpi=300, bbox_inches='tight')
|
Loading…
Reference in a new issue