Compare commits

...

4 commits

12 changed files with 483 additions and 33 deletions

143
lira-pytorch/WideResNet.py Normal file
View file

@ -0,0 +1,143 @@
import torch
import torch.nn as nn
from torchsummary import summary
import math
class IndividualBlock1(nn.Module):
def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(IndividualBlock1, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)
self.subsample_input = subsample_input
self.increase_filters = increase_filters
if subsample_input:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
elif increase_filters:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, x):
if self.subsample_input or self.increase_filters:
x = self.batch_norm1(x)
x = self.activation(x)
x1 = self.conv1(x)
else:
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
if self.subsample_input or self.increase_filters:
return self.conv_inp(x) + x1
else:
return x + x1
class IndividualBlockN(nn.Module):
def __init__(self, input_features, output_features, stride):
super(IndividualBlockN, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
def forward(self, x):
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
return x1 + x
class Nblock(nn.Module):
def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(Nblock, self).__init__()
layers = []
for i in range(N):
if i == 0:
layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
else:
layers.append(IndividualBlockN(output_features, output_features, stride=1))
self.nblockLayer = nn.Sequential(*layers)
def forward(self, x):
return self.nblockLayer(x)
class WideResNet(nn.Module):
def __init__(self, d, k, n_classes, input_features, output_features, strides):
super(WideResNet, self).__init__()
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)
filters = [16 * k, 32 * k, 64 * k]
self.out_filters = filters[-1]
N = (d - 4) // 6
increase_filters = k > 1
self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])
self.batch_norm = nn.BatchNorm2d(filters[-1])
self.activation = nn.ReLU(inplace=True)
self.avg_pool = nn.AvgPool2d(kernel_size=8)
self.fc = nn.Linear(filters[-1], n_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
x = self.conv1(x)
attention1 = self.block1(x)
attention2 = self.block2(attention1)
attention3 = self.block3(attention2)
out = self.batch_norm(attention3)
out = self.activation(out)
out = self.avg_pool(out)
out = out.view(-1, self.out_filters)
return self.fc(out), attention1, attention2, attention3
if __name__ == '__main__':
# change d and k if you want to check a model other than WRN-40-2
d = 40
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides)
# verify that an output is produced
sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
net(sample_input)
# Summarize model
summary(net, input_size=(3, 32, 32))

View file

@ -0,0 +1,36 @@
import torch
from torch.utils.data import random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42):
print(f"Train batch size: {train_batch_size}")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform)
#splitting data in half for teacher dataset and student dataset (use manual seed for consistency)
seed = torch.Generator().manual_seed(seed_val)
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
teacher_set = subsets[0]
student_set = subsets[1]
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
return teacher_set, student_set, testset

View file

@ -17,13 +17,14 @@ from tqdm import tqdm
import student_model import student_model
from utils import json_file_to_pyobj, get_loaders from utils import json_file_to_pyobj, get_loaders
from distillation_utils import get_teacherstudent_trainset
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--n_queries", default=2, type=int) parser.add_argument("--n_queries", default=2, type=int)
parser.add_argument("--model", default="resnet18", type=str) parser.add_argument("--model", default="resnet18", type=str)
parser.add_argument("--savedir", default="exp/cifar10", type=str) parser.add_argument("--savedir", default="exp/cifar10", type=str)
args = parser.parse_args() args = parser.parse_args()
SEED = 42
@torch.no_grad() @torch.no_grad()
def run(): def run():
@ -31,7 +32,12 @@ def run():
dataset = "cifar10" dataset = "cifar10"
# Dataset # Dataset
train_dl, test_dl = get_loaders(dataset, 4096) json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
train_ds, test_ds = studentset, testset
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)
# Infer the logits with multiple queries # Infer the logits with multiple queries
for path in os.listdir(args.savedir): for path in os.listdir(args.savedir):

View file

@ -23,11 +23,14 @@ from pathlib import Path
import numpy as np import numpy as np
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from distillation_utils import get_teacherstudent_trainset
from utils import json_file_to_pyobj, get_loaders
from torch.utils.data import DataLoader
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--savedir", default="exp/cifar10", type=str) parser.add_argument("--savedir", default="exp/cifar10", type=str)
args = parser.parse_args() args = parser.parse_args()
SEED = 42
def load_one(path): def load_one(path):
""" """
@ -56,9 +59,17 @@ def load_one(path):
def get_labels(): def get_labels():
# Dataset
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
# Get the indices of the student set
student_indices = studentset.indices
datadir = Path().home() / "opt/data/cifar" datadir = Path().home() / "opt/data/cifar"
train_ds = CIFAR10(root=datadir, train=True, download=True) train_ds = CIFAR10(root=datadir, train=True, download=True)
return np.array(train_ds.targets) # Access the original targets for the student set
student_targets = [train_ds.targets[i] for i in student_indices]
return np.array(student_targets)
def load_stats(): def load_stats():

View file

@ -22,9 +22,7 @@ import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
import torchvision import torchvision
from torchvision import transforms from torchvision import transforms
from distillation_utils import get_teacherstudent_trainset
#privacy libraries #privacy libraries
import opacus import opacus
from opacus.validators import ModuleValidator from opacus.validators import ModuleValidator
@ -37,6 +35,7 @@ import student_model
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
SEED = 42 #setting for testing
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--lr", default=0.1, type=float) parser.add_argument("--lr", default=0.1, type=float)
@ -113,7 +112,11 @@ def run(teacher, student):
wandb.config.update(args) wandb.config.update(args)
# Dataset # Dataset
train_ds, test_ds = get_trainset() #get specific student set
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
train_ds, test_ds = studentset, testset
# Compute the IN / OUT subset: # Compute the IN / OUT subset:
# If we run each experiment independently then even after a lot of trials # If we run each experiment independently then even after a lot of trials
# there will still probably be some examples that were always included # there will still probably be some examples that were always included
@ -215,7 +218,11 @@ def main():
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size) #get specific teacher set
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
loaders = trainloader, testloader
best_test_set_accuracy = 0 best_test_set_accuracy = 0
dp_epsilon = 8 dp_epsilon = 8
dp_delta = 1e-5 dp_delta = 1e-5
@ -224,14 +231,14 @@ def main():
teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=teacher, module=teacher,
optimizer=optimizer, optimizer=optimizer,
data_loader=train_loader, data_loader=trainloader,
epochs=epochs, epochs=epochs,
target_epsilon=dp_epsilon, target_epsilon=dp_epsilon,
target_delta=dp_delta, target_delta=dp_delta,
max_grad_norm=norm, max_grad_norm=norm,
) )
teacher.load_state_dict(torch.load(os.path.join("wrn-1733078278-8e-1e-05d-12.0n-dict.pt"), weights_only=True)) teacher.load_state_dict(torch.load(os.path.join("teachers_in/wrn-1733273741-8.0e-1e-05d-12.0n-dict.pt"), weights_only=True))
teacher.to(device) teacher.to(device)
teacher.eval() teacher.eval()
#instantiate student "shadow model" #instantiate student "shadow model"

62
lira-pytorch/utils.py Normal file
View file

@ -0,0 +1,62 @@
import json
import collections
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
# Borrowed from https://github.com/ozan-oktay/Attention-Gated-Networks
def json_file_to_pyobj(filename):
def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values())
def json2obj(data): return json.loads(data, object_hook=_json_object_hook)
return json2obj(open(filename).read())
def get_loaders(dataset, train_batch_size=128, test_batch_size=10):
print(f"Train batch size: {train_batch_size}")
if dataset == 'cifar10':
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
elif dataset == 'svhn':
normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
return trainloader, testloader

View file

@ -0,0 +1,11 @@
{
"training":{
"dataset": "CIFAR10",
"wrn_depth": 16,
"wrn_width": 1,
"checkpoint": "True",
"log": "True",
"batch_size": 4096,
"epochs": 200
}
}

View file

@ -18,12 +18,61 @@ from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager from opacus.utils.batch_memory_manager import BatchMemoryManager
from WideResNet import WideResNet from WideResNet import WideResNet
from equations import get_eps_audit from equations import get_eps_audit
import student_model
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
DEVICE = None DEVICE = None
STUDENTBOOL = False
def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75):
#instantiate istudent
student = student_model.Model(num_classes=10).to(device)
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
student_init = copy.deepcopy(student)
student.to(device)
teacher.to(device)
teacher.eval() # Teacher set to evaluation mode
student.train() # Student to train mode
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_dl:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
with torch.no_grad():
teacher_logits, _, _, _ = teacher(inputs)
# Forward pass with the student model
student_logits = student(inputs)
#Soften the student logits by applying softmax first and log() second
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)
# Weighted sum of the two losses
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
running_loss += loss.item()
if epoch % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
return student_init, student
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
seed = np.random.randint(0, 1e9) seed = np.random.randint(0, 1e9)
@ -72,9 +121,10 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long()) td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4) train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
return train_dl, test_dl, x_in, x_m, y_m, S_m return train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m
def evaluate_on(model, dataloader): def evaluate_on(model, dataloader):
@ -90,7 +140,10 @@ def evaluate_on(model, dataloader):
labels = labels.to(DEVICE) labels = labels.to(DEVICE)
wrn_outputs = model(images) wrn_outputs = model(images)
outputs = wrn_outputs[0] if STUDENTBOOL:
outputs = wrn_outputs
else:
outputs = wrn_outputs[0]
_, predicted = torch.max(outputs.data, 1) _, predicted = torch.max(outputs.data, 1)
total += labels.size(0) total += labels.size(0)
correct += (predicted == labels).sum().item() correct += (predicted == labels).sum().item()
@ -209,6 +262,8 @@ def main():
parser.add_argument('--cuda', type=int, help='gpu index', required=False) parser.add_argument('--cuda', type=int, help='gpu index', required=False)
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
parser.add_argument('--m', type=int, help='number of target points', required=True) parser.add_argument('--m', type=int, help='number of target points', required=True)
parser.add_argument('--auditmodel', type=str, help='type of model to audit', default="teacher")
args = parser.parse_args() args = parser.parse_args()
if torch.cuda.is_available() and args.cuda: if torch.cuda.is_available() and args.cuda:
@ -227,8 +282,8 @@ def main():
"norm": args.norm, "norm": args.norm,
"batch_size": 4096, "batch_size": 4096,
"epochs": 100, "epochs": 100,
"k+": 300, "k+": 200,
"k-": 300, "k-": 200,
"p_value": 0.05, "p_value": 0.05,
} }
@ -243,18 +298,42 @@ def main():
hp['norm'], hp['norm'],
)) ))
train_dl, test_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size']) train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size'])
print(f"len train: {len(train_dl)}") print(f"len train: {len(train_dl)}")
print(f"Got vector Sm: {S_m.shape}, sum={np.sum(S_m)}") print(f"Got vector Sm: {S_m.shape}, sum={np.sum(S_m)}")
print(f"Got x_in: {x_in.shape}") print(f"Got x_in: {x_in.shape}")
print(f"Got x_m: {x_m.shape}") print(f"Got x_m: {x_m.shape}")
print(f"Got y_m: {y_m.shape}") print(f"Got y_m: {y_m.shape}")
model_init, model_trained = train(hp, train_dl, test_dl)
# torch.save(model_init.state_dict(), "data/init_model.pt") # torch.save(model_init.state_dict(), "data/init_model.pt")
# torch.save(model_trained.state_dict(), "data/trained_model.pt") # torch.save(model_trained.state_dict(), "data/trained_model.pt")
if args.auditmodel == "student":
global STUDENTBOOL
teacher_init, teacher_trained = train(hp, train_dl, test_dl)
STUDENTBOOL = True
# torch.save(model_init.state_dict(), "data/init_model.pt")
# torch.save(model_trained.state_dict(), "data/trained_model.pt")
#train student model
print("Training Student Model")
model_init, model_trained = train_knowledge_distillation(
teacher=teacher_trained,
train_dl=pure_train_dl,
epochs=100,
device=DEVICE,
learning_rate=0.001,
T=2,
soft_target_loss_weight=0.25,
ce_loss_weight=0.75,
)
stcorrect, sttotal = evaluate_on(model_trained, test_dl)
stacc = stcorrect/sttotal*100
print(f"Student Accuracy: {stacc}%")
else:
model_init, model_trained = train(hp, train_dl, test_dl)
scores = list() scores = list()
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
with torch.no_grad(): with torch.no_grad():
@ -266,9 +345,12 @@ def main():
x_point = x_m[i].unsqueeze(0) x_point = x_m[i].unsqueeze(0)
y_point = y_m[i].unsqueeze(0) y_point = y_m[i].unsqueeze(0)
is_in = S_m[i] is_in = S_m[i]
if STUDENTBOOL:
init_loss = criterion(model_init(x_point)[0], y_point) init_loss = criterion(model_init(x_point), y_point)
trained_loss = criterion(model_trained(x_point)[0], y_point) trained_loss = criterion(model_trained(x_point), y_point)
else:
init_loss = criterion(model_init(x_point)[0], y_point)
trained_loss = criterion(model_trained(x_point)[0], y_point)
scores.append(((init_loss - trained_loss).item(), is_in)) scores.append(((init_loss - trained_loss).item(), is_in))
@ -290,7 +372,7 @@ def main():
print(f"Audit total: {correct}/{total} = {round(correct/total*100, 2)}") print(f"Audit total: {correct}/{total} = {round(correct/total*100, 2)}")
print(f"p[ε < {eps_lb}] < {hp['p_value']}") print(f"p[ε < {eps_lb}] < {hp['p_value']}")
correct, total = evaluate_on(model_init, train_dl) correct, total = evaluate_on(model_init, train_dl)
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}") print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
correct, total = evaluate_on(model_trained, test_dl) correct, total = evaluate_on(model_trained, test_dl)

View file

@ -0,0 +1,29 @@
import torch
import torch.nn as nn
# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedLightNNCosine(nn.Module):
def __init__(self, num_classes=10):
super(ModifiedLightNNCosine, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
return x
Model = ModifiedLightNNCosine

View file

@ -17,9 +17,12 @@ import student_model
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
import opacus import opacus
from distillation_utils import get_teacherstudent_trainset
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
SEED = 42 #setting for testing
def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device): def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
# Dataset # Dataset
@ -116,7 +119,12 @@ def main():
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size) #get specific student set
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
teachertrainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
studenttrainloader = DataLoader(studentset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
best_test_set_accuracy = 0 best_test_set_accuracy = 0
if args.epsilon is not None: if args.epsilon is not None:
@ -127,7 +135,7 @@ def main():
teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=teacher, module=teacher,
optimizer=optimizer, optimizer=optimizer,
data_loader=train_loader, data_loader=teachertrainloader,
epochs=epochs, epochs=epochs,
target_epsilon=dp_epsilon, target_epsilon=dp_epsilon,
target_delta=dp_delta, target_delta=dp_delta,
@ -144,7 +152,7 @@ def main():
train_knowledge_distillation( train_knowledge_distillation(
teacher=teacher, teacher=teacher,
student=student, student=student,
train_dl=train_loader, train_dl=studenttrainloader,
epochs=args.epochs, epochs=args.epochs,
learning_rate=0.001, learning_rate=0.001,
T=2, T=2,
@ -157,8 +165,8 @@ def main():
torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt") torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt")
print("Testing student and teacher") print("Testing student and teacher")
test_student = test(student, device, test_loader) test_student = test(student, device, testloader)
test_teacher = test(teacher, device, test_loader, True) test_teacher = test(teacher, device, testloader, True)
print(f"Teacher accuracy: {test_teacher:.2f}%") print(f"Teacher accuracy: {test_teacher:.2f}%")
print(f"Student accuracy: {test_student:.2f}%") print(f"Student accuracy: {test_student:.2f}%")

View file

@ -0,0 +1,47 @@
import torch
from torch.utils.data import random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
from torch.utils.data import Subset
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42, teacher_datapt_out=False):
print(f"Train batch size: {train_batch_size}")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform)
#splitting data in half for teacher dataset and student dataset (use manual seed for consistency)
seed = torch.Generator().manual_seed(seed_val)
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
teacher_set = subsets[0]
#if removing datapoint from teacher:
if teacher_datapt_out:
teacher_indices = teacher_set.indices
size = len(teacher_set)
index_to_remove = torch.randint(0, size, (1,)).item() # Randomly select one index
keep_bool = torch.ones(size, dtype=torch.bool)
keep_bool[index_to_remove] = False
keep_indices = torch.tensor(teacher_indices)[keep_bool]
teacher_set = Subset(trainset, keep_indices.tolist())
student_set = subsets[1]
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
return teacher_set, student_set, testset

View file

@ -7,14 +7,18 @@ import torch.nn as nn
import numpy as np import numpy as np
import random import random
from utils import json_file_to_pyobj, get_loaders from utils import json_file_to_pyobj, get_loaders
from distillation_utils import get_teacherstudent_trainset
from WideResNet import WideResNet from WideResNet import WideResNet
from tqdm import tqdm from tqdm import tqdm
import opacus import opacus
from opacus.validators import ModuleValidator from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager from opacus.utils.batch_memory_manager import BatchMemoryManager
from torch.utils.data import DataLoader
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
SEED = 0
def set_seed(seed=42): def set_seed(seed=42):
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
@ -22,7 +26,9 @@ def set_seed(seed=42):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
#doing for convenience fix later
global SEED
SEED = seed
def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, scheduler, test_loader, log, logfile, checkpointFile): def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, scheduler, test_loader, log, logfile, checkpointFile):
best_test_set_accuracy = 0 best_test_set_accuracy = 0
@ -80,7 +86,6 @@ def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, schedul
def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200, norm=1.0, dp_epsilon=None): def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200, norm=1.0, dp_epsilon=None):
train_loader, test_loader = loaders train_loader, test_loader = loaders
dp_delta = 1e-5 dp_delta = 1e-5
checkpointFile = 'wrn-{}-{}e-{}d-{}n-dict.pt'.format(int(time.time()), dp_epsilon, dp_delta, norm) checkpointFile = 'wrn-{}-{}e-{}d-{}n-dict.pt'.format(int(time.time()), dp_epsilon, dp_delta, norm)
@ -143,8 +148,11 @@ def train(args):
logfile = '' logfile = ''
checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False
loaders = get_loaders(dataset, training_configurations.batch_size) #get specific teacher set
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED, True)
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
loaders = trainloader, testloader
if torch.cuda.is_available() and args.cuda: if torch.cuda.is_available() and args.cuda:
device = torch.device(f'cuda:{args.cuda}') device = torch.device(f'cuda:{args.cuda}')
elif torch.cuda.is_available(): elif torch.cuda.is_available():
@ -165,7 +173,7 @@ def train(args):
net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides) net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides)
net = net.to(device) net = net.to(device)
epochs = training_configurations.epochs epochs = 100
best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, logfile, epochs, args.norm, args.epsilon) best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, logfile, epochs, args.norm, args.epsilon)
if log: if log: