Compare commits

...

12 commits

7 changed files with 1156 additions and 94 deletions

View file

@ -7,25 +7,34 @@ import torch
import torch.nn as nn
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, Subset, TensorDataset
from torch.utils.data import DataLoader, Subset, TensorDataset, ConcatDataset
import torch.nn.functional as F
from pathlib import Path
from torchvision import transforms
from torchvision.datasets import CIFAR10
import pytorch_lightning as pl
import opacus
import random
from tqdm import tqdm
from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager
from concurrent.futures import ProcessPoolExecutor, as_completed
from WideResNet import WideResNet
from equations import get_eps_audit
import student_model
import fast_model
import convnet_classifier
import wrn
import warnings
warnings.filterwarnings("ignore")
DEVICE = None
DTYPE = None
DATADIR = Path("./data")
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10):
seed = np.random.randint(0, 1e9)
seed ^= int(time.time())
pl.seed_everything(seed)
@ -44,37 +53,129 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
datadir = Path("./data")
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
# Original dataset
x = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms
p = np.random.permutation(len(train_ds))
x_train = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds)))
y_train = np.array(train_ds.targets).astype(np.int64)
# Choose m points to randomly exclude at chance
S = np.full(len(train_ds), True)
S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
x = np.stack(test_ds[i][0].numpy() for i in range(len(test_ds))) # Applies transforms
y = np.array(test_ds.targets).astype(np.int64)
# Pull points from training set when m > test set
if m > len(x):
k = m - len(x)
mask = np.full(len(x_train), False)
mask[:k] = True
x = np.concatenate([x_train[mask], x])
y = np.concatenate([y_train[mask], y])
x_train = x_train[~mask]
y_train = y_train[~mask]
# Store the m points which could have been included/excluded
mask = np.full(len(train_ds), False)
mask = np.full(len(x), False)
mask[:m] = True
mask = mask[p]
mask = mask[np.random.permutation(len(x))]
x_m = x[mask] # These are the points being guessed at
y_m = np.array(train_ds.targets)[mask].astype(np.int64)
S_m = S[p][mask] # Ground truth of inclusion/exclusion for x_m
adv_points = x[mask]
adv_labels = y[mask]
# Remove excluded points from dataset
x_in = x[S[p]]
y_in = np.array(train_ds.targets).astype(np.int64)
y_in = y_in[S[p]]
# Mislabel inclusion/exclusion examples intentionally!
for i in range(len(adv_labels)):
while True:
c = np.random.choice(range(10))
if adv_labels[i] != c:
adv_labels[i] = c
break
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
# Choose m points to randomly exclude at chance
S = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
assert len(adv_points) == m
inc_points = adv_points[S]
inc_labels = adv_labels[S]
td = TensorDataset(torch.from_numpy(inc_points).float(), torch.from_numpy(inc_labels).long())
td2 = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train).long())
td = ConcatDataset([td, td2])
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
return train_dl, test_dl, x_in, x_m, y_m, S_m
return train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S
def get_dataloaders_raw(m=1000, train_batch_size=512, test_batch_size=10):
def preprocess_data(data):
data = torch.tensor(data)#.to(DTYPE)
data = data / 255.0
data = data.permute(0, 3, 1, 2)
data = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(data)
data = nn.ReflectionPad2d(4)(data)
data = transforms.RandomCrop(size=(32, 32))(data)
data = transforms.RandomHorizontalFlip()(data)
return data
train_ds = CIFAR10(root=DATADIR, train=True, download=True)
test_ds = CIFAR10(root=DATADIR, train=False, download=True)
train_x = train_ds.data
test_x = test_ds.data
train_y = np.array(train_ds.targets)
test_y = np.array(test_ds.targets)
if m > len(test_x):
k = m - len(test_x)
mask = np.full(len(train_x), False)
mask[:k] = True
mask = mask[np.random.permutation(len(train_x))]
test_x = np.concatenate([train_x[mask], test_x])
test_y = np.concatenate([train_y[mask], test_y])
train_y = train_y[~mask]
train_x = train_x[~mask]
mask = np.full(len(test_x), False)
mask[:m] = True
mask = mask[np.random.permutation(len(test_x))]
S = np.random.choice([True, False], size=m)
attack_x = test_x[mask][S]
attack_y = test_y[mask][S]
for i in range(len(attack_y)):
while True:
c = np.random.choice(range(10))
if attack_y[i] != c:
attack_y[i] = c
break
train_x = np.concatenate([train_x, attack_x])
train_y = np.concatenate([train_y, attack_y])
train_x = preprocess_data(train_x)
test_x = preprocess_data(test_x)
attack_x = preprocess_data(attack_x)
train_y = torch.tensor(train_y)
test_y = torch.tensor(test_y)
attack_y = torch.tensor(attack_y)
train_dl = DataLoader(
TensorDataset(train_x, train_y.long()),
batch_size=train_batch_size,
shuffle=True,
drop_last=True,
num_workers=4
)
test_dl = DataLoader(
TensorDataset(test_x, test_y.long()),
batch_size=train_batch_size,
shuffle=True,
num_workers=4
)
return train_dl, test_dl, train_x, attack_x.numpy(), attack_y.numpy(), S
def evaluate_on(model, dataloader):
@ -90,7 +191,11 @@ def evaluate_on(model, dataloader):
labels = labels.to(DEVICE)
wrn_outputs = model(images)
outputs = wrn_outputs[0]
if len(wrn_outputs) == 4:
outputs = wrn_outputs[0]
else:
outputs = wrn_outputs
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
@ -98,7 +203,54 @@ def evaluate_on(model, dataloader):
return correct, total
def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75):
#instantiate istudent
student = student_model.Model(num_classes=10).to(device)
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
student_init = copy.deepcopy(student)
student.to(device)
teacher.to(device)
teacher.eval() # Teacher set to evaluation mode
student.train() # Student to train mode
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_dl:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
with torch.no_grad():
teacher_logits, _, _, _ = teacher(inputs)
# Forward pass with the student model
student_logits = student(inputs)
#Soften the student logits by applying softmax first and log() second
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)
# Weighted sum of the two losses
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
running_loss += loss.item()
if epoch % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
return student_init, student
def train_no_cap(model, model_init, hp, train_dl, test_dl, optimizer, criterion, scheduler, adv_points, adv_labels, S):
best_test_set_accuracy = 0
for epoch in range(hp['epochs']):
@ -111,7 +263,10 @@ def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
optimizer.zero_grad()
wrn_outputs = model(inputs)
outputs = wrn_outputs[0]
if len(wrn_outputs) == 4:
outputs = wrn_outputs[0]
else:
outputs = wrn_outputs
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
@ -121,12 +276,304 @@ def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
if epoch % 10 == 0 or epoch == hp['epochs'] - 1:
correct, total = evaluate_on(model, test_dl)
epoch_accuracy = round(100 * correct / total, 2)
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%")
scores = score_model(model_init, model, adv_points, adv_labels, S)
audits = audit_model(hp, scores)
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}% | Audit : {audits[2]}/{2*audits[1]}/{audits[3]} | p[ε < {audits[0]}] < {hp['p_value']} @ ε={hp['epsilon']}")
return best_test_set_accuracy
def train(hp, train_dl, test_dl):
def load(hp, model_path, train_dl):
init_model = model_path / "init_model.pt"
trained_model = model_path / "trained_model.pt"
model = WideResNet(
d=hp["wrn_depth"],
k=hp["wrn_width"],
n_classes=10,
input_features=3,
output_features=16,
strides=[1, 1, 2, 2],
)
model = ModuleValidator.fix(model)
ModuleValidator.validate(model, strict=True)
model_init = copy.deepcopy(model)
privacy_engine = opacus.PrivacyEngine()
optimizer = optim.SGD(
model.parameters(),
lr=0.1,
momentum=0.9,
nesterov=True,
weight_decay=5e-4
)
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dl,
epochs=hp['epochs'],
target_epsilon=hp['epsilon'],
target_delta=hp['delta'],
max_grad_norm=hp['norm'],
)
model_init.load_state_dict(torch.load(init_model, weights_only=True))
model.load_state_dict(torch.load(trained_model, weights_only=True))
model_init = model_init.to(DEVICE)
model = model.to(DEVICE)
adv_points = np.load("data/adv_points.npy")
adv_labels = np.load("data/adv_labels.npy")
S = np.load("data/S.npy")
return model_init, model, adv_points, adv_labels, S
def train_wrn2(hp, train_dl, test_dl, adv_points, adv_labels, S):
model = wrn.WideResNet(16, 10, 4)
model = model.to(DEVICE)
ModuleValidator.validate(model, strict=True)
model_init = copy.deepcopy(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
model.parameters(),
lr=0.12,
momentum=0.9,
weight_decay=1e-4
)
scheduler = MultiStepLR(
optimizer,
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
gamma=0.1
)
print(f"Training with {hp['epochs']} epochs")
if hp['epsilon'] is not None:
privacy_engine = opacus.PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dl,
epochs=hp['epochs'],
target_epsilon=hp['epsilon'],
target_delta=hp['delta'],
max_grad_norm=hp['norm'],
)
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
with BatchMemoryManager(
data_loader=train_loader,
max_physical_batch_size=10, # 1000 ~= 9.4GB vram
optimizer=optimizer
) as memory_safe_data_loader:
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
memory_safe_data_loader,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
else:
print("Training without differential privacy")
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
train_dl,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
return model_init, model
def train_small(hp, train_dl, test_dl, adv_points, adv_labels, S):
model = student_model.Model(num_classes=10).to(DEVICE)
model = model.to(DEVICE)
model = ModuleValidator.fix(model)
ModuleValidator.validate(model, strict=True)
model_init = copy.deepcopy(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = MultiStepLR(
optimizer,
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
gamma=0.2
)
print(f"Training raw (no distill) STUDENT with {hp['epochs']} epochs")
if hp['epsilon'] is not None:
privacy_engine = opacus.PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dl,
epochs=hp['epochs'],
target_epsilon=hp['epsilon'],
target_delta=hp['delta'],
max_grad_norm=hp['norm'],
)
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
with BatchMemoryManager(
data_loader=train_loader,
max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
optimizer=optimizer
) as memory_safe_data_loader:
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
memory_safe_data_loader,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
else:
print("Training without differential privacy")
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
train_dl,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
return model_init, model
def train_fast(hp, train_dl, test_dl, train_x, adv_points, adv_labels, S):
epochs = hp['epochs']
momentum = 0.9
weight_decay = 0.256
weight_decay_bias = 0.004
ema_update_freq = 5
ema_rho = 0.99**ema_update_freq
dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32
print("=========================")
print("Training a fast model")
print("=========================")
weights = fast_model.patch_whitening(train_x[:10000, :, 4:-4, 4:-4])
model = fast_model.Model(weights, c_in=3, c_out=10, scale_out=0.125)
model.to(DEVICE)
init_model = copy.deepcopy(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
model.parameters(),
lr=0.1,
momentum=0.9,
nesterov=True,
weight_decay=5e-4
)
scheduler = MultiStepLR(
optimizer,
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
gamma=0.2
)
train_no_cap(model, model_init, hp, train_dl, test_dl, optimizer, criterion, scheduler, adv_points, adv_labels, S)
return init_model, model
def train_convnet(hp, train_dl, test_dl, adv_points, adv_labels, S):
model = convnet_classifier.ConvNet()
model = model.to(DEVICE)
ModuleValidator.validate(model, strict=True)
model_init = copy.deepcopy(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = MultiStepLR(optimizer, milestones=[10, 25], gamma=0.1)
print(f"Training with {hp['epochs']} epochs")
if hp['epsilon'] is not None:
privacy_engine = opacus.PrivacyEngine(accountant='rdp')
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dl,
epochs=hp['epochs'],
target_epsilon=hp['epsilon'],
target_delta=hp['delta'],
max_grad_norm=hp['norm'],
)
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
with BatchMemoryManager(
data_loader=train_loader,
max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
optimizer=optimizer
) as memory_safe_data_loader:
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
memory_safe_data_loader,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
else:
print("Training without differential privacy")
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
train_dl,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
return model_init, model
def train(hp, train_dl, test_dl, adv_points, adv_labels, S):
model = WideResNet(
d=hp["wrn_depth"],
k=hp["wrn_width"],
@ -179,56 +626,136 @@ def train(hp, train_dl, test_dl):
) as memory_safe_data_loader:
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
memory_safe_data_loader,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
else:
print("Training without differential privacy")
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
train_dl,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
return model_init, model
def get_k_audit(k, scores, hp):
correct = np.sum(~scores[:k]) + np.sum(scores[-k:])
eps_lb = get_eps_audit(
hp['target_points'],
2*k,
correct,
hp['delta'],
hp['p_value']
)
return eps_lb, k, correct, len(scores)
def score_model(model_init, model_trained, adv_points, adv_labels, S):
scores = list()
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
model_init.eval()
x_m = torch.from_numpy(adv_points).to(DEVICE)
y_m = torch.from_numpy(adv_labels).long().to(DEVICE)
for i in range(len(x_m)):
x_point = x_m[i].unsqueeze(0).to(DEVICE)
y_point = y_m[i].unsqueeze(0).to(DEVICE)
is_in = S[i]
wrn_outputs = model_init(x_point)
outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs
init_loss = criterion(outputs, y_point)
wrn_outputs = model_trained(x_point)
outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs
trained_loss = criterion(outputs, y_point)
scores.append(((init_loss - trained_loss).item(), is_in))
scores = sorted(scores, key=lambda x: x[0])
scores = np.array([x[1] for x in scores])
return scores
def audit_model(hp, scores):
audits = (0, 0, 0, 0)
k_schedule = np.linspace(1, hp['target_points']//2, 40)
k_schedule = np.floor(k_schedule).astype(int)
with ProcessPoolExecutor() as executor:
futures = {
executor.submit(get_k_audit, k, scores, hp): k for k in k_schedule
}
for future in as_completed(futures):
try:
eps_lb, k, correct, total = future.result()
if eps_lb > audits[0]:
audits = (eps_lb, k, correct, total)
except Exception as exc:
k = futures[future]
print(f"'k={k}' generated an exception: {exc}")
return audits
def main():
global DEVICE
global DTYPE
parser = argparse.ArgumentParser(description='WideResNet O1 audit')
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
parser.add_argument('--cuda', type=int, help='gpu index', required=False)
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
parser.add_argument('--m', type=int, help='number of target points', required=True)
parser.add_argument('--epochs', type=int, help='number of epochs', required=True)
parser.add_argument('--load', type=Path, help='number of epochs', required=False)
parser.add_argument('--studentraw', action='store_true', help='train a raw student', required=False)
parser.add_argument('--distill', action='store_true', help='train a raw student', required=False)
parser.add_argument('--fast', action='store_true', help='train the fast model', required=False)
parser.add_argument('--wrn2', action='store_true', help='Train a groupnormed wrn', required=False)
parser.add_argument('--convnet', action='store_true', help='Train a convnet', required=False)
args = parser.parse_args()
if torch.cuda.is_available() and args.cuda:
DEVICE = torch.device(f'cuda:{args.cuda}')
DTYPE = torch.float16
elif torch.cuda.is_available():
DEVICE = torch.device('cuda:0')
DTYPE = torch.float16
else:
DEVICE = torch.device('cpu')
DTYPE = torch.float32
hp = {
"target_points": args.m,
"wrn_depth": 16,
"wrn_width": 1,
"epsilon": args.epsilon,
"delta": 1e-5,
"delta": 1e-6,
"norm": args.norm,
"batch_size": 4096,
"epochs": 100,
"k+": 300,
"k-": 300,
"batch_size": 50 if args.convnet else 4096,
"epochs": args.epochs,
"p_value": 0.05,
}
@ -243,58 +770,68 @@ def main():
hp['norm'],
))
train_dl, test_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size'])
print(f"len train: {len(train_dl)}")
print(f"Got vector Sm: {S_m.shape}, sum={np.sum(S_m)}")
print(f"Got x_in: {x_in.shape}")
print(f"Got x_m: {x_m.shape}")
print(f"Got y_m: {y_m.shape}")
if args.load:
train_dl, test_dl, ____, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size'])
model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl)
test_dl = None
elif args.fast:
train_dl, test_dl, train_x, adv_points, adv_labels, S = get_dataloaders_raw(hp['target_points'])
model_init, model_trained = train_fast(hp, train_dl, test_dl, train_x, adv_points, adv_labels, S)
else:
train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size'])
if args.wrn2:
print("=========================")
print("Training wrn2 model from meta")
print("=========================")
model_init, model_trained = train_wrn2(hp, train_dl, test_dl, adv_points, adv_labels, S)
elif args.convnet:
print("=========================")
print("Training a simple convnet")
print("=========================")
model_init, model_trained = train_convnet(hp, train_dl, test_dl, adv_points, adv_labels, S)
elif args.studentraw:
print("=========================")
print("Training a raw student model")
print("=========================")
model_init, model_trained = train_small(hp, train_dl, test_dl, adv_points, adv_labels, S)
elif args.distill:
print("=========================")
print("Training a distilled student model")
print("=========================")
teacher_init, teacher_trained = train(hp, train_dl, test_dl, adv_points, adv_labels, S)
model_init, model_trained = train_knowledge_distillation(
teacher=teacher_trained,
train_dl=train_dl,
epochs=hp['epochs'],
device=DEVICE,
learning_rate=0.001,
T=2,
soft_target_loss_weight=0.25,
ce_loss_weight=0.75,
)
else:
print("=========================")
print("Training teacher model")
print("=========================")
model_init, model_trained = train(hp, train_dl, test_dl)
model_init, model_trained = train(hp, train_dl, test_dl)
np.save("data/adv_points", adv_points)
np.save("data/adv_labels", adv_labels)
np.save("data/S", S)
torch.save(model_init.state_dict(), "data/init_model.pt")
torch.save(model_trained.state_dict(), "data/trained_model.pt")
# torch.save(model_init.state_dict(), "data/init_model.pt")
# torch.save(model_trained.state_dict(), "data/trained_model.pt")
# scores = score_model(model_init, model_trained, adv_points, adv_labels, S)
# audits = audit_model(hp, scores)
scores = list()
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
model_init.eval()
x_m = torch.from_numpy(x_m).to(DEVICE)
y_m = torch.from_numpy(y_m).long().to(DEVICE)
# print(f"Audit total: {audits[2]}/{2*audits[1]}/{audits[3]}")
# print(f"p[ε < {audits[0]}] < {hp['p_value']} for true epsilon {hp['epsilon']}")
for i in range(len(x_m)):
x_point = x_m[i].unsqueeze(0)
y_point = y_m[i].unsqueeze(0)
is_in = S_m[i]
init_loss = criterion(model_init(x_point)[0], y_point)
trained_loss = criterion(model_trained(x_point)[0], y_point)
scores.append(((init_loss - trained_loss).item(), is_in))
scores = sorted(scores, key=lambda x: x[0])
scores = np.array([x[1] for x in scores])
print(scores[:10])
correct = np.sum(~scores[:hp['k-']]) + np.sum(scores[-hp['k+']:])
total = len(scores)
eps_lb = get_eps_audit(
hp['target_points'],
hp['k+'] + hp['k-'],
correct,
hp['delta'],
hp['p_value']
)
print(f"Audit total: {correct}/{total} = {round(correct/total*100, 2)}")
print(f"p[ε < {eps_lb}] < {hp['p_value']}")
correct, total = evaluate_on(model_init, train_dl)
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
correct, total = evaluate_on(model_trained, test_dl)
print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
if test_dl is not None:
correct, total = evaluate_on(model_init, test_dl)
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
correct, total = evaluate_on(model_trained, test_dl)
print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
if __name__ == '__main__':

View file

@ -0,0 +1,51 @@
# Name: Peng Cheng
# UIN: 674792652
#
# Code adapted from:
# https://github.com/jameschengpeng/PyTorch-CNN-on-CIFAR10
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=48, kernel_size=(3,3), padding=(1,1))
self.conv2 = nn.Conv2d(in_channels=48, out_channels=96, kernel_size=(3,3), padding=(1,1))
self.conv3 = nn.Conv2d(in_channels=96, out_channels=192, kernel_size=(3,3), padding=(1,1))
self.conv4 = nn.Conv2d(in_channels=192, out_channels=256, kernel_size=(3,3), padding=(1,1))
self.pool = nn.MaxPool2d(2,2)
self.fc1 = nn.Linear(in_features=8*8*256, out_features=512)
self.fc2 = nn.Linear(in_features=512, out_features=64)
self.Dropout = nn.Dropout(0.25)
self.fc3 = nn.Linear(in_features=64, out_features=10)
def forward(self, x):
x = F.relu(self.conv1(x)) #32*32*48
x = F.relu(self.conv2(x)) #32*32*96
x = self.pool(x) #16*16*96
x = self.Dropout(x)
x = F.relu(self.conv3(x)) #16*16*192
x = F.relu(self.conv4(x)) #16*16*256
x = self.pool(x) # 8*8*256
x = self.Dropout(x)
x = x.view(-1, 8*8*256) # reshape x
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.Dropout(x)
x = self.fc3(x)
return x

View file

@ -49,5 +49,4 @@ def get_eps_audit(m, r, v, delta, p):
if __name__ == '__main__':
x = 100
print(f"For m=100 r=100 v=100 p=0.05: {get_eps_audit(x, x, x, 1e-5, 0.05)}")
print(get_eps_audit(1000, 600, 600, 1e-5, 0.05))

141
one_run_audit/fast_model.py Normal file
View file

@ -0,0 +1,141 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def label_smoothing_loss(inputs, targets, alpha):
log_probs = torch.nn.functional.log_softmax(inputs, dim=1, _stacklevel=5)
kl = -log_probs.mean(dim=1)
xent = torch.nn.functional.nll_loss(log_probs, targets, reduction="none")
loss = (1 - alpha) * xent + alpha * kl
return loss
class GhostBatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, num_splits, **kw):
super().__init__(num_features, **kw)
running_mean = torch.zeros(num_features * num_splits)
running_var = torch.ones(num_features * num_splits)
self.weight.requires_grad = False
self.num_splits = num_splits
self.register_buffer("running_mean", running_mean)
self.register_buffer("running_var", running_var)
def train(self, mode=True):
if (self.training is True) and (mode is False):
# lazily collate stats when we are going to use them
self.running_mean = torch.mean(
self.running_mean.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
self.running_var = torch.mean(
self.running_var.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
return super().train(mode)
def forward(self, input):
n, c, h, w = input.shape
if self.training or not self.track_running_stats:
assert n % self.num_splits == 0, f"Batch size ({n}) must be divisible by num_splits ({self.num_splits}) of GhostBatchNorm"
return F.batch_norm(
input.view(-1, c * self.num_splits, h, w),
self.running_mean,
self.running_var,
self.weight.repeat(self.num_splits),
self.bias.repeat(self.num_splits),
True,
self.momentum,
self.eps,
).view(n, c, h, w)
else:
return F.batch_norm(
input,
self.running_mean[: self.num_features],
self.running_var[: self.num_features],
self.weight,
self.bias,
False,
self.momentum,
self.eps,
)
def conv_bn_relu(c_in, c_out, kernel_size=(3, 3), padding=(1, 1)):
return nn.Sequential(
nn.Conv2d(c_in, c_out, kernel_size=kernel_size, padding=padding, bias=False),
GhostBatchNorm(c_out, num_splits=16),
nn.CELU(alpha=0.3),
)
def conv_pool_norm_act(c_in, c_out):
return nn.Sequential(
nn.Conv2d(c_in, c_out, kernel_size=(3, 3), padding=(1, 1), bias=False),
nn.MaxPool2d(kernel_size=2, stride=2),
GhostBatchNorm(c_out, num_splits=16),
nn.CELU(alpha=0.3),
)
def patch_whitening(data, patch_size=(3, 3)):
# Compute weights from data such that
# torch.std(F.conv2d(data, weights), dim=(2, 3))
# is close to 1.
h, w = patch_size
c = data.size(1)
patches = data.unfold(2, h, 1).unfold(3, w, 1)
patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)
n, c, h, w = patches.shape
X = patches.reshape(n, c * h * w)
X = X / (X.size(0) - 1) ** 0.5
covariance = X.t() @ X
eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
eigenvalues = eigenvalues.flip(0)
eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)
return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)
class ResNetBagOfTricks(nn.Module):
def __init__(self, first_layer_weights, c_in, c_out, scale_out):
super().__init__()
c = first_layer_weights.size(0)
conv1 = nn.Conv2d(c_in, c, kernel_size=(3, 3), padding=(1, 1), bias=False)
conv1.weight.data = first_layer_weights
conv1.weight.requires_grad = False
self.conv1 = conv1
self.conv2 = conv_bn_relu(c, 64, kernel_size=(1, 1), padding=0)
self.conv3 = conv_pool_norm_act(64, 128)
self.conv4 = conv_bn_relu(128, 128)
self.conv5 = conv_bn_relu(128, 128)
self.conv6 = conv_pool_norm_act(128, 256)
self.conv7 = conv_pool_norm_act(256, 512)
self.conv8 = conv_bn_relu(512, 512)
self.conv9 = conv_bn_relu(512, 512)
self.pool10 = nn.MaxPool2d(kernel_size=4, stride=4)
self.linear11 = nn.Linear(512, c_out, bias=False)
self.scale_out = scale_out
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x + self.conv5(self.conv4(x))
x = self.conv6(x)
x = self.conv7(x)
x = x + self.conv9(self.conv8(x))
x = self.pool10(x)
x = x.reshape(x.size(0), x.size(1))
x = self.linear11(x)
x = self.scale_out * x
return x
Model = ResNetBagOfTricks

View file

@ -1,21 +1,94 @@
import time
import math
import concurrent.futures
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from equations import get_eps_audit
delta = 1e-5
p_value = 0.05
def compute_y(x_values, p, delta, proportion_correct, key):
return key, [get_eps_audit(x, x, math.floor(x *proportion_correct), delta, p) for x in x_values]
x_values = np.floor((1.5)**np.arange(30)).astype(int)
x_values = np.concatenate([x_values[x_values < 60000], [60000]])
y_values = [get_eps_audit(x, x, x, delta, p_value) for x in tqdm(x_values)]
plt.xscale('log')
plt.plot(x_values, y_values, marker='o')
plt.xlabel("Number of samples guessed correctly")
plt.ylabel("ε value audited")
plt.title("Maximum possible ε from audit")
def get_plots():
final_values = dict()
mul = 1.5 #1.275 #1.5
max = 60000 #2000 #60000
# 5. Save the plot as a PNG
plt.savefig("/dev/shm/my_plot.png", dpi=300, bbox_inches='tight')
x_values = np.floor((mul)**np.arange(30)).astype(int)
x_values = np.concatenate([x_values[x_values < max], [max]])
with concurrent.futures.ProcessPoolExecutor(max_workers=16) as executor:
start_time = time.time()
futures = [
executor.submit(compute_y, x_values, 0.05, 0.0, 1.0, "y11"),
executor.submit(compute_y, x_values, 0.05, 1e-6, 1.0, "y12"),
executor.submit(compute_y, x_values, 0.05, 1e-4, 1.0, "y13"),
executor.submit(compute_y, x_values, 0.05, 1e-2, 1.0, "y14"),
executor.submit(compute_y, x_values, 0.01, 0.0, 1.0, "y21"),
executor.submit(compute_y, x_values, 0.01, 1e-6, 1.0, "y22"),
executor.submit(compute_y, x_values, 0.01, 1e-4, 1.0, "y23"),
executor.submit(compute_y, x_values, 0.01, 1e-2, 1.0, "y24"),
executor.submit(compute_y, x_values, 0.05, 0.0, 0.9, "y31"),
executor.submit(compute_y, x_values, 0.05, 1e-6, 0.9, "y32"),
executor.submit(compute_y, x_values, 0.05, 1e-4, 0.9, "y33"),
executor.submit(compute_y, x_values, 0.05, 1e-2, 0.9, "y34"),
executor.submit(compute_y, x_values, 0.01, 0.0, 0.9, "y41"),
executor.submit(compute_y, x_values, 0.01, 1e-6, 0.9, "y42"),
executor.submit(compute_y, x_values, 0.01, 1e-4, 0.9, "y43"),
executor.submit(compute_y, x_values, 0.01, 1e-2, 0.9, "y44"),
]
for future in concurrent.futures.as_completed(futures):
k, v = future.result()
final_values[k] = v
print(f"Took: {time.time()-start_time}s")
return final_values, x_values
def plot_to(value_set, x_values, title, fig_name):
plt.xscale('log')
plt.plot(x_values, value_set[0], marker='o', label='δ=0')
plt.plot(x_values, value_set[1], marker='o', label='δ=1e-6')
plt.plot(x_values, value_set[2], marker='o', label='δ=1e-4')
plt.plot(x_values, value_set[3], marker='o', label='δ=1e-2')
plt.xlabel("Number of samples attacked")
plt.ylabel("Maximum ε lower-bound from audit")
plt.title(title)
plt.legend()
plt.savefig(fig_name, dpi=300, bbox_inches='tight')
def main():
final_values, x_values = get_plots()
plot_to(
[final_values[f"y1{i}"] for i in range(1,5)],
x_values,
"Maximum ε audit with p-value=0.05 and 100% MIA accuracy",
"/dev/shm/plot_05_100.png"
)
plot_to(
[final_values[f"y1{i}"] for i in range(1,5)],
x_values,
"Maximum ε audit with p-value=0.01 and 100% MIA accuracy",
"/dev/shm/plot_01_100.png"
)
plot_to(
[final_values[f"y1{i}"] for i in range(1,5)],
x_values,
"Maximum ε audit with p-value=0.05 and 90% MIA accuracy",
"/dev/shm/plot_05_90.png"
)
plot_to(
[final_values[f"y1{i}"] for i in range(1,5)],
x_values,
"Maximum ε audit with p-value=0.01 and 90% MIA accuracy"
"/dev/shm/plot_01_90.png"
)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,29 @@
import torch
import torch.nn as nn
# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedLightNNCosine(nn.Module):
def __init__(self, num_classes=10):
super(ModifiedLightNNCosine, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
return x
Model = ModifiedLightNNCosine

232
one_run_audit/wrn.py Normal file
View file

@ -0,0 +1,232 @@
"""
Adapted from:
https://github.com/facebookresearch/tan/blob/main/src/models/wideresnet.py
"""
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Adapted from timm:
https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class L2Norm(nn.Module):
def forward(self, x):
return x / x.norm(p=2, dim=1, keepdim=True)
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, nb_groups, order):
super(BasicBlock, self).__init__()
self.order = order
self.bn1 = nn.GroupNorm(nb_groups, in_planes) if nb_groups else nn.Identity()
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1
)
self.bn2 = nn.GroupNorm(nb_groups, out_planes) if nb_groups else nn.Identity()
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(
out_planes, out_planes, kernel_size=3, stride=1, padding=1
)
self.equalInOut = in_planes == out_planes
self.bnShortcut = (
(not self.equalInOut)
and nb_groups
and nn.GroupNorm(nb_groups, in_planes)
or (not self.equalInOut)
and nn.Identity()
or None
)
self.convShortcut = (
(not self.equalInOut)
and nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, padding=0
)
) or None
def forward(self, x):
skip = x
assert self.order in [0, 1, 2, 3]
if self.order == 0: # DM accuracy good
if not self.equalInOut:
skip = self.convShortcut(self.bnShortcut(self.relu1(x)))
out = self.conv1(self.bn1(self.relu1(x)))
out = self.conv2(self.bn2(self.relu2(out)))
elif self.order == 1: # classic accuracy bad
if not self.equalInOut:
skip = self.convShortcut(self.relu1(self.bnShortcut(x)))
out = self.conv1(self.relu1(self.bn1(x)))
out = self.conv2(self.relu2(self.bn2(out)))
elif self.order == 2: # DM IN RESIDUAL, normal other
if not self.equalInOut:
skip = self.convShortcut(self.bnShortcut(self.relu1(x)))
out = self.conv1(self.relu1(self.bn1(x)))
out = self.conv2(self.relu2(self.bn2(out)))
elif self.order == 3: # normal in residualm DM in others
if not self.equalInOut:
skip = self.convShortcut(self.relu1(self.bnShortcut(x)))
out = self.conv1(self.bn1(self.relu1(x)))
out = self.conv2(self.bn2(self.relu2(out)))
return torch.add(skip, out)
class NetworkBlock(nn.Module):
def __init__(
self, nb_layers, in_planes, out_planes, block, stride, nb_groups, order
):
super(NetworkBlock, self).__init__()
self.layer = self._make_layer(
block, in_planes, out_planes, nb_layers, stride, nb_groups, order
)
def _make_layer(
self, block, in_planes, out_planes, nb_layers, stride, nb_groups, order
):
layers = []
for i in range(int(nb_layers)):
layers.append(
block(
i == 0 and in_planes or out_planes,
out_planes,
i == 0 and stride or 1,
nb_groups,
order,
)
)
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
class WideResNet(nn.Module):
def __init__(
self,
depth,
feat_dim,
#num_classes,
widen_factor=1,
nb_groups=16,
init=0,
order1=0,
order2=0,
):
if order1 == 0:
print("order1=0: In the blocks: like in DM, BN on top of relu")
if order1 == 1:
print("order1=1: In the blocks: not like in DM, relu on top of BN")
if order1 == 2:
print(
"order1=2: In the blocks: BN on top of relu in residual (DM), relu on top of BN ortherplace (clqssique)"
)
if order1 == 3:
print(
"order1=3: In the blocks: relu on top of BN in residual (classic), BN on top of relu otherplace (DM)"
)
if order2 == 0:
print("order2=0: outside the blocks: like in DM, BN on top of relu")
if order2 == 1:
print("order2=1: outside the blocks: not like in DM, relu on top of BN")
super(WideResNet, self).__init__()
nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
assert (depth - 4) % 6 == 0
n = (depth - 4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1)
# 1st block
self.block1 = NetworkBlock(
n, nChannels[0], nChannels[1], block, 1, nb_groups, order1
)
# 2nd block
self.block2 = NetworkBlock(
n, nChannels[1], nChannels[2], block, 2, nb_groups, order1
)
# 3rd block
self.block3 = NetworkBlock(
n, nChannels[2], nChannels[3], block, 2, nb_groups, order1
)
# global average pooling and classifier
"""
self.bn1 = nn.GroupNorm(nb_groups, nChannels[3]) if nb_groups else nn.Identity()
self.relu = nn.ReLU()
self.fc = nn.Linear(nChannels[3], num_classes)
"""
self.nChannels = nChannels[3]
self.block4 = nn.Sequential(
nn.Flatten(),
nn.Linear(256 * 8 * 8, 4096, bias=False), # 256 * 6 * 6 if 224 * 224
nn.GroupNorm(16, 4096),
nn.ReLU(inplace=True),
)
# fc7
self.block5 = nn.Sequential(
nn.Linear(4096, 4096, bias=False),
nn.GroupNorm(16, 4096),
nn.ReLU(inplace=True),
)
# fc8
self.block6 =nn.Sequential(
nn.Linear(4096, feat_dim),
L2Norm(),
)
if init == 0: # as in Deep Mind's paper
for m in self.modules():
if isinstance(m, nn.Conv2d):
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight)
s = 1 / (max(fan_in, 1)) ** 0.5
nn.init.trunc_normal_(m.weight, std=s)
m.bias.data.zero_()
elif isinstance(m, nn.GroupNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight)
s = 1 / (max(fan_in, 1)) ** 0.5
nn.init.trunc_normal_(m.weight, std=s)
#m.bias.data.zero_()
if init == 1: # old version
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, nn.GroupNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
self.order2 = order2
def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.block4(out)
out = self.block5(out)
out = self.block6(out)
if out.ndim == 4:
out = out.mean(dim=-1)
if out.ndim == 3:
out = out.mean(dim=-1)
#out = self.bn1(self.relu(out)) if self.order2 == 0 else self.relu(self.bn1(out))
#out = F.avg_pool2d(out, 8)
#out = out.view(-1, self.nChannels)
return out#self.fc(out)