838 lines
28 KiB
Python
838 lines
28 KiB
Python
import argparse
|
|
import equations
|
|
import numpy as np
|
|
import time
|
|
import copy
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import optim
|
|
from torch.optim.lr_scheduler import MultiStepLR
|
|
from torch.utils.data import DataLoader, Subset, TensorDataset, ConcatDataset
|
|
import torch.nn.functional as F
|
|
from pathlib import Path
|
|
from torchvision import transforms
|
|
from torchvision.datasets import CIFAR10
|
|
import pytorch_lightning as pl
|
|
import opacus
|
|
import random
|
|
from tqdm import tqdm
|
|
from opacus.validators import ModuleValidator
|
|
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
from WideResNet import WideResNet
|
|
from equations import get_eps_audit
|
|
import student_model
|
|
import fast_model
|
|
import convnet_classifier
|
|
import wrn
|
|
import warnings
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
DEVICE = None
|
|
DTYPE = None
|
|
DATADIR = Path("./data")
|
|
|
|
|
|
def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10):
|
|
seed = np.random.randint(0, 1e9)
|
|
seed ^= int(time.time())
|
|
pl.seed_everything(seed)
|
|
|
|
train_transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
|
|
(4, 4, 4, 4), mode='reflect').squeeze()),
|
|
transforms.ToPILImage(),
|
|
transforms.RandomCrop(32),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
|
])
|
|
test_transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
|
])
|
|
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
|
|
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
|
|
|
|
# Original dataset
|
|
x_train = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds)))
|
|
y_train = np.array(train_ds.targets).astype(np.int64)
|
|
|
|
x = np.stack(test_ds[i][0].numpy() for i in range(len(test_ds))) # Applies transforms
|
|
y = np.array(test_ds.targets).astype(np.int64)
|
|
|
|
# Pull points from training set when m > test set
|
|
if m > len(x):
|
|
k = m - len(x)
|
|
mask = np.full(len(x_train), False)
|
|
mask[:k] = True
|
|
|
|
x = np.concatenate([x_train[mask], x])
|
|
y = np.concatenate([y_train[mask], y])
|
|
x_train = x_train[~mask]
|
|
y_train = y_train[~mask]
|
|
|
|
# Store the m points which could have been included/excluded
|
|
mask = np.full(len(x), False)
|
|
mask[:m] = True
|
|
mask = mask[np.random.permutation(len(x))]
|
|
|
|
adv_points = x[mask]
|
|
adv_labels = y[mask]
|
|
|
|
# Mislabel inclusion/exclusion examples intentionally!
|
|
for i in range(len(adv_labels)):
|
|
while True:
|
|
c = np.random.choice(range(10))
|
|
if adv_labels[i] != c:
|
|
adv_labels[i] = c
|
|
break
|
|
|
|
# Choose m points to randomly exclude at chance
|
|
S = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
|
|
|
|
assert len(adv_points) == m
|
|
inc_points = adv_points[S]
|
|
inc_labels = adv_labels[S]
|
|
|
|
td = TensorDataset(torch.from_numpy(inc_points).float(), torch.from_numpy(inc_labels).long())
|
|
td2 = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train).long())
|
|
td = ConcatDataset([td, td2])
|
|
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
|
pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
|
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
|
|
|
return train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S
|
|
|
|
|
|
def get_dataloaders_raw(m=1000, train_batch_size=512, test_batch_size=10):
|
|
def preprocess_data(data):
|
|
data = torch.tensor(data)#.to(DTYPE)
|
|
data = data / 255.0
|
|
data = data.permute(0, 3, 1, 2)
|
|
data = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(data)
|
|
data = nn.ReflectionPad2d(4)(data)
|
|
data = transforms.RandomCrop(size=(32, 32))(data)
|
|
data = transforms.RandomHorizontalFlip()(data)
|
|
return data
|
|
|
|
train_ds = CIFAR10(root=DATADIR, train=True, download=True)
|
|
test_ds = CIFAR10(root=DATADIR, train=False, download=True)
|
|
|
|
train_x = train_ds.data
|
|
test_x = test_ds.data
|
|
train_y = np.array(train_ds.targets)
|
|
test_y = np.array(test_ds.targets)
|
|
|
|
if m > len(test_x):
|
|
k = m - len(test_x)
|
|
mask = np.full(len(train_x), False)
|
|
mask[:k] = True
|
|
mask = mask[np.random.permutation(len(train_x))]
|
|
|
|
test_x = np.concatenate([train_x[mask], test_x])
|
|
test_y = np.concatenate([train_y[mask], test_y])
|
|
train_y = train_y[~mask]
|
|
train_x = train_x[~mask]
|
|
|
|
mask = np.full(len(test_x), False)
|
|
mask[:m] = True
|
|
mask = mask[np.random.permutation(len(test_x))]
|
|
S = np.random.choice([True, False], size=m)
|
|
|
|
attack_x = test_x[mask][S]
|
|
attack_y = test_y[mask][S]
|
|
|
|
for i in range(len(attack_y)):
|
|
while True:
|
|
c = np.random.choice(range(10))
|
|
if attack_y[i] != c:
|
|
attack_y[i] = c
|
|
break
|
|
|
|
train_x = np.concatenate([train_x, attack_x])
|
|
train_y = np.concatenate([train_y, attack_y])
|
|
|
|
train_x = preprocess_data(train_x)
|
|
test_x = preprocess_data(test_x)
|
|
attack_x = preprocess_data(attack_x)
|
|
train_y = torch.tensor(train_y)
|
|
test_y = torch.tensor(test_y)
|
|
attack_y = torch.tensor(attack_y)
|
|
|
|
train_dl = DataLoader(
|
|
TensorDataset(train_x, train_y.long()),
|
|
batch_size=train_batch_size,
|
|
shuffle=True,
|
|
drop_last=True,
|
|
num_workers=4
|
|
)
|
|
test_dl = DataLoader(
|
|
TensorDataset(test_x, test_y.long()),
|
|
batch_size=train_batch_size,
|
|
shuffle=True,
|
|
num_workers=4
|
|
)
|
|
return train_dl, test_dl, train_x, attack_x.numpy(), attack_y.numpy(), S
|
|
|
|
|
|
def evaluate_on(model, dataloader):
|
|
correct = 0
|
|
total = 0
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
|
|
for data in dataloader:
|
|
images, labels = data
|
|
images = images.to(DEVICE)
|
|
labels = labels.to(DEVICE)
|
|
|
|
wrn_outputs = model(images)
|
|
if len(wrn_outputs) == 4:
|
|
outputs = wrn_outputs[0]
|
|
else:
|
|
outputs = wrn_outputs
|
|
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
return correct, total
|
|
|
|
|
|
def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75):
|
|
#instantiate istudent
|
|
student = student_model.Model(num_classes=10).to(device)
|
|
|
|
ce_loss = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
|
student_init = copy.deepcopy(student)
|
|
student.to(device)
|
|
teacher.to(device)
|
|
teacher.eval() # Teacher set to evaluation mode
|
|
student.train() # Student to train mode
|
|
for epoch in range(epochs):
|
|
running_loss = 0.0
|
|
for inputs, labels in train_dl:
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
|
|
with torch.no_grad():
|
|
teacher_logits, _, _, _ = teacher(inputs)
|
|
|
|
# Forward pass with the student model
|
|
student_logits = student(inputs)
|
|
#Soften the student logits by applying softmax first and log() second
|
|
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
|
|
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
|
|
|
|
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
|
|
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
|
|
|
|
# Calculate the true label loss
|
|
label_loss = ce_loss(student_logits, labels)
|
|
|
|
# Weighted sum of the two losses
|
|
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item()
|
|
if epoch % 10 == 0:
|
|
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
|
|
|
|
return student_init, student
|
|
|
|
|
|
def train_no_cap(model, model_init, hp, train_dl, test_dl, optimizer, criterion, scheduler, adv_points, adv_labels, S):
|
|
best_test_set_accuracy = 0
|
|
|
|
for epoch in range(hp['epochs']):
|
|
model.train()
|
|
for i, data in enumerate(train_dl, 0):
|
|
inputs, labels = data
|
|
inputs = inputs.to(DEVICE)
|
|
labels = labels.to(DEVICE)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
wrn_outputs = model(inputs)
|
|
if len(wrn_outputs) == 4:
|
|
outputs = wrn_outputs[0]
|
|
else:
|
|
outputs = wrn_outputs
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
scheduler.step()
|
|
|
|
if epoch % 10 == 0 or epoch == hp['epochs'] - 1:
|
|
correct, total = evaluate_on(model, test_dl)
|
|
epoch_accuracy = round(100 * correct / total, 2)
|
|
scores = score_model(model_init, model, adv_points, adv_labels, S)
|
|
audits = audit_model(hp, scores)
|
|
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}% | Audit : {audits[2]}/{2*audits[1]}/{audits[3]} | p[ε < {audits[0]}] < {hp['p_value']} @ ε={hp['epsilon']}")
|
|
|
|
return best_test_set_accuracy
|
|
|
|
|
|
def load(hp, model_path, train_dl):
|
|
init_model = model_path / "init_model.pt"
|
|
trained_model = model_path / "trained_model.pt"
|
|
|
|
model = WideResNet(
|
|
d=hp["wrn_depth"],
|
|
k=hp["wrn_width"],
|
|
n_classes=10,
|
|
input_features=3,
|
|
output_features=16,
|
|
strides=[1, 1, 2, 2],
|
|
)
|
|
model = ModuleValidator.fix(model)
|
|
ModuleValidator.validate(model, strict=True)
|
|
model_init = copy.deepcopy(model)
|
|
|
|
privacy_engine = opacus.PrivacyEngine()
|
|
optimizer = optim.SGD(
|
|
model.parameters(),
|
|
lr=0.1,
|
|
momentum=0.9,
|
|
nesterov=True,
|
|
weight_decay=5e-4
|
|
)
|
|
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
|
module=model,
|
|
optimizer=optimizer,
|
|
data_loader=train_dl,
|
|
epochs=hp['epochs'],
|
|
target_epsilon=hp['epsilon'],
|
|
target_delta=hp['delta'],
|
|
max_grad_norm=hp['norm'],
|
|
)
|
|
|
|
model_init.load_state_dict(torch.load(init_model, weights_only=True))
|
|
model.load_state_dict(torch.load(trained_model, weights_only=True))
|
|
|
|
model_init = model_init.to(DEVICE)
|
|
model = model.to(DEVICE)
|
|
|
|
adv_points = np.load("data/adv_points.npy")
|
|
adv_labels = np.load("data/adv_labels.npy")
|
|
S = np.load("data/S.npy")
|
|
|
|
return model_init, model, adv_points, adv_labels, S
|
|
|
|
|
|
def train_wrn2(hp, train_dl, test_dl, adv_points, adv_labels, S):
|
|
model = wrn.WideResNet(16, 10, 4)
|
|
model = model.to(DEVICE)
|
|
ModuleValidator.validate(model, strict=True)
|
|
model_init = copy.deepcopy(model)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.SGD(
|
|
model.parameters(),
|
|
lr=0.12,
|
|
momentum=0.9,
|
|
weight_decay=1e-4
|
|
)
|
|
scheduler = MultiStepLR(
|
|
optimizer,
|
|
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
|
|
gamma=0.1
|
|
)
|
|
|
|
print(f"Training with {hp['epochs']} epochs")
|
|
|
|
if hp['epsilon'] is not None:
|
|
privacy_engine = opacus.PrivacyEngine()
|
|
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
|
module=model,
|
|
optimizer=optimizer,
|
|
data_loader=train_dl,
|
|
epochs=hp['epochs'],
|
|
target_epsilon=hp['epsilon'],
|
|
target_delta=hp['delta'],
|
|
max_grad_norm=hp['norm'],
|
|
)
|
|
|
|
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
|
|
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
|
|
|
|
with BatchMemoryManager(
|
|
data_loader=train_loader,
|
|
max_physical_batch_size=10, # 1000 ~= 9.4GB vram
|
|
optimizer=optimizer
|
|
) as memory_safe_data_loader:
|
|
best_test_set_accuracy = train_no_cap(
|
|
model,
|
|
model_init,
|
|
hp,
|
|
memory_safe_data_loader,
|
|
test_dl,
|
|
optimizer,
|
|
criterion,
|
|
scheduler,
|
|
adv_points,
|
|
adv_labels,
|
|
S,
|
|
)
|
|
else:
|
|
print("Training without differential privacy")
|
|
best_test_set_accuracy = train_no_cap(
|
|
model,
|
|
model_init,
|
|
hp,
|
|
train_dl,
|
|
test_dl,
|
|
optimizer,
|
|
criterion,
|
|
scheduler,
|
|
adv_points,
|
|
adv_labels,
|
|
S,
|
|
)
|
|
|
|
return model_init, model
|
|
|
|
|
|
def train_small(hp, train_dl, test_dl, adv_points, adv_labels, S):
|
|
model = student_model.Model(num_classes=10).to(DEVICE)
|
|
model = model.to(DEVICE)
|
|
model = ModuleValidator.fix(model)
|
|
ModuleValidator.validate(model, strict=True)
|
|
|
|
model_init = copy.deepcopy(model)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
scheduler = MultiStepLR(
|
|
optimizer,
|
|
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
|
|
gamma=0.2
|
|
)
|
|
|
|
print(f"Training raw (no distill) STUDENT with {hp['epochs']} epochs")
|
|
|
|
if hp['epsilon'] is not None:
|
|
privacy_engine = opacus.PrivacyEngine()
|
|
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
|
module=model,
|
|
optimizer=optimizer,
|
|
data_loader=train_dl,
|
|
epochs=hp['epochs'],
|
|
target_epsilon=hp['epsilon'],
|
|
target_delta=hp['delta'],
|
|
max_grad_norm=hp['norm'],
|
|
)
|
|
|
|
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
|
|
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
|
|
|
|
with BatchMemoryManager(
|
|
data_loader=train_loader,
|
|
max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
|
|
optimizer=optimizer
|
|
) as memory_safe_data_loader:
|
|
best_test_set_accuracy = train_no_cap(
|
|
model,
|
|
model_init,
|
|
hp,
|
|
memory_safe_data_loader,
|
|
test_dl,
|
|
optimizer,
|
|
criterion,
|
|
scheduler,
|
|
adv_points,
|
|
adv_labels,
|
|
S,
|
|
)
|
|
else:
|
|
print("Training without differential privacy")
|
|
best_test_set_accuracy = train_no_cap(
|
|
model,
|
|
model_init,
|
|
hp,
|
|
train_dl,
|
|
test_dl,
|
|
optimizer,
|
|
criterion,
|
|
scheduler,
|
|
adv_points,
|
|
adv_labels,
|
|
S,
|
|
)
|
|
|
|
return model_init, model
|
|
|
|
|
|
def train_fast(hp, train_dl, test_dl, train_x, adv_points, adv_labels, S):
|
|
epochs = hp['epochs']
|
|
momentum = 0.9
|
|
weight_decay = 0.256
|
|
weight_decay_bias = 0.004
|
|
ema_update_freq = 5
|
|
ema_rho = 0.99**ema_update_freq
|
|
dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32
|
|
|
|
print("=========================")
|
|
print("Training a fast model")
|
|
print("=========================")
|
|
weights = fast_model.patch_whitening(train_x[:10000, :, 4:-4, 4:-4])
|
|
model = fast_model.Model(weights, c_in=3, c_out=10, scale_out=0.125)
|
|
|
|
model.to(DEVICE)
|
|
init_model = copy.deepcopy(model)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.SGD(
|
|
model.parameters(),
|
|
lr=0.1,
|
|
momentum=0.9,
|
|
nesterov=True,
|
|
weight_decay=5e-4
|
|
)
|
|
scheduler = MultiStepLR(
|
|
optimizer,
|
|
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
|
|
gamma=0.2
|
|
)
|
|
|
|
train_no_cap(model, model_init, hp, train_dl, test_dl, optimizer, criterion, scheduler, adv_points, adv_labels, S)
|
|
return init_model, model
|
|
|
|
|
|
def train_convnet(hp, train_dl, test_dl, adv_points, adv_labels, S):
|
|
model = convnet_classifier.ConvNet()
|
|
model = model.to(DEVICE)
|
|
ModuleValidator.validate(model, strict=True)
|
|
model_init = copy.deepcopy(model)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
|
scheduler = MultiStepLR(optimizer, milestones=[10, 25], gamma=0.1)
|
|
|
|
print(f"Training with {hp['epochs']} epochs")
|
|
|
|
if hp['epsilon'] is not None:
|
|
privacy_engine = opacus.PrivacyEngine(accountant='rdp')
|
|
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
|
module=model,
|
|
optimizer=optimizer,
|
|
data_loader=train_dl,
|
|
epochs=hp['epochs'],
|
|
target_epsilon=hp['epsilon'],
|
|
target_delta=hp['delta'],
|
|
max_grad_norm=hp['norm'],
|
|
)
|
|
|
|
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
|
|
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
|
|
|
|
with BatchMemoryManager(
|
|
data_loader=train_loader,
|
|
max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
|
|
optimizer=optimizer
|
|
) as memory_safe_data_loader:
|
|
best_test_set_accuracy = train_no_cap(
|
|
model,
|
|
model_init,
|
|
hp,
|
|
memory_safe_data_loader,
|
|
test_dl,
|
|
optimizer,
|
|
criterion,
|
|
scheduler,
|
|
adv_points,
|
|
adv_labels,
|
|
S,
|
|
)
|
|
else:
|
|
print("Training without differential privacy")
|
|
best_test_set_accuracy = train_no_cap(
|
|
model,
|
|
model_init,
|
|
hp,
|
|
train_dl,
|
|
test_dl,
|
|
optimizer,
|
|
criterion,
|
|
scheduler,
|
|
adv_points,
|
|
adv_labels,
|
|
S,
|
|
)
|
|
|
|
return model_init, model
|
|
|
|
|
|
def train(hp, train_dl, test_dl, adv_points, adv_labels, S):
|
|
model = WideResNet(
|
|
d=hp["wrn_depth"],
|
|
k=hp["wrn_width"],
|
|
n_classes=10,
|
|
input_features=3,
|
|
output_features=16,
|
|
strides=[1, 1, 2, 2],
|
|
)
|
|
model = model.to(DEVICE)
|
|
model = ModuleValidator.fix(model)
|
|
ModuleValidator.validate(model, strict=True)
|
|
|
|
model_init = copy.deepcopy(model)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.SGD(
|
|
model.parameters(),
|
|
lr=0.1,
|
|
momentum=0.9,
|
|
nesterov=True,
|
|
weight_decay=5e-4
|
|
)
|
|
scheduler = MultiStepLR(
|
|
optimizer,
|
|
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
|
|
gamma=0.2
|
|
)
|
|
|
|
print(f"Training with {hp['epochs']} epochs")
|
|
|
|
if hp['epsilon'] is not None:
|
|
privacy_engine = opacus.PrivacyEngine()
|
|
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
|
module=model,
|
|
optimizer=optimizer,
|
|
data_loader=train_dl,
|
|
epochs=hp['epochs'],
|
|
target_epsilon=hp['epsilon'],
|
|
target_delta=hp['delta'],
|
|
max_grad_norm=hp['norm'],
|
|
)
|
|
|
|
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
|
|
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
|
|
|
|
with BatchMemoryManager(
|
|
data_loader=train_loader,
|
|
max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
|
|
optimizer=optimizer
|
|
) as memory_safe_data_loader:
|
|
best_test_set_accuracy = train_no_cap(
|
|
model,
|
|
model_init,
|
|
hp,
|
|
memory_safe_data_loader,
|
|
test_dl,
|
|
optimizer,
|
|
criterion,
|
|
scheduler,
|
|
adv_points,
|
|
adv_labels,
|
|
S,
|
|
)
|
|
else:
|
|
print("Training without differential privacy")
|
|
best_test_set_accuracy = train_no_cap(
|
|
model,
|
|
model_init,
|
|
hp,
|
|
train_dl,
|
|
test_dl,
|
|
optimizer,
|
|
criterion,
|
|
scheduler,
|
|
adv_points,
|
|
adv_labels,
|
|
S,
|
|
)
|
|
|
|
return model_init, model
|
|
|
|
|
|
def get_k_audit(k, scores, hp):
|
|
correct = np.sum(~scores[:k]) + np.sum(scores[-k:])
|
|
|
|
eps_lb = get_eps_audit(
|
|
hp['target_points'],
|
|
2*k,
|
|
correct,
|
|
hp['delta'],
|
|
hp['p_value']
|
|
)
|
|
return eps_lb, k, correct, len(scores)
|
|
|
|
|
|
def score_model(model_init, model_trained, adv_points, adv_labels, S):
|
|
scores = list()
|
|
criterion = nn.CrossEntropyLoss()
|
|
with torch.no_grad():
|
|
model_init.eval()
|
|
x_m = torch.from_numpy(adv_points).to(DEVICE)
|
|
y_m = torch.from_numpy(adv_labels).long().to(DEVICE)
|
|
|
|
for i in range(len(x_m)):
|
|
x_point = x_m[i].unsqueeze(0).to(DEVICE)
|
|
y_point = y_m[i].unsqueeze(0).to(DEVICE)
|
|
is_in = S[i]
|
|
|
|
wrn_outputs = model_init(x_point)
|
|
outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs
|
|
init_loss = criterion(outputs, y_point)
|
|
|
|
wrn_outputs = model_trained(x_point)
|
|
outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs
|
|
trained_loss = criterion(outputs, y_point)
|
|
|
|
scores.append(((init_loss - trained_loss).item(), is_in))
|
|
|
|
scores = sorted(scores, key=lambda x: x[0])
|
|
scores = np.array([x[1] for x in scores])
|
|
return scores
|
|
|
|
|
|
def audit_model(hp, scores):
|
|
audits = (0, 0, 0, 0)
|
|
k_schedule = np.linspace(1, hp['target_points']//2, 40)
|
|
k_schedule = np.floor(k_schedule).astype(int)
|
|
|
|
with ProcessPoolExecutor() as executor:
|
|
futures = {
|
|
executor.submit(get_k_audit, k, scores, hp): k for k in k_schedule
|
|
}
|
|
|
|
for future in as_completed(futures):
|
|
try:
|
|
eps_lb, k, correct, total = future.result()
|
|
if eps_lb > audits[0]:
|
|
audits = (eps_lb, k, correct, total)
|
|
except Exception as exc:
|
|
k = futures[future]
|
|
print(f"'k={k}' generated an exception: {exc}")
|
|
|
|
return audits
|
|
|
|
|
|
def main():
|
|
global DEVICE
|
|
global DTYPE
|
|
|
|
parser = argparse.ArgumentParser(description='WideResNet O1 audit')
|
|
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
|
|
parser.add_argument('--cuda', type=int, help='gpu index', required=False)
|
|
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
|
|
parser.add_argument('--m', type=int, help='number of target points', required=True)
|
|
parser.add_argument('--epochs', type=int, help='number of epochs', required=True)
|
|
parser.add_argument('--load', type=Path, help='number of epochs', required=False)
|
|
parser.add_argument('--studentraw', action='store_true', help='train a raw student', required=False)
|
|
parser.add_argument('--distill', action='store_true', help='train a raw student', required=False)
|
|
parser.add_argument('--fast', action='store_true', help='train the fast model', required=False)
|
|
parser.add_argument('--wrn2', action='store_true', help='Train a groupnormed wrn', required=False)
|
|
parser.add_argument('--convnet', action='store_true', help='Train a convnet', required=False)
|
|
args = parser.parse_args()
|
|
|
|
if torch.cuda.is_available() and args.cuda:
|
|
DEVICE = torch.device(f'cuda:{args.cuda}')
|
|
DTYPE = torch.float16
|
|
elif torch.cuda.is_available():
|
|
DEVICE = torch.device('cuda:0')
|
|
DTYPE = torch.float16
|
|
else:
|
|
DEVICE = torch.device('cpu')
|
|
DTYPE = torch.float32
|
|
|
|
hp = {
|
|
"target_points": args.m,
|
|
"wrn_depth": 16,
|
|
"wrn_width": 1,
|
|
"epsilon": args.epsilon,
|
|
"delta": 1e-6,
|
|
"norm": args.norm,
|
|
"batch_size": 50 if args.convnet else 4096,
|
|
"epochs": args.epochs,
|
|
"p_value": 0.05,
|
|
}
|
|
|
|
hp['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format(
|
|
int(time.time()),
|
|
hp['wrn_depth'],
|
|
hp['wrn_width'],
|
|
hp['batch_size'],
|
|
hp['epochs'],
|
|
hp['epsilon'],
|
|
hp['delta'],
|
|
hp['norm'],
|
|
))
|
|
|
|
if args.load:
|
|
train_dl, test_dl, ____, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
|
model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl)
|
|
test_dl = None
|
|
elif args.fast:
|
|
train_dl, test_dl, train_x, adv_points, adv_labels, S = get_dataloaders_raw(hp['target_points'])
|
|
model_init, model_trained = train_fast(hp, train_dl, test_dl, train_x, adv_points, adv_labels, S)
|
|
else:
|
|
train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
|
if args.wrn2:
|
|
print("=========================")
|
|
print("Training wrn2 model from meta")
|
|
print("=========================")
|
|
model_init, model_trained = train_wrn2(hp, train_dl, test_dl, adv_points, adv_labels, S)
|
|
elif args.convnet:
|
|
print("=========================")
|
|
print("Training a simple convnet")
|
|
print("=========================")
|
|
model_init, model_trained = train_convnet(hp, train_dl, test_dl, adv_points, adv_labels, S)
|
|
elif args.studentraw:
|
|
print("=========================")
|
|
print("Training a raw student model")
|
|
print("=========================")
|
|
model_init, model_trained = train_small(hp, train_dl, test_dl, adv_points, adv_labels, S)
|
|
elif args.distill:
|
|
print("=========================")
|
|
print("Training a distilled student model")
|
|
print("=========================")
|
|
teacher_init, teacher_trained = train(hp, train_dl, test_dl, adv_points, adv_labels, S)
|
|
model_init, model_trained = train_knowledge_distillation(
|
|
teacher=teacher_trained,
|
|
train_dl=train_dl,
|
|
epochs=hp['epochs'],
|
|
device=DEVICE,
|
|
learning_rate=0.001,
|
|
T=2,
|
|
soft_target_loss_weight=0.25,
|
|
ce_loss_weight=0.75,
|
|
)
|
|
else:
|
|
print("=========================")
|
|
print("Training teacher model")
|
|
print("=========================")
|
|
model_init, model_trained = train(hp, train_dl, test_dl)
|
|
|
|
np.save("data/adv_points", adv_points)
|
|
np.save("data/adv_labels", adv_labels)
|
|
np.save("data/S", S)
|
|
torch.save(model_init.state_dict(), "data/init_model.pt")
|
|
torch.save(model_trained.state_dict(), "data/trained_model.pt")
|
|
|
|
# scores = score_model(model_init, model_trained, adv_points, adv_labels, S)
|
|
# audits = audit_model(hp, scores)
|
|
|
|
# print(f"Audit total: {audits[2]}/{2*audits[1]}/{audits[3]}")
|
|
# print(f"p[ε < {audits[0]}] < {hp['p_value']} for true epsilon {hp['epsilon']}")
|
|
|
|
if test_dl is not None:
|
|
correct, total = evaluate_on(model_init, test_dl)
|
|
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
|
|
correct, total = evaluate_on(model_trained, test_dl)
|
|
print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|