Compare commits

..

12 commits

16 changed files with 1153 additions and 541 deletions

View file

@ -1,143 +0,0 @@
import torch
import torch.nn as nn
from torchsummary import summary
import math
class IndividualBlock1(nn.Module):
def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(IndividualBlock1, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)
self.subsample_input = subsample_input
self.increase_filters = increase_filters
if subsample_input:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
elif increase_filters:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, x):
if self.subsample_input or self.increase_filters:
x = self.batch_norm1(x)
x = self.activation(x)
x1 = self.conv1(x)
else:
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
if self.subsample_input or self.increase_filters:
return self.conv_inp(x) + x1
else:
return x + x1
class IndividualBlockN(nn.Module):
def __init__(self, input_features, output_features, stride):
super(IndividualBlockN, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
def forward(self, x):
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
return x1 + x
class Nblock(nn.Module):
def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(Nblock, self).__init__()
layers = []
for i in range(N):
if i == 0:
layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
else:
layers.append(IndividualBlockN(output_features, output_features, stride=1))
self.nblockLayer = nn.Sequential(*layers)
def forward(self, x):
return self.nblockLayer(x)
class WideResNet(nn.Module):
def __init__(self, d, k, n_classes, input_features, output_features, strides):
super(WideResNet, self).__init__()
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)
filters = [16 * k, 32 * k, 64 * k]
self.out_filters = filters[-1]
N = (d - 4) // 6
increase_filters = k > 1
self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])
self.batch_norm = nn.BatchNorm2d(filters[-1])
self.activation = nn.ReLU(inplace=True)
self.avg_pool = nn.AvgPool2d(kernel_size=8)
self.fc = nn.Linear(filters[-1], n_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
x = self.conv1(x)
attention1 = self.block1(x)
attention2 = self.block2(attention1)
attention3 = self.block3(attention2)
out = self.batch_norm(attention3)
out = self.activation(out)
out = self.avg_pool(out)
out = out.view(-1, self.out_filters)
return self.fc(out), attention1, attention2, attention3
if __name__ == '__main__':
# change d and k if you want to check a model other than WRN-40-2
d = 40
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides)
# verify that an output is produced
sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
net(sample_input)
# Summarize model
summary(net, input_size=(3, 32, 32))

View file

@ -1,36 +0,0 @@
import torch
from torch.utils.data import random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42):
print(f"Train batch size: {train_batch_size}")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform)
#splitting data in half for teacher dataset and student dataset (use manual seed for consistency)
seed = torch.Generator().manual_seed(seed_val)
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
teacher_set = subsets[0]
student_set = subsets[1]
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
return teacher_set, student_set, testset

View file

@ -17,14 +17,13 @@ from tqdm import tqdm
import student_model import student_model
from utils import json_file_to_pyobj, get_loaders from utils import json_file_to_pyobj, get_loaders
from distillation_utils import get_teacherstudent_trainset
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--n_queries", default=2, type=int) parser.add_argument("--n_queries", default=2, type=int)
parser.add_argument("--model", default="resnet18", type=str) parser.add_argument("--model", default="resnet18", type=str)
parser.add_argument("--savedir", default="exp/cifar10", type=str) parser.add_argument("--savedir", default="exp/cifar10", type=str)
args = parser.parse_args() args = parser.parse_args()
SEED = 42
@torch.no_grad() @torch.no_grad()
def run(): def run():
@ -32,12 +31,7 @@ def run():
dataset = "cifar10" dataset = "cifar10"
# Dataset # Dataset
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json") train_dl, test_dl = get_loaders(dataset, 4096)
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
train_ds, test_ds = studentset, testset
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)
# Infer the logits with multiple queries # Infer the logits with multiple queries
for path in os.listdir(args.savedir): for path in os.listdir(args.savedir):

View file

@ -23,14 +23,11 @@ from pathlib import Path
import numpy as np import numpy as np
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from distillation_utils import get_teacherstudent_trainset
from utils import json_file_to_pyobj, get_loaders
from torch.utils.data import DataLoader
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--savedir", default="exp/cifar10", type=str) parser.add_argument("--savedir", default="exp/cifar10", type=str)
args = parser.parse_args() args = parser.parse_args()
SEED = 42
def load_one(path): def load_one(path):
""" """
@ -59,17 +56,9 @@ def load_one(path):
def get_labels(): def get_labels():
# Dataset
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
# Get the indices of the student set
student_indices = studentset.indices
datadir = Path().home() / "opt/data/cifar" datadir = Path().home() / "opt/data/cifar"
train_ds = CIFAR10(root=datadir, train=True, download=True) train_ds = CIFAR10(root=datadir, train=True, download=True)
# Access the original targets for the student set return np.array(train_ds.targets)
student_targets = [train_ds.targets[i] for i in student_indices]
return np.array(student_targets)
def load_stats(): def load_stats():

View file

@ -22,7 +22,9 @@ import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
import torchvision import torchvision
from torchvision import transforms from torchvision import transforms
from distillation_utils import get_teacherstudent_trainset
#privacy libraries #privacy libraries
import opacus import opacus
from opacus.validators import ModuleValidator from opacus.validators import ModuleValidator
@ -35,7 +37,6 @@ import student_model
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
SEED = 42 #setting for testing
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--lr", default=0.1, type=float) parser.add_argument("--lr", default=0.1, type=float)
@ -112,11 +113,7 @@ def run(teacher, student):
wandb.config.update(args) wandb.config.update(args)
# Dataset # Dataset
#get specific student set train_ds, test_ds = get_trainset()
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
train_ds, test_ds = studentset, testset
# Compute the IN / OUT subset: # Compute the IN / OUT subset:
# If we run each experiment independently then even after a lot of trials # If we run each experiment independently then even after a lot of trials
# there will still probably be some examples that were always included # there will still probably be some examples that were always included
@ -218,11 +215,7 @@ def main():
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
#get specific teacher set train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size)
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
loaders = trainloader, testloader
best_test_set_accuracy = 0 best_test_set_accuracy = 0
dp_epsilon = 8 dp_epsilon = 8
dp_delta = 1e-5 dp_delta = 1e-5
@ -231,14 +224,14 @@ def main():
teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=teacher, module=teacher,
optimizer=optimizer, optimizer=optimizer,
data_loader=trainloader, data_loader=train_loader,
epochs=epochs, epochs=epochs,
target_epsilon=dp_epsilon, target_epsilon=dp_epsilon,
target_delta=dp_delta, target_delta=dp_delta,
max_grad_norm=norm, max_grad_norm=norm,
) )
teacher.load_state_dict(torch.load(os.path.join("teachers_in/wrn-1733273741-8.0e-1e-05d-12.0n-dict.pt"), weights_only=True)) teacher.load_state_dict(torch.load(os.path.join("wrn-1733078278-8e-1e-05d-12.0n-dict.pt"), weights_only=True))
teacher.to(device) teacher.to(device)
teacher.eval() teacher.eval()
#instantiate student "shadow model" #instantiate student "shadow model"

View file

@ -1,62 +0,0 @@
import json
import collections
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
# Borrowed from https://github.com/ozan-oktay/Attention-Gated-Networks
def json_file_to_pyobj(filename):
def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values())
def json2obj(data): return json.loads(data, object_hook=_json_object_hook)
return json2obj(open(filename).read())
def get_loaders(dataset, train_batch_size=128, test_batch_size=10):
print(f"Train batch size: {train_batch_size}")
if dataset == 'cifar10':
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
elif dataset == 'svhn':
normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
return trainloader, testloader

View file

@ -1,11 +0,0 @@
{
"training":{
"dataset": "CIFAR10",
"wrn_depth": 16,
"wrn_width": 1,
"checkpoint": "True",
"log": "True",
"batch_size": 4096,
"epochs": 200
}
}

View file

@ -7,26 +7,201 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, Subset, TensorDataset from torch.utils.data import DataLoader, Subset, TensorDataset, ConcatDataset
import torch.nn.functional as F import torch.nn.functional as F
from pathlib import Path from pathlib import Path
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
import pytorch_lightning as pl import pytorch_lightning as pl
import opacus import opacus
import random
from tqdm import tqdm
from opacus.validators import ModuleValidator from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager from opacus.utils.batch_memory_manager import BatchMemoryManager
from concurrent.futures import ProcessPoolExecutor, as_completed
from WideResNet import WideResNet from WideResNet import WideResNet
from equations import get_eps_audit from equations import get_eps_audit
import student_model import student_model
import fast_model
import convnet_classifier
import wrn
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
DEVICE = None DEVICE = None
STUDENTBOOL = False DTYPE = None
DATADIR = Path("./data")
def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10):
seed = np.random.randint(0, 1e9)
seed ^= int(time.time())
pl.seed_everything(seed)
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
# Original dataset
x_train = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds)))
y_train = np.array(train_ds.targets).astype(np.int64)
x = np.stack(test_ds[i][0].numpy() for i in range(len(test_ds))) # Applies transforms
y = np.array(test_ds.targets).astype(np.int64)
# Pull points from training set when m > test set
if m > len(x):
k = m - len(x)
mask = np.full(len(x_train), False)
mask[:k] = True
x = np.concatenate([x_train[mask], x])
y = np.concatenate([y_train[mask], y])
x_train = x_train[~mask]
y_train = y_train[~mask]
# Store the m points which could have been included/excluded
mask = np.full(len(x), False)
mask[:m] = True
mask = mask[np.random.permutation(len(x))]
adv_points = x[mask]
adv_labels = y[mask]
# Mislabel inclusion/exclusion examples intentionally!
for i in range(len(adv_labels)):
while True:
c = np.random.choice(range(10))
if adv_labels[i] != c:
adv_labels[i] = c
break
# Choose m points to randomly exclude at chance
S = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
assert len(adv_points) == m
inc_points = adv_points[S]
inc_labels = adv_labels[S]
td = TensorDataset(torch.from_numpy(inc_points).float(), torch.from_numpy(inc_labels).long())
td2 = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train).long())
td = ConcatDataset([td, td2])
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
return train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S
def get_dataloaders_raw(m=1000, train_batch_size=512, test_batch_size=10):
def preprocess_data(data):
data = torch.tensor(data)#.to(DTYPE)
data = data / 255.0
data = data.permute(0, 3, 1, 2)
data = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(data)
data = nn.ReflectionPad2d(4)(data)
data = transforms.RandomCrop(size=(32, 32))(data)
data = transforms.RandomHorizontalFlip()(data)
return data
train_ds = CIFAR10(root=DATADIR, train=True, download=True)
test_ds = CIFAR10(root=DATADIR, train=False, download=True)
train_x = train_ds.data
test_x = test_ds.data
train_y = np.array(train_ds.targets)
test_y = np.array(test_ds.targets)
if m > len(test_x):
k = m - len(test_x)
mask = np.full(len(train_x), False)
mask[:k] = True
mask = mask[np.random.permutation(len(train_x))]
test_x = np.concatenate([train_x[mask], test_x])
test_y = np.concatenate([train_y[mask], test_y])
train_y = train_y[~mask]
train_x = train_x[~mask]
mask = np.full(len(test_x), False)
mask[:m] = True
mask = mask[np.random.permutation(len(test_x))]
S = np.random.choice([True, False], size=m)
attack_x = test_x[mask][S]
attack_y = test_y[mask][S]
for i in range(len(attack_y)):
while True:
c = np.random.choice(range(10))
if attack_y[i] != c:
attack_y[i] = c
break
train_x = np.concatenate([train_x, attack_x])
train_y = np.concatenate([train_y, attack_y])
train_x = preprocess_data(train_x)
test_x = preprocess_data(test_x)
attack_x = preprocess_data(attack_x)
train_y = torch.tensor(train_y)
test_y = torch.tensor(test_y)
attack_y = torch.tensor(attack_y)
train_dl = DataLoader(
TensorDataset(train_x, train_y.long()),
batch_size=train_batch_size,
shuffle=True,
drop_last=True,
num_workers=4
)
test_dl = DataLoader(
TensorDataset(test_x, test_y.long()),
batch_size=train_batch_size,
shuffle=True,
num_workers=4
)
return train_dl, test_dl, train_x, attack_x.numpy(), attack_y.numpy(), S
def evaluate_on(model, dataloader):
correct = 0
total = 0
with torch.no_grad():
model.eval()
for data in dataloader:
images, labels = data
images = images.to(DEVICE)
labels = labels.to(DEVICE)
wrn_outputs = model(images)
if len(wrn_outputs) == 4:
outputs = wrn_outputs[0]
else:
outputs = wrn_outputs
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct, total
def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75): def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75):
#instantiate istudent #instantiate istudent
@ -74,84 +249,8 @@ def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rat
return student_init, student return student_init, student
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
seed = np.random.randint(0, 1e9)
seed ^= int(time.time())
pl.seed_everything(seed)
train_transform = transforms.Compose([ def train_no_cap(model, model_init, hp, train_dl, test_dl, optimizer, criterion, scheduler, adv_points, adv_labels, S):
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
datadir = Path("./data")
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
# Original dataset
x = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms
p = np.random.permutation(len(train_ds))
# Choose m points to randomly exclude at chance
S = np.full(len(train_ds), True)
S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
# Store the m points which could have been included/excluded
mask = np.full(len(train_ds), False)
mask[:m] = True
mask = mask[p]
x_m = x[mask] # These are the points being guessed at
y_m = np.array(train_ds.targets)[mask].astype(np.int64)
S_m = S[p][mask] # Ground truth of inclusion/exclusion for x_m
# Remove excluded points from dataset
x_in = x[S[p]]
y_in = np.array(train_ds.targets).astype(np.int64)
y_in = y_in[S[p]]
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
return train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m
def evaluate_on(model, dataloader):
correct = 0
total = 0
with torch.no_grad():
model.eval()
for data in dataloader:
images, labels = data
images = images.to(DEVICE)
labels = labels.to(DEVICE)
wrn_outputs = model(images)
if STUDENTBOOL:
outputs = wrn_outputs
else:
outputs = wrn_outputs[0]
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct, total
def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
best_test_set_accuracy = 0 best_test_set_accuracy = 0
for epoch in range(hp['epochs']): for epoch in range(hp['epochs']):
@ -164,7 +263,10 @@ def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
optimizer.zero_grad() optimizer.zero_grad()
wrn_outputs = model(inputs) wrn_outputs = model(inputs)
if len(wrn_outputs) == 4:
outputs = wrn_outputs[0] outputs = wrn_outputs[0]
else:
outputs = wrn_outputs
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -174,12 +276,304 @@ def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
if epoch % 10 == 0 or epoch == hp['epochs'] - 1: if epoch % 10 == 0 or epoch == hp['epochs'] - 1:
correct, total = evaluate_on(model, test_dl) correct, total = evaluate_on(model, test_dl)
epoch_accuracy = round(100 * correct / total, 2) epoch_accuracy = round(100 * correct / total, 2)
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%") scores = score_model(model_init, model, adv_points, adv_labels, S)
audits = audit_model(hp, scores)
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}% | Audit : {audits[2]}/{2*audits[1]}/{audits[3]} | p[ε < {audits[0]}] < {hp['p_value']} @ ε={hp['epsilon']}")
return best_test_set_accuracy return best_test_set_accuracy
def train(hp, train_dl, test_dl): def load(hp, model_path, train_dl):
init_model = model_path / "init_model.pt"
trained_model = model_path / "trained_model.pt"
model = WideResNet(
d=hp["wrn_depth"],
k=hp["wrn_width"],
n_classes=10,
input_features=3,
output_features=16,
strides=[1, 1, 2, 2],
)
model = ModuleValidator.fix(model)
ModuleValidator.validate(model, strict=True)
model_init = copy.deepcopy(model)
privacy_engine = opacus.PrivacyEngine()
optimizer = optim.SGD(
model.parameters(),
lr=0.1,
momentum=0.9,
nesterov=True,
weight_decay=5e-4
)
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dl,
epochs=hp['epochs'],
target_epsilon=hp['epsilon'],
target_delta=hp['delta'],
max_grad_norm=hp['norm'],
)
model_init.load_state_dict(torch.load(init_model, weights_only=True))
model.load_state_dict(torch.load(trained_model, weights_only=True))
model_init = model_init.to(DEVICE)
model = model.to(DEVICE)
adv_points = np.load("data/adv_points.npy")
adv_labels = np.load("data/adv_labels.npy")
S = np.load("data/S.npy")
return model_init, model, adv_points, adv_labels, S
def train_wrn2(hp, train_dl, test_dl, adv_points, adv_labels, S):
model = wrn.WideResNet(16, 10, 4)
model = model.to(DEVICE)
ModuleValidator.validate(model, strict=True)
model_init = copy.deepcopy(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
model.parameters(),
lr=0.12,
momentum=0.9,
weight_decay=1e-4
)
scheduler = MultiStepLR(
optimizer,
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
gamma=0.1
)
print(f"Training with {hp['epochs']} epochs")
if hp['epsilon'] is not None:
privacy_engine = opacus.PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dl,
epochs=hp['epochs'],
target_epsilon=hp['epsilon'],
target_delta=hp['delta'],
max_grad_norm=hp['norm'],
)
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
with BatchMemoryManager(
data_loader=train_loader,
max_physical_batch_size=10, # 1000 ~= 9.4GB vram
optimizer=optimizer
) as memory_safe_data_loader:
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
memory_safe_data_loader,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
else:
print("Training without differential privacy")
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
train_dl,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
return model_init, model
def train_small(hp, train_dl, test_dl, adv_points, adv_labels, S):
model = student_model.Model(num_classes=10).to(DEVICE)
model = model.to(DEVICE)
model = ModuleValidator.fix(model)
ModuleValidator.validate(model, strict=True)
model_init = copy.deepcopy(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = MultiStepLR(
optimizer,
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
gamma=0.2
)
print(f"Training raw (no distill) STUDENT with {hp['epochs']} epochs")
if hp['epsilon'] is not None:
privacy_engine = opacus.PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dl,
epochs=hp['epochs'],
target_epsilon=hp['epsilon'],
target_delta=hp['delta'],
max_grad_norm=hp['norm'],
)
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
with BatchMemoryManager(
data_loader=train_loader,
max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
optimizer=optimizer
) as memory_safe_data_loader:
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
memory_safe_data_loader,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
else:
print("Training without differential privacy")
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
train_dl,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
return model_init, model
def train_fast(hp, train_dl, test_dl, train_x, adv_points, adv_labels, S):
epochs = hp['epochs']
momentum = 0.9
weight_decay = 0.256
weight_decay_bias = 0.004
ema_update_freq = 5
ema_rho = 0.99**ema_update_freq
dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32
print("=========================")
print("Training a fast model")
print("=========================")
weights = fast_model.patch_whitening(train_x[:10000, :, 4:-4, 4:-4])
model = fast_model.Model(weights, c_in=3, c_out=10, scale_out=0.125)
model.to(DEVICE)
init_model = copy.deepcopy(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
model.parameters(),
lr=0.1,
momentum=0.9,
nesterov=True,
weight_decay=5e-4
)
scheduler = MultiStepLR(
optimizer,
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
gamma=0.2
)
train_no_cap(model, model_init, hp, train_dl, test_dl, optimizer, criterion, scheduler, adv_points, adv_labels, S)
return init_model, model
def train_convnet(hp, train_dl, test_dl, adv_points, adv_labels, S):
model = convnet_classifier.ConvNet()
model = model.to(DEVICE)
ModuleValidator.validate(model, strict=True)
model_init = copy.deepcopy(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = MultiStepLR(optimizer, milestones=[10, 25], gamma=0.1)
print(f"Training with {hp['epochs']} epochs")
if hp['epsilon'] is not None:
privacy_engine = opacus.PrivacyEngine(accountant='rdp')
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dl,
epochs=hp['epochs'],
target_epsilon=hp['epsilon'],
target_delta=hp['delta'],
max_grad_norm=hp['norm'],
)
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
with BatchMemoryManager(
data_loader=train_loader,
max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
optimizer=optimizer
) as memory_safe_data_loader:
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
memory_safe_data_loader,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
else:
print("Training without differential privacy")
best_test_set_accuracy = train_no_cap(
model,
model_init,
hp,
train_dl,
test_dl,
optimizer,
criterion,
scheduler,
adv_points,
adv_labels,
S,
)
return model_init, model
def train(hp, train_dl, test_dl, adv_points, adv_labels, S):
model = WideResNet( model = WideResNet(
d=hp["wrn_depth"], d=hp["wrn_depth"],
k=hp["wrn_width"], k=hp["wrn_width"],
@ -232,58 +626,136 @@ def train(hp, train_dl, test_dl):
) as memory_safe_data_loader: ) as memory_safe_data_loader:
best_test_set_accuracy = train_no_cap( best_test_set_accuracy = train_no_cap(
model, model,
model_init,
hp, hp,
memory_safe_data_loader, memory_safe_data_loader,
test_dl, test_dl,
optimizer, optimizer,
criterion, criterion,
scheduler, scheduler,
adv_points,
adv_labels,
S,
) )
else: else:
print("Training without differential privacy") print("Training without differential privacy")
best_test_set_accuracy = train_no_cap( best_test_set_accuracy = train_no_cap(
model, model,
model_init,
hp, hp,
train_dl, train_dl,
test_dl, test_dl,
optimizer, optimizer,
criterion, criterion,
scheduler, scheduler,
adv_points,
adv_labels,
S,
) )
return model_init, model return model_init, model
def get_k_audit(k, scores, hp):
correct = np.sum(~scores[:k]) + np.sum(scores[-k:])
eps_lb = get_eps_audit(
hp['target_points'],
2*k,
correct,
hp['delta'],
hp['p_value']
)
return eps_lb, k, correct, len(scores)
def score_model(model_init, model_trained, adv_points, adv_labels, S):
scores = list()
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
model_init.eval()
x_m = torch.from_numpy(adv_points).to(DEVICE)
y_m = torch.from_numpy(adv_labels).long().to(DEVICE)
for i in range(len(x_m)):
x_point = x_m[i].unsqueeze(0).to(DEVICE)
y_point = y_m[i].unsqueeze(0).to(DEVICE)
is_in = S[i]
wrn_outputs = model_init(x_point)
outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs
init_loss = criterion(outputs, y_point)
wrn_outputs = model_trained(x_point)
outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs
trained_loss = criterion(outputs, y_point)
scores.append(((init_loss - trained_loss).item(), is_in))
scores = sorted(scores, key=lambda x: x[0])
scores = np.array([x[1] for x in scores])
return scores
def audit_model(hp, scores):
audits = (0, 0, 0, 0)
k_schedule = np.linspace(1, hp['target_points']//2, 40)
k_schedule = np.floor(k_schedule).astype(int)
with ProcessPoolExecutor() as executor:
futures = {
executor.submit(get_k_audit, k, scores, hp): k for k in k_schedule
}
for future in as_completed(futures):
try:
eps_lb, k, correct, total = future.result()
if eps_lb > audits[0]:
audits = (eps_lb, k, correct, total)
except Exception as exc:
k = futures[future]
print(f"'k={k}' generated an exception: {exc}")
return audits
def main(): def main():
global DEVICE global DEVICE
global DTYPE
parser = argparse.ArgumentParser(description='WideResNet O1 audit') parser = argparse.ArgumentParser(description='WideResNet O1 audit')
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True) parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
parser.add_argument('--cuda', type=int, help='gpu index', required=False) parser.add_argument('--cuda', type=int, help='gpu index', required=False)
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
parser.add_argument('--m', type=int, help='number of target points', required=True) parser.add_argument('--m', type=int, help='number of target points', required=True)
parser.add_argument('--auditmodel', type=str, help='type of model to audit', default="teacher") parser.add_argument('--epochs', type=int, help='number of epochs', required=True)
parser.add_argument('--load', type=Path, help='number of epochs', required=False)
parser.add_argument('--studentraw', action='store_true', help='train a raw student', required=False)
parser.add_argument('--distill', action='store_true', help='train a raw student', required=False)
parser.add_argument('--fast', action='store_true', help='train the fast model', required=False)
parser.add_argument('--wrn2', action='store_true', help='Train a groupnormed wrn', required=False)
parser.add_argument('--convnet', action='store_true', help='Train a convnet', required=False)
args = parser.parse_args() args = parser.parse_args()
if torch.cuda.is_available() and args.cuda: if torch.cuda.is_available() and args.cuda:
DEVICE = torch.device(f'cuda:{args.cuda}') DEVICE = torch.device(f'cuda:{args.cuda}')
DTYPE = torch.float16
elif torch.cuda.is_available(): elif torch.cuda.is_available():
DEVICE = torch.device('cuda:0') DEVICE = torch.device('cuda:0')
DTYPE = torch.float16
else: else:
DEVICE = torch.device('cpu') DEVICE = torch.device('cpu')
DTYPE = torch.float32
hp = { hp = {
"target_points": args.m, "target_points": args.m,
"wrn_depth": 16, "wrn_depth": 16,
"wrn_width": 1, "wrn_width": 1,
"epsilon": args.epsilon, "epsilon": args.epsilon,
"delta": 1e-5, "delta": 1e-6,
"norm": args.norm, "norm": args.norm,
"batch_size": 4096, "batch_size": 50 if args.convnet else 4096,
"epochs": 100, "epochs": args.epochs,
"k+": 200,
"k-": 200,
"p_value": 0.05, "p_value": 0.05,
} }
@ -298,82 +770,65 @@ def main():
hp['norm'], hp['norm'],
)) ))
train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size']) if args.load:
print(f"len train: {len(train_dl)}") train_dl, test_dl, ____, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size'])
print(f"Got vector Sm: {S_m.shape}, sum={np.sum(S_m)}") model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl)
print(f"Got x_in: {x_in.shape}") test_dl = None
print(f"Got x_m: {x_m.shape}") elif args.fast:
print(f"Got y_m: {y_m.shape}") train_dl, test_dl, train_x, adv_points, adv_labels, S = get_dataloaders_raw(hp['target_points'])
model_init, model_trained = train_fast(hp, train_dl, test_dl, train_x, adv_points, adv_labels, S)
# torch.save(model_init.state_dict(), "data/init_model.pt") else:
# torch.save(model_trained.state_dict(), "data/trained_model.pt") train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size'])
if args.wrn2:
if args.auditmodel == "student": print("=========================")
global STUDENTBOOL print("Training wrn2 model from meta")
teacher_init, teacher_trained = train(hp, train_dl, test_dl) print("=========================")
STUDENTBOOL = True model_init, model_trained = train_wrn2(hp, train_dl, test_dl, adv_points, adv_labels, S)
# torch.save(model_init.state_dict(), "data/init_model.pt") elif args.convnet:
# torch.save(model_trained.state_dict(), "data/trained_model.pt") print("=========================")
print("Training a simple convnet")
print("=========================")
#train student model model_init, model_trained = train_convnet(hp, train_dl, test_dl, adv_points, adv_labels, S)
print("Training Student Model") elif args.studentraw:
print("=========================")
print("Training a raw student model")
print("=========================")
model_init, model_trained = train_small(hp, train_dl, test_dl, adv_points, adv_labels, S)
elif args.distill:
print("=========================")
print("Training a distilled student model")
print("=========================")
teacher_init, teacher_trained = train(hp, train_dl, test_dl, adv_points, adv_labels, S)
model_init, model_trained = train_knowledge_distillation( model_init, model_trained = train_knowledge_distillation(
teacher=teacher_trained, teacher=teacher_trained,
train_dl=pure_train_dl, train_dl=train_dl,
epochs=100, epochs=hp['epochs'],
device=DEVICE, device=DEVICE,
learning_rate=0.001, learning_rate=0.001,
T=2, T=2,
soft_target_loss_weight=0.25, soft_target_loss_weight=0.25,
ce_loss_weight=0.75, ce_loss_weight=0.75,
) )
stcorrect, sttotal = evaluate_on(model_trained, test_dl)
stacc = stcorrect/sttotal*100
print(f"Student Accuracy: {stacc}%")
else: else:
print("=========================")
print("Training teacher model")
print("=========================")
model_init, model_trained = train(hp, train_dl, test_dl) model_init, model_trained = train(hp, train_dl, test_dl)
scores = list() np.save("data/adv_points", adv_points)
criterion = nn.CrossEntropyLoss() np.save("data/adv_labels", adv_labels)
with torch.no_grad(): np.save("data/S", S)
model_init.eval() torch.save(model_init.state_dict(), "data/init_model.pt")
x_m = torch.from_numpy(x_m).to(DEVICE) torch.save(model_trained.state_dict(), "data/trained_model.pt")
y_m = torch.from_numpy(y_m).long().to(DEVICE)
for i in range(len(x_m)): # scores = score_model(model_init, model_trained, adv_points, adv_labels, S)
x_point = x_m[i].unsqueeze(0) # audits = audit_model(hp, scores)
y_point = y_m[i].unsqueeze(0)
is_in = S_m[i]
if STUDENTBOOL:
init_loss = criterion(model_init(x_point), y_point)
trained_loss = criterion(model_trained(x_point), y_point)
else:
init_loss = criterion(model_init(x_point)[0], y_point)
trained_loss = criterion(model_trained(x_point)[0], y_point)
scores.append(((init_loss - trained_loss).item(), is_in)) # print(f"Audit total: {audits[2]}/{2*audits[1]}/{audits[3]}")
# print(f"p[ε < {audits[0]}] < {hp['p_value']} for true epsilon {hp['epsilon']}")
scores = sorted(scores, key=lambda x: x[0]) if test_dl is not None:
scores = np.array([x[1] for x in scores]) correct, total = evaluate_on(model_init, test_dl)
print(scores[:10])
correct = np.sum(~scores[:hp['k-']]) + np.sum(scores[-hp['k+']:])
total = len(scores)
eps_lb = get_eps_audit(
hp['target_points'],
hp['k+'] + hp['k-'],
correct,
hp['delta'],
hp['p_value']
)
print(f"Audit total: {correct}/{total} = {round(correct/total*100, 2)}")
print(f"p[ε < {eps_lb}] < {hp['p_value']}")
correct, total = evaluate_on(model_init, train_dl)
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}") print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
correct, total = evaluate_on(model_trained, test_dl) correct, total = evaluate_on(model_trained, test_dl)
print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}") print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")

View file

@ -0,0 +1,51 @@
# Name: Peng Cheng
# UIN: 674792652
#
# Code adapted from:
# https://github.com/jameschengpeng/PyTorch-CNN-on-CIFAR10
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=48, kernel_size=(3,3), padding=(1,1))
self.conv2 = nn.Conv2d(in_channels=48, out_channels=96, kernel_size=(3,3), padding=(1,1))
self.conv3 = nn.Conv2d(in_channels=96, out_channels=192, kernel_size=(3,3), padding=(1,1))
self.conv4 = nn.Conv2d(in_channels=192, out_channels=256, kernel_size=(3,3), padding=(1,1))
self.pool = nn.MaxPool2d(2,2)
self.fc1 = nn.Linear(in_features=8*8*256, out_features=512)
self.fc2 = nn.Linear(in_features=512, out_features=64)
self.Dropout = nn.Dropout(0.25)
self.fc3 = nn.Linear(in_features=64, out_features=10)
def forward(self, x):
x = F.relu(self.conv1(x)) #32*32*48
x = F.relu(self.conv2(x)) #32*32*96
x = self.pool(x) #16*16*96
x = self.Dropout(x)
x = F.relu(self.conv3(x)) #16*16*192
x = F.relu(self.conv4(x)) #16*16*256
x = self.pool(x) # 8*8*256
x = self.Dropout(x)
x = x.view(-1, 8*8*256) # reshape x
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.Dropout(x)
x = self.fc3(x)
return x

View file

@ -49,5 +49,4 @@ def get_eps_audit(m, r, v, delta, p):
if __name__ == '__main__': if __name__ == '__main__':
x = 100 print(get_eps_audit(1000, 600, 600, 1e-5, 0.05))
print(f"For m=100 r=100 v=100 p=0.05: {get_eps_audit(x, x, x, 1e-5, 0.05)}")

141
one_run_audit/fast_model.py Normal file
View file

@ -0,0 +1,141 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def label_smoothing_loss(inputs, targets, alpha):
log_probs = torch.nn.functional.log_softmax(inputs, dim=1, _stacklevel=5)
kl = -log_probs.mean(dim=1)
xent = torch.nn.functional.nll_loss(log_probs, targets, reduction="none")
loss = (1 - alpha) * xent + alpha * kl
return loss
class GhostBatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, num_splits, **kw):
super().__init__(num_features, **kw)
running_mean = torch.zeros(num_features * num_splits)
running_var = torch.ones(num_features * num_splits)
self.weight.requires_grad = False
self.num_splits = num_splits
self.register_buffer("running_mean", running_mean)
self.register_buffer("running_var", running_var)
def train(self, mode=True):
if (self.training is True) and (mode is False):
# lazily collate stats when we are going to use them
self.running_mean = torch.mean(
self.running_mean.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
self.running_var = torch.mean(
self.running_var.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
return super().train(mode)
def forward(self, input):
n, c, h, w = input.shape
if self.training or not self.track_running_stats:
assert n % self.num_splits == 0, f"Batch size ({n}) must be divisible by num_splits ({self.num_splits}) of GhostBatchNorm"
return F.batch_norm(
input.view(-1, c * self.num_splits, h, w),
self.running_mean,
self.running_var,
self.weight.repeat(self.num_splits),
self.bias.repeat(self.num_splits),
True,
self.momentum,
self.eps,
).view(n, c, h, w)
else:
return F.batch_norm(
input,
self.running_mean[: self.num_features],
self.running_var[: self.num_features],
self.weight,
self.bias,
False,
self.momentum,
self.eps,
)
def conv_bn_relu(c_in, c_out, kernel_size=(3, 3), padding=(1, 1)):
return nn.Sequential(
nn.Conv2d(c_in, c_out, kernel_size=kernel_size, padding=padding, bias=False),
GhostBatchNorm(c_out, num_splits=16),
nn.CELU(alpha=0.3),
)
def conv_pool_norm_act(c_in, c_out):
return nn.Sequential(
nn.Conv2d(c_in, c_out, kernel_size=(3, 3), padding=(1, 1), bias=False),
nn.MaxPool2d(kernel_size=2, stride=2),
GhostBatchNorm(c_out, num_splits=16),
nn.CELU(alpha=0.3),
)
def patch_whitening(data, patch_size=(3, 3)):
# Compute weights from data such that
# torch.std(F.conv2d(data, weights), dim=(2, 3))
# is close to 1.
h, w = patch_size
c = data.size(1)
patches = data.unfold(2, h, 1).unfold(3, w, 1)
patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)
n, c, h, w = patches.shape
X = patches.reshape(n, c * h * w)
X = X / (X.size(0) - 1) ** 0.5
covariance = X.t() @ X
eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
eigenvalues = eigenvalues.flip(0)
eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)
return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)
class ResNetBagOfTricks(nn.Module):
def __init__(self, first_layer_weights, c_in, c_out, scale_out):
super().__init__()
c = first_layer_weights.size(0)
conv1 = nn.Conv2d(c_in, c, kernel_size=(3, 3), padding=(1, 1), bias=False)
conv1.weight.data = first_layer_weights
conv1.weight.requires_grad = False
self.conv1 = conv1
self.conv2 = conv_bn_relu(c, 64, kernel_size=(1, 1), padding=0)
self.conv3 = conv_pool_norm_act(64, 128)
self.conv4 = conv_bn_relu(128, 128)
self.conv5 = conv_bn_relu(128, 128)
self.conv6 = conv_pool_norm_act(128, 256)
self.conv7 = conv_pool_norm_act(256, 512)
self.conv8 = conv_bn_relu(512, 512)
self.conv9 = conv_bn_relu(512, 512)
self.pool10 = nn.MaxPool2d(kernel_size=4, stride=4)
self.linear11 = nn.Linear(512, c_out, bias=False)
self.scale_out = scale_out
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x + self.conv5(self.conv4(x))
x = self.conv6(x)
x = self.conv7(x)
x = x + self.conv9(self.conv8(x))
x = self.pool10(x)
x = x.reshape(x.size(0), x.size(1))
x = self.linear11(x)
x = self.scale_out * x
return x
Model = ResNetBagOfTricks

View file

@ -1,21 +1,94 @@
import time
import math
import concurrent.futures
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from tqdm import tqdm
from equations import get_eps_audit from equations import get_eps_audit
delta = 1e-5 def compute_y(x_values, p, delta, proportion_correct, key):
p_value = 0.05 return key, [get_eps_audit(x, x, math.floor(x *proportion_correct), delta, p) for x in x_values]
x_values = np.floor((1.5)**np.arange(30)).astype(int)
x_values = np.concatenate([x_values[x_values < 60000], [60000]])
y_values = [get_eps_audit(x, x, x, delta, p_value) for x in tqdm(x_values)]
def get_plots():
final_values = dict()
mul = 1.5 #1.275 #1.5
max = 60000 #2000 #60000
x_values = np.floor((mul)**np.arange(30)).astype(int)
x_values = np.concatenate([x_values[x_values < max], [max]])
with concurrent.futures.ProcessPoolExecutor(max_workers=16) as executor:
start_time = time.time()
futures = [
executor.submit(compute_y, x_values, 0.05, 0.0, 1.0, "y11"),
executor.submit(compute_y, x_values, 0.05, 1e-6, 1.0, "y12"),
executor.submit(compute_y, x_values, 0.05, 1e-4, 1.0, "y13"),
executor.submit(compute_y, x_values, 0.05, 1e-2, 1.0, "y14"),
executor.submit(compute_y, x_values, 0.01, 0.0, 1.0, "y21"),
executor.submit(compute_y, x_values, 0.01, 1e-6, 1.0, "y22"),
executor.submit(compute_y, x_values, 0.01, 1e-4, 1.0, "y23"),
executor.submit(compute_y, x_values, 0.01, 1e-2, 1.0, "y24"),
executor.submit(compute_y, x_values, 0.05, 0.0, 0.9, "y31"),
executor.submit(compute_y, x_values, 0.05, 1e-6, 0.9, "y32"),
executor.submit(compute_y, x_values, 0.05, 1e-4, 0.9, "y33"),
executor.submit(compute_y, x_values, 0.05, 1e-2, 0.9, "y34"),
executor.submit(compute_y, x_values, 0.01, 0.0, 0.9, "y41"),
executor.submit(compute_y, x_values, 0.01, 1e-6, 0.9, "y42"),
executor.submit(compute_y, x_values, 0.01, 1e-4, 0.9, "y43"),
executor.submit(compute_y, x_values, 0.01, 1e-2, 0.9, "y44"),
]
for future in concurrent.futures.as_completed(futures):
k, v = future.result()
final_values[k] = v
print(f"Took: {time.time()-start_time}s")
return final_values, x_values
def plot_to(value_set, x_values, title, fig_name):
plt.xscale('log') plt.xscale('log')
plt.plot(x_values, y_values, marker='o') plt.plot(x_values, value_set[0], marker='o', label='δ=0')
plt.xlabel("Number of samples guessed correctly") plt.plot(x_values, value_set[1], marker='o', label='δ=1e-6')
plt.ylabel("ε value audited") plt.plot(x_values, value_set[2], marker='o', label='δ=1e-4')
plt.title("Maximum possible ε from audit") plt.plot(x_values, value_set[3], marker='o', label='δ=1e-2')
# 5. Save the plot as a PNG plt.xlabel("Number of samples attacked")
plt.savefig("/dev/shm/my_plot.png", dpi=300, bbox_inches='tight') plt.ylabel("Maximum ε lower-bound from audit")
plt.title(title)
plt.legend()
plt.savefig(fig_name, dpi=300, bbox_inches='tight')
def main():
final_values, x_values = get_plots()
plot_to(
[final_values[f"y1{i}"] for i in range(1,5)],
x_values,
"Maximum ε audit with p-value=0.05 and 100% MIA accuracy",
"/dev/shm/plot_05_100.png"
)
plot_to(
[final_values[f"y1{i}"] for i in range(1,5)],
x_values,
"Maximum ε audit with p-value=0.01 and 100% MIA accuracy",
"/dev/shm/plot_01_100.png"
)
plot_to(
[final_values[f"y1{i}"] for i in range(1,5)],
x_values,
"Maximum ε audit with p-value=0.05 and 90% MIA accuracy",
"/dev/shm/plot_05_90.png"
)
plot_to(
[final_values[f"y1{i}"] for i in range(1,5)],
x_values,
"Maximum ε audit with p-value=0.01 and 90% MIA accuracy"
"/dev/shm/plot_01_90.png"
)
if __name__ == '__main__':
main()

232
one_run_audit/wrn.py Normal file
View file

@ -0,0 +1,232 @@
"""
Adapted from:
https://github.com/facebookresearch/tan/blob/main/src/models/wideresnet.py
"""
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Adapted from timm:
https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class L2Norm(nn.Module):
def forward(self, x):
return x / x.norm(p=2, dim=1, keepdim=True)
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, nb_groups, order):
super(BasicBlock, self).__init__()
self.order = order
self.bn1 = nn.GroupNorm(nb_groups, in_planes) if nb_groups else nn.Identity()
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1
)
self.bn2 = nn.GroupNorm(nb_groups, out_planes) if nb_groups else nn.Identity()
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(
out_planes, out_planes, kernel_size=3, stride=1, padding=1
)
self.equalInOut = in_planes == out_planes
self.bnShortcut = (
(not self.equalInOut)
and nb_groups
and nn.GroupNorm(nb_groups, in_planes)
or (not self.equalInOut)
and nn.Identity()
or None
)
self.convShortcut = (
(not self.equalInOut)
and nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, padding=0
)
) or None
def forward(self, x):
skip = x
assert self.order in [0, 1, 2, 3]
if self.order == 0: # DM accuracy good
if not self.equalInOut:
skip = self.convShortcut(self.bnShortcut(self.relu1(x)))
out = self.conv1(self.bn1(self.relu1(x)))
out = self.conv2(self.bn2(self.relu2(out)))
elif self.order == 1: # classic accuracy bad
if not self.equalInOut:
skip = self.convShortcut(self.relu1(self.bnShortcut(x)))
out = self.conv1(self.relu1(self.bn1(x)))
out = self.conv2(self.relu2(self.bn2(out)))
elif self.order == 2: # DM IN RESIDUAL, normal other
if not self.equalInOut:
skip = self.convShortcut(self.bnShortcut(self.relu1(x)))
out = self.conv1(self.relu1(self.bn1(x)))
out = self.conv2(self.relu2(self.bn2(out)))
elif self.order == 3: # normal in residualm DM in others
if not self.equalInOut:
skip = self.convShortcut(self.relu1(self.bnShortcut(x)))
out = self.conv1(self.bn1(self.relu1(x)))
out = self.conv2(self.bn2(self.relu2(out)))
return torch.add(skip, out)
class NetworkBlock(nn.Module):
def __init__(
self, nb_layers, in_planes, out_planes, block, stride, nb_groups, order
):
super(NetworkBlock, self).__init__()
self.layer = self._make_layer(
block, in_planes, out_planes, nb_layers, stride, nb_groups, order
)
def _make_layer(
self, block, in_planes, out_planes, nb_layers, stride, nb_groups, order
):
layers = []
for i in range(int(nb_layers)):
layers.append(
block(
i == 0 and in_planes or out_planes,
out_planes,
i == 0 and stride or 1,
nb_groups,
order,
)
)
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
class WideResNet(nn.Module):
def __init__(
self,
depth,
feat_dim,
#num_classes,
widen_factor=1,
nb_groups=16,
init=0,
order1=0,
order2=0,
):
if order1 == 0:
print("order1=0: In the blocks: like in DM, BN on top of relu")
if order1 == 1:
print("order1=1: In the blocks: not like in DM, relu on top of BN")
if order1 == 2:
print(
"order1=2: In the blocks: BN on top of relu in residual (DM), relu on top of BN ortherplace (clqssique)"
)
if order1 == 3:
print(
"order1=3: In the blocks: relu on top of BN in residual (classic), BN on top of relu otherplace (DM)"
)
if order2 == 0:
print("order2=0: outside the blocks: like in DM, BN on top of relu")
if order2 == 1:
print("order2=1: outside the blocks: not like in DM, relu on top of BN")
super(WideResNet, self).__init__()
nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
assert (depth - 4) % 6 == 0
n = (depth - 4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1)
# 1st block
self.block1 = NetworkBlock(
n, nChannels[0], nChannels[1], block, 1, nb_groups, order1
)
# 2nd block
self.block2 = NetworkBlock(
n, nChannels[1], nChannels[2], block, 2, nb_groups, order1
)
# 3rd block
self.block3 = NetworkBlock(
n, nChannels[2], nChannels[3], block, 2, nb_groups, order1
)
# global average pooling and classifier
"""
self.bn1 = nn.GroupNorm(nb_groups, nChannels[3]) if nb_groups else nn.Identity()
self.relu = nn.ReLU()
self.fc = nn.Linear(nChannels[3], num_classes)
"""
self.nChannels = nChannels[3]
self.block4 = nn.Sequential(
nn.Flatten(),
nn.Linear(256 * 8 * 8, 4096, bias=False), # 256 * 6 * 6 if 224 * 224
nn.GroupNorm(16, 4096),
nn.ReLU(inplace=True),
)
# fc7
self.block5 = nn.Sequential(
nn.Linear(4096, 4096, bias=False),
nn.GroupNorm(16, 4096),
nn.ReLU(inplace=True),
)
# fc8
self.block6 =nn.Sequential(
nn.Linear(4096, feat_dim),
L2Norm(),
)
if init == 0: # as in Deep Mind's paper
for m in self.modules():
if isinstance(m, nn.Conv2d):
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight)
s = 1 / (max(fan_in, 1)) ** 0.5
nn.init.trunc_normal_(m.weight, std=s)
m.bias.data.zero_()
elif isinstance(m, nn.GroupNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight)
s = 1 / (max(fan_in, 1)) ** 0.5
nn.init.trunc_normal_(m.weight, std=s)
#m.bias.data.zero_()
if init == 1: # old version
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, nn.GroupNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
self.order2 = order2
def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.block4(out)
out = self.block5(out)
out = self.block6(out)
if out.ndim == 4:
out = out.mean(dim=-1)
if out.ndim == 3:
out = out.mean(dim=-1)
#out = self.bn1(self.relu(out)) if self.order2 == 0 else self.relu(self.bn1(out))
#out = F.avg_pool2d(out, 8)
#out = out.view(-1, self.nChannels)
return out#self.fc(out)

View file

@ -17,12 +17,9 @@ import student_model
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
import opacus import opacus
from distillation_utils import get_teacherstudent_trainset
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
SEED = 42 #setting for testing
def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device): def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
# Dataset # Dataset
@ -119,12 +116,7 @@ def main():
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
#get specific student set train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size)
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
teachertrainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
studenttrainloader = DataLoader(studentset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
best_test_set_accuracy = 0 best_test_set_accuracy = 0
if args.epsilon is not None: if args.epsilon is not None:
@ -135,7 +127,7 @@ def main():
teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=teacher, module=teacher,
optimizer=optimizer, optimizer=optimizer,
data_loader=teachertrainloader, data_loader=train_loader,
epochs=epochs, epochs=epochs,
target_epsilon=dp_epsilon, target_epsilon=dp_epsilon,
target_delta=dp_delta, target_delta=dp_delta,
@ -152,7 +144,7 @@ def main():
train_knowledge_distillation( train_knowledge_distillation(
teacher=teacher, teacher=teacher,
student=student, student=student,
train_dl=studenttrainloader, train_dl=train_loader,
epochs=args.epochs, epochs=args.epochs,
learning_rate=0.001, learning_rate=0.001,
T=2, T=2,
@ -165,8 +157,8 @@ def main():
torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt") torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt")
print("Testing student and teacher") print("Testing student and teacher")
test_student = test(student, device, testloader) test_student = test(student, device, test_loader)
test_teacher = test(teacher, device, testloader, True) test_teacher = test(teacher, device, test_loader, True)
print(f"Teacher accuracy: {test_teacher:.2f}%") print(f"Teacher accuracy: {test_teacher:.2f}%")
print(f"Student accuracy: {test_student:.2f}%") print(f"Student accuracy: {test_student:.2f}%")

View file

@ -1,47 +0,0 @@
import torch
from torch.utils.data import random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
from torch.utils.data import Subset
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42, teacher_datapt_out=False):
print(f"Train batch size: {train_batch_size}")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform)
#splitting data in half for teacher dataset and student dataset (use manual seed for consistency)
seed = torch.Generator().manual_seed(seed_val)
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
teacher_set = subsets[0]
#if removing datapoint from teacher:
if teacher_datapt_out:
teacher_indices = teacher_set.indices
size = len(teacher_set)
index_to_remove = torch.randint(0, size, (1,)).item() # Randomly select one index
keep_bool = torch.ones(size, dtype=torch.bool)
keep_bool[index_to_remove] = False
keep_indices = torch.tensor(teacher_indices)[keep_bool]
teacher_set = Subset(trainset, keep_indices.tolist())
student_set = subsets[1]
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
return teacher_set, student_set, testset

View file

@ -7,18 +7,14 @@ import torch.nn as nn
import numpy as np import numpy as np
import random import random
from utils import json_file_to_pyobj, get_loaders from utils import json_file_to_pyobj, get_loaders
from distillation_utils import get_teacherstudent_trainset
from WideResNet import WideResNet from WideResNet import WideResNet
from tqdm import tqdm from tqdm import tqdm
import opacus import opacus
from opacus.validators import ModuleValidator from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager from opacus.utils.batch_memory_manager import BatchMemoryManager
from torch.utils.data import DataLoader
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
SEED = 0
def set_seed(seed=42): def set_seed(seed=42):
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
@ -26,9 +22,7 @@ def set_seed(seed=42):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
#doing for convenience fix later
global SEED
SEED = seed
def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, scheduler, test_loader, log, logfile, checkpointFile): def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, scheduler, test_loader, log, logfile, checkpointFile):
best_test_set_accuracy = 0 best_test_set_accuracy = 0
@ -86,6 +80,7 @@ def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, schedul
def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200, norm=1.0, dp_epsilon=None): def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200, norm=1.0, dp_epsilon=None):
train_loader, test_loader = loaders train_loader, test_loader = loaders
dp_delta = 1e-5 dp_delta = 1e-5
checkpointFile = 'wrn-{}-{}e-{}d-{}n-dict.pt'.format(int(time.time()), dp_epsilon, dp_delta, norm) checkpointFile = 'wrn-{}-{}e-{}d-{}n-dict.pt'.format(int(time.time()), dp_epsilon, dp_delta, norm)
@ -148,11 +143,8 @@ def train(args):
logfile = '' logfile = ''
checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False
#get specific teacher set loaders = get_loaders(dataset, training_configurations.batch_size)
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED, True)
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
loaders = trainloader, testloader
if torch.cuda.is_available() and args.cuda: if torch.cuda.is_available() and args.cuda:
device = torch.device(f'cuda:{args.cuda}') device = torch.device(f'cuda:{args.cuda}')
elif torch.cuda.is_available(): elif torch.cuda.is_available():
@ -173,7 +165,7 @@ def train(args):
net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides) net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides)
net = net.to(device) net = net.to(device)
epochs = 100 epochs = training_configurations.epochs
best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, logfile, epochs, args.norm, args.epsilon) best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, logfile, epochs, args.norm, args.epsilon)
if log: if log: