Compare commits

...

4 commits

12 changed files with 483 additions and 33 deletions

143
lira-pytorch/WideResNet.py Normal file
View file

@ -0,0 +1,143 @@
import torch
import torch.nn as nn
from torchsummary import summary
import math
class IndividualBlock1(nn.Module):
def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(IndividualBlock1, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)
self.subsample_input = subsample_input
self.increase_filters = increase_filters
if subsample_input:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
elif increase_filters:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, x):
if self.subsample_input or self.increase_filters:
x = self.batch_norm1(x)
x = self.activation(x)
x1 = self.conv1(x)
else:
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
if self.subsample_input or self.increase_filters:
return self.conv_inp(x) + x1
else:
return x + x1
class IndividualBlockN(nn.Module):
def __init__(self, input_features, output_features, stride):
super(IndividualBlockN, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
def forward(self, x):
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
return x1 + x
class Nblock(nn.Module):
def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(Nblock, self).__init__()
layers = []
for i in range(N):
if i == 0:
layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
else:
layers.append(IndividualBlockN(output_features, output_features, stride=1))
self.nblockLayer = nn.Sequential(*layers)
def forward(self, x):
return self.nblockLayer(x)
class WideResNet(nn.Module):
def __init__(self, d, k, n_classes, input_features, output_features, strides):
super(WideResNet, self).__init__()
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)
filters = [16 * k, 32 * k, 64 * k]
self.out_filters = filters[-1]
N = (d - 4) // 6
increase_filters = k > 1
self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])
self.batch_norm = nn.BatchNorm2d(filters[-1])
self.activation = nn.ReLU(inplace=True)
self.avg_pool = nn.AvgPool2d(kernel_size=8)
self.fc = nn.Linear(filters[-1], n_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
x = self.conv1(x)
attention1 = self.block1(x)
attention2 = self.block2(attention1)
attention3 = self.block3(attention2)
out = self.batch_norm(attention3)
out = self.activation(out)
out = self.avg_pool(out)
out = out.view(-1, self.out_filters)
return self.fc(out), attention1, attention2, attention3
if __name__ == '__main__':
# change d and k if you want to check a model other than WRN-40-2
d = 40
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides)
# verify that an output is produced
sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
net(sample_input)
# Summarize model
summary(net, input_size=(3, 32, 32))

View file

@ -0,0 +1,36 @@
import torch
from torch.utils.data import random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42):
print(f"Train batch size: {train_batch_size}")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform)
#splitting data in half for teacher dataset and student dataset (use manual seed for consistency)
seed = torch.Generator().manual_seed(seed_val)
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
teacher_set = subsets[0]
student_set = subsets[1]
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
return teacher_set, student_set, testset

View file

@ -17,13 +17,14 @@ from tqdm import tqdm
import student_model
from utils import json_file_to_pyobj, get_loaders
from distillation_utils import get_teacherstudent_trainset
parser = argparse.ArgumentParser()
parser.add_argument("--n_queries", default=2, type=int)
parser.add_argument("--model", default="resnet18", type=str)
parser.add_argument("--savedir", default="exp/cifar10", type=str)
args = parser.parse_args()
SEED = 42
@torch.no_grad()
def run():
@ -31,7 +32,12 @@ def run():
dataset = "cifar10"
# Dataset
train_dl, test_dl = get_loaders(dataset, 4096)
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
train_ds, test_ds = studentset, testset
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)
# Infer the logits with multiple queries
for path in os.listdir(args.savedir):

View file

@ -23,11 +23,14 @@ from pathlib import Path
import numpy as np
from torchvision.datasets import CIFAR10
from distillation_utils import get_teacherstudent_trainset
from utils import json_file_to_pyobj, get_loaders
from torch.utils.data import DataLoader
parser = argparse.ArgumentParser()
parser.add_argument("--savedir", default="exp/cifar10", type=str)
args = parser.parse_args()
SEED = 42
def load_one(path):
"""
@ -56,9 +59,17 @@ def load_one(path):
def get_labels():
# Dataset
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
# Get the indices of the student set
student_indices = studentset.indices
datadir = Path().home() / "opt/data/cifar"
train_ds = CIFAR10(root=datadir, train=True, download=True)
return np.array(train_ds.targets)
# Access the original targets for the student set
student_targets = [train_ds.targets[i] for i in student_indices]
return np.array(student_targets)
def load_stats():

View file

@ -22,9 +22,7 @@ import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from distillation_utils import get_teacherstudent_trainset
#privacy libraries
import opacus
from opacus.validators import ModuleValidator
@ -37,6 +35,7 @@ import student_model
import warnings
warnings.filterwarnings("ignore")
SEED = 42 #setting for testing
parser = argparse.ArgumentParser()
parser.add_argument("--lr", default=0.1, type=float)
@ -113,7 +112,11 @@ def run(teacher, student):
wandb.config.update(args)
# Dataset
train_ds, test_ds = get_trainset()
#get specific student set
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
train_ds, test_ds = studentset, testset
# Compute the IN / OUT subset:
# If we run each experiment independently then even after a lot of trials
# there will still probably be some examples that were always included
@ -215,7 +218,11 @@ def main():
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size)
#get specific teacher set
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
loaders = trainloader, testloader
best_test_set_accuracy = 0
dp_epsilon = 8
dp_delta = 1e-5
@ -224,14 +231,14 @@ def main():
teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=teacher,
optimizer=optimizer,
data_loader=train_loader,
data_loader=trainloader,
epochs=epochs,
target_epsilon=dp_epsilon,
target_delta=dp_delta,
max_grad_norm=norm,
)
teacher.load_state_dict(torch.load(os.path.join("wrn-1733078278-8e-1e-05d-12.0n-dict.pt"), weights_only=True))
teacher.load_state_dict(torch.load(os.path.join("teachers_in/wrn-1733273741-8.0e-1e-05d-12.0n-dict.pt"), weights_only=True))
teacher.to(device)
teacher.eval()
#instantiate student "shadow model"

62
lira-pytorch/utils.py Normal file
View file

@ -0,0 +1,62 @@
import json
import collections
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
# Borrowed from https://github.com/ozan-oktay/Attention-Gated-Networks
def json_file_to_pyobj(filename):
def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values())
def json2obj(data): return json.loads(data, object_hook=_json_object_hook)
return json2obj(open(filename).read())
def get_loaders(dataset, train_batch_size=128, test_batch_size=10):
print(f"Train batch size: {train_batch_size}")
if dataset == 'cifar10':
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
elif dataset == 'svhn':
normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
return trainloader, testloader

View file

@ -0,0 +1,11 @@
{
"training":{
"dataset": "CIFAR10",
"wrn_depth": 16,
"wrn_width": 1,
"checkpoint": "True",
"log": "True",
"batch_size": 4096,
"epochs": 200
}
}

View file

@ -18,12 +18,61 @@ from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager
from WideResNet import WideResNet
from equations import get_eps_audit
import student_model
import warnings
warnings.filterwarnings("ignore")
DEVICE = None
STUDENTBOOL = False
def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75):
#instantiate istudent
student = student_model.Model(num_classes=10).to(device)
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
student_init = copy.deepcopy(student)
student.to(device)
teacher.to(device)
teacher.eval() # Teacher set to evaluation mode
student.train() # Student to train mode
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_dl:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
with torch.no_grad():
teacher_logits, _, _, _ = teacher(inputs)
# Forward pass with the student model
student_logits = student(inputs)
#Soften the student logits by applying softmax first and log() second
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)
# Weighted sum of the two losses
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
running_loss += loss.item()
if epoch % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
return student_init, student
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
seed = np.random.randint(0, 1e9)
@ -72,9 +121,10 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
return train_dl, test_dl, x_in, x_m, y_m, S_m
return train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m
def evaluate_on(model, dataloader):
@ -90,6 +140,9 @@ def evaluate_on(model, dataloader):
labels = labels.to(DEVICE)
wrn_outputs = model(images)
if STUDENTBOOL:
outputs = wrn_outputs
else:
outputs = wrn_outputs[0]
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
@ -209,6 +262,8 @@ def main():
parser.add_argument('--cuda', type=int, help='gpu index', required=False)
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
parser.add_argument('--m', type=int, help='number of target points', required=True)
parser.add_argument('--auditmodel', type=str, help='type of model to audit', default="teacher")
args = parser.parse_args()
if torch.cuda.is_available() and args.cuda:
@ -227,8 +282,8 @@ def main():
"norm": args.norm,
"batch_size": 4096,
"epochs": 100,
"k+": 300,
"k-": 300,
"k+": 200,
"k-": 200,
"p_value": 0.05,
}
@ -243,18 +298,42 @@ def main():
hp['norm'],
))
train_dl, test_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size'])
train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size'])
print(f"len train: {len(train_dl)}")
print(f"Got vector Sm: {S_m.shape}, sum={np.sum(S_m)}")
print(f"Got x_in: {x_in.shape}")
print(f"Got x_m: {x_m.shape}")
print(f"Got y_m: {y_m.shape}")
model_init, model_trained = train(hp, train_dl, test_dl)
# torch.save(model_init.state_dict(), "data/init_model.pt")
# torch.save(model_trained.state_dict(), "data/trained_model.pt")
if args.auditmodel == "student":
global STUDENTBOOL
teacher_init, teacher_trained = train(hp, train_dl, test_dl)
STUDENTBOOL = True
# torch.save(model_init.state_dict(), "data/init_model.pt")
# torch.save(model_trained.state_dict(), "data/trained_model.pt")
#train student model
print("Training Student Model")
model_init, model_trained = train_knowledge_distillation(
teacher=teacher_trained,
train_dl=pure_train_dl,
epochs=100,
device=DEVICE,
learning_rate=0.001,
T=2,
soft_target_loss_weight=0.25,
ce_loss_weight=0.75,
)
stcorrect, sttotal = evaluate_on(model_trained, test_dl)
stacc = stcorrect/sttotal*100
print(f"Student Accuracy: {stacc}%")
else:
model_init, model_trained = train(hp, train_dl, test_dl)
scores = list()
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
@ -266,7 +345,10 @@ def main():
x_point = x_m[i].unsqueeze(0)
y_point = y_m[i].unsqueeze(0)
is_in = S_m[i]
if STUDENTBOOL:
init_loss = criterion(model_init(x_point), y_point)
trained_loss = criterion(model_trained(x_point), y_point)
else:
init_loss = criterion(model_init(x_point)[0], y_point)
trained_loss = criterion(model_trained(x_point)[0], y_point)

View file

@ -0,0 +1,29 @@
import torch
import torch.nn as nn
# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedLightNNCosine(nn.Module):
def __init__(self, num_classes=10):
super(ModifiedLightNNCosine, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
return x
Model = ModifiedLightNNCosine

View file

@ -17,9 +17,12 @@ import student_model
import torch.optim as optim
import torch.nn.functional as F
import opacus
from distillation_utils import get_teacherstudent_trainset
import warnings
warnings.filterwarnings("ignore")
SEED = 42 #setting for testing
def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
# Dataset
@ -116,7 +119,12 @@ def main():
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size)
#get specific student set
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
teachertrainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
studenttrainloader = DataLoader(studentset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
best_test_set_accuracy = 0
if args.epsilon is not None:
@ -127,7 +135,7 @@ def main():
teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=teacher,
optimizer=optimizer,
data_loader=train_loader,
data_loader=teachertrainloader,
epochs=epochs,
target_epsilon=dp_epsilon,
target_delta=dp_delta,
@ -144,7 +152,7 @@ def main():
train_knowledge_distillation(
teacher=teacher,
student=student,
train_dl=train_loader,
train_dl=studenttrainloader,
epochs=args.epochs,
learning_rate=0.001,
T=2,
@ -157,8 +165,8 @@ def main():
torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt")
print("Testing student and teacher")
test_student = test(student, device, test_loader)
test_teacher = test(teacher, device, test_loader, True)
test_student = test(student, device, testloader)
test_teacher = test(teacher, device, testloader, True)
print(f"Teacher accuracy: {test_teacher:.2f}%")
print(f"Student accuracy: {test_student:.2f}%")

View file

@ -0,0 +1,47 @@
import torch
from torch.utils.data import random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
from torch.utils.data import Subset
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42, teacher_datapt_out=False):
print(f"Train batch size: {train_batch_size}")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform)
#splitting data in half for teacher dataset and student dataset (use manual seed for consistency)
seed = torch.Generator().manual_seed(seed_val)
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
teacher_set = subsets[0]
#if removing datapoint from teacher:
if teacher_datapt_out:
teacher_indices = teacher_set.indices
size = len(teacher_set)
index_to_remove = torch.randint(0, size, (1,)).item() # Randomly select one index
keep_bool = torch.ones(size, dtype=torch.bool)
keep_bool[index_to_remove] = False
keep_indices = torch.tensor(teacher_indices)[keep_bool]
teacher_set = Subset(trainset, keep_indices.tolist())
student_set = subsets[1]
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
return teacher_set, student_set, testset

View file

@ -7,14 +7,18 @@ import torch.nn as nn
import numpy as np
import random
from utils import json_file_to_pyobj, get_loaders
from distillation_utils import get_teacherstudent_trainset
from WideResNet import WideResNet
from tqdm import tqdm
import opacus
from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings("ignore")
SEED = 0
def set_seed(seed=42):
torch.backends.cudnn.deterministic = True
@ -22,7 +26,9 @@ def set_seed(seed=42):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
#doing for convenience fix later
global SEED
SEED = seed
def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, scheduler, test_loader, log, logfile, checkpointFile):
best_test_set_accuracy = 0
@ -80,7 +86,6 @@ def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, schedul
def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200, norm=1.0, dp_epsilon=None):
train_loader, test_loader = loaders
dp_delta = 1e-5
checkpointFile = 'wrn-{}-{}e-{}d-{}n-dict.pt'.format(int(time.time()), dp_epsilon, dp_delta, norm)
@ -143,8 +148,11 @@ def train(args):
logfile = ''
checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False
loaders = get_loaders(dataset, training_configurations.batch_size)
#get specific teacher set
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED, True)
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
loaders = trainloader, testloader
if torch.cuda.is_available() and args.cuda:
device = torch.device(f'cuda:{args.cuda}')
elif torch.cuda.is_available():
@ -165,7 +173,7 @@ def train(args):
net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides)
net = net.to(device)
epochs = training_configurations.epochs
epochs = 100
best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, logfile, epochs, args.norm, args.epsilon)
if log: