made to use student/teacher subsets that are half data each
This commit is contained in:
parent
369249ce69
commit
3fda6dd727
10 changed files with 350 additions and 22 deletions
143
lira-pytorch/WideResNet.py
Normal file
143
lira-pytorch/WideResNet.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchsummary import summary
|
||||
import math
|
||||
|
||||
|
||||
class IndividualBlock1(nn.Module):
|
||||
def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
|
||||
super(IndividualBlock1, self).__init__()
|
||||
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
self.batch_norm1 = nn.BatchNorm2d(input_features)
|
||||
self.batch_norm2 = nn.BatchNorm2d(output_features)
|
||||
|
||||
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
self.subsample_input = subsample_input
|
||||
self.increase_filters = increase_filters
|
||||
if subsample_input:
|
||||
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
|
||||
elif increase_filters:
|
||||
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.subsample_input or self.increase_filters:
|
||||
x = self.batch_norm1(x)
|
||||
x = self.activation(x)
|
||||
x1 = self.conv1(x)
|
||||
else:
|
||||
x1 = self.batch_norm1(x)
|
||||
x1 = self.activation(x1)
|
||||
x1 = self.conv1(x1)
|
||||
x1 = self.batch_norm2(x1)
|
||||
x1 = self.activation(x1)
|
||||
x1 = self.conv2(x1)
|
||||
|
||||
if self.subsample_input or self.increase_filters:
|
||||
return self.conv_inp(x) + x1
|
||||
else:
|
||||
return x + x1
|
||||
|
||||
|
||||
class IndividualBlockN(nn.Module):
|
||||
def __init__(self, input_features, output_features, stride):
|
||||
super(IndividualBlockN, self).__init__()
|
||||
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
self.batch_norm1 = nn.BatchNorm2d(input_features)
|
||||
self.batch_norm2 = nn.BatchNorm2d(output_features)
|
||||
|
||||
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.batch_norm1(x)
|
||||
x1 = self.activation(x1)
|
||||
x1 = self.conv1(x1)
|
||||
x1 = self.batch_norm2(x1)
|
||||
x1 = self.activation(x1)
|
||||
x1 = self.conv2(x1)
|
||||
|
||||
return x1 + x
|
||||
|
||||
|
||||
class Nblock(nn.Module):
|
||||
|
||||
def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
|
||||
super(Nblock, self).__init__()
|
||||
|
||||
layers = []
|
||||
for i in range(N):
|
||||
if i == 0:
|
||||
layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
|
||||
else:
|
||||
layers.append(IndividualBlockN(output_features, output_features, stride=1))
|
||||
|
||||
self.nblockLayer = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.nblockLayer(x)
|
||||
|
||||
|
||||
class WideResNet(nn.Module):
|
||||
|
||||
def __init__(self, d, k, n_classes, input_features, output_features, strides):
|
||||
super(WideResNet, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)
|
||||
|
||||
filters = [16 * k, 32 * k, 64 * k]
|
||||
self.out_filters = filters[-1]
|
||||
N = (d - 4) // 6
|
||||
increase_filters = k > 1
|
||||
self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
|
||||
self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
|
||||
self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])
|
||||
|
||||
self.batch_norm = nn.BatchNorm2d(filters[-1])
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size=8)
|
||||
self.fc = nn.Linear(filters[-1], n_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
attention1 = self.block1(x)
|
||||
attention2 = self.block2(attention1)
|
||||
attention3 = self.block3(attention2)
|
||||
out = self.batch_norm(attention3)
|
||||
out = self.activation(out)
|
||||
out = self.avg_pool(out)
|
||||
out = out.view(-1, self.out_filters)
|
||||
|
||||
return self.fc(out), attention1, attention2, attention3
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# change d and k if you want to check a model other than WRN-40-2
|
||||
d = 40
|
||||
k = 2
|
||||
strides = [1, 1, 2, 2]
|
||||
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides)
|
||||
|
||||
# verify that an output is produced
|
||||
sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
|
||||
net(sample_input)
|
||||
|
||||
# Summarize model
|
||||
summary(net, input_size=(3, 32, 32))
|
36
lira-pytorch/distillation_utils.py
Normal file
36
lira-pytorch/distillation_utils.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
from torch.utils.data import random_split
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
import torch.nn.functional as F
|
||||
|
||||
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42):
|
||||
print(f"Train batch size: {train_batch_size}")
|
||||
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
train_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
|
||||
(4, 4, 4, 4), mode='reflect').squeeze()),
|
||||
transforms.ToPILImage(),
|
||||
transforms.RandomCrop(32),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
test_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform)
|
||||
|
||||
#splitting data in half for teacher dataset and student dataset (use manual seed for consistency)
|
||||
seed = torch.Generator().manual_seed(seed_val)
|
||||
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
|
||||
teacher_set = subsets[0]
|
||||
student_set = subsets[1]
|
||||
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
|
||||
|
||||
return teacher_set, student_set, testset
|
|
@ -17,13 +17,14 @@ from tqdm import tqdm
|
|||
|
||||
import student_model
|
||||
from utils import json_file_to_pyobj, get_loaders
|
||||
from distillation_utils import get_teacherstudent_trainset
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--n_queries", default=2, type=int)
|
||||
parser.add_argument("--model", default="resnet18", type=str)
|
||||
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
SEED = 42
|
||||
|
||||
@torch.no_grad()
|
||||
def run():
|
||||
|
@ -31,7 +32,12 @@ def run():
|
|||
dataset = "cifar10"
|
||||
|
||||
# Dataset
|
||||
train_dl, test_dl = get_loaders(dataset, 4096)
|
||||
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
|
||||
training_configurations = json_options.training
|
||||
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
|
||||
train_ds, test_ds = studentset, testset
|
||||
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
|
||||
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)
|
||||
|
||||
# Infer the logits with multiple queries
|
||||
for path in os.listdir(args.savedir):
|
||||
|
|
|
@ -23,11 +23,14 @@ from pathlib import Path
|
|||
|
||||
import numpy as np
|
||||
from torchvision.datasets import CIFAR10
|
||||
from distillation_utils import get_teacherstudent_trainset
|
||||
from utils import json_file_to_pyobj, get_loaders
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
SEED = 42
|
||||
|
||||
def load_one(path):
|
||||
"""
|
||||
|
@ -56,9 +59,17 @@ def load_one(path):
|
|||
|
||||
|
||||
def get_labels():
|
||||
# Dataset
|
||||
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
|
||||
training_configurations = json_options.training
|
||||
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
|
||||
# Get the indices of the student set
|
||||
student_indices = studentset.indices
|
||||
datadir = Path().home() / "opt/data/cifar"
|
||||
train_ds = CIFAR10(root=datadir, train=True, download=True)
|
||||
return np.array(train_ds.targets)
|
||||
# Access the original targets for the student set
|
||||
student_targets = [train_ds.targets[i] for i in student_indices]
|
||||
return np.array(student_targets)
|
||||
|
||||
|
||||
def load_stats():
|
||||
|
|
|
@ -22,9 +22,7 @@ import torch.optim as optim
|
|||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
|
||||
from distillation_utils import get_teacherstudent_trainset
|
||||
#privacy libraries
|
||||
import opacus
|
||||
from opacus.validators import ModuleValidator
|
||||
|
@ -37,6 +35,7 @@ import student_model
|
|||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
SEED = 42 #setting for testing
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--lr", default=0.1, type=float)
|
||||
|
@ -113,7 +112,11 @@ def run(teacher, student):
|
|||
wandb.config.update(args)
|
||||
|
||||
# Dataset
|
||||
train_ds, test_ds = get_trainset()
|
||||
#get specific student set
|
||||
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
|
||||
training_configurations = json_options.training
|
||||
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
|
||||
train_ds, test_ds = studentset, testset
|
||||
# Compute the IN / OUT subset:
|
||||
# If we run each experiment independently then even after a lot of trials
|
||||
# there will still probably be some examples that were always included
|
||||
|
@ -215,7 +218,11 @@ def main():
|
|||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
|
||||
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
|
||||
train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size)
|
||||
#get specific teacher set
|
||||
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
|
||||
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
||||
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
||||
loaders = trainloader, testloader
|
||||
best_test_set_accuracy = 0
|
||||
dp_epsilon = 8
|
||||
dp_delta = 1e-5
|
||||
|
@ -224,14 +231,14 @@ def main():
|
|||
teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
||||
module=teacher,
|
||||
optimizer=optimizer,
|
||||
data_loader=train_loader,
|
||||
data_loader=trainloader,
|
||||
epochs=epochs,
|
||||
target_epsilon=dp_epsilon,
|
||||
target_delta=dp_delta,
|
||||
max_grad_norm=norm,
|
||||
)
|
||||
|
||||
teacher.load_state_dict(torch.load(os.path.join("wrn-1733078278-8e-1e-05d-12.0n-dict.pt"), weights_only=True))
|
||||
teacher.load_state_dict(torch.load(os.path.join("teachers_in/wrn-1733273741-8.0e-1e-05d-12.0n-dict.pt"), weights_only=True))
|
||||
teacher.to(device)
|
||||
teacher.eval()
|
||||
#instantiate student "shadow model"
|
||||
|
|
62
lira-pytorch/utils.py
Normal file
62
lira-pytorch/utils.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
import json
|
||||
import collections
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# Borrowed from https://github.com/ozan-oktay/Attention-Gated-Networks
|
||||
def json_file_to_pyobj(filename):
|
||||
def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values())
|
||||
|
||||
def json2obj(data): return json.loads(data, object_hook=_json_object_hook)
|
||||
|
||||
return json2obj(open(filename).read())
|
||||
|
||||
|
||||
def get_loaders(dataset, train_batch_size=128, test_batch_size=10):
|
||||
print(f"Train batch size: {train_batch_size}")
|
||||
|
||||
if dataset == 'cifar10':
|
||||
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
|
||||
(4, 4, 4, 4), mode='reflect').squeeze()),
|
||||
transforms.ToPILImage(),
|
||||
transforms.RandomCrop(32),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
test_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
|
||||
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||
|
||||
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
|
||||
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||
|
||||
elif dataset == 'svhn':
|
||||
normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
|
||||
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||
|
||||
testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
|
||||
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||
|
||||
return trainloader, testloader
|
||||
|
||||
|
11
lira-pytorch/wresnet16-audit-cifar10.json
Normal file
11
lira-pytorch/wresnet16-audit-cifar10.json
Normal file
|
@ -0,0 +1,11 @@
|
|||
{
|
||||
"training":{
|
||||
"dataset": "CIFAR10",
|
||||
"wrn_depth": 16,
|
||||
"wrn_width": 1,
|
||||
"checkpoint": "True",
|
||||
"log": "True",
|
||||
"batch_size": 4096,
|
||||
"epochs": 200
|
||||
}
|
||||
}
|
|
@ -17,9 +17,12 @@ import student_model
|
|||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import opacus
|
||||
from distillation_utils import get_teacherstudent_trainset
|
||||
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
SEED = 42 #setting for testing
|
||||
|
||||
def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
|
||||
# Dataset
|
||||
|
@ -116,7 +119,12 @@ def main():
|
|||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
|
||||
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
|
||||
train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size)
|
||||
#get specific student set
|
||||
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
|
||||
teachertrainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
||||
studenttrainloader = DataLoader(studentset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
||||
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
||||
|
||||
best_test_set_accuracy = 0
|
||||
|
||||
if args.epsilon is not None:
|
||||
|
@ -127,7 +135,7 @@ def main():
|
|||
teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
||||
module=teacher,
|
||||
optimizer=optimizer,
|
||||
data_loader=train_loader,
|
||||
data_loader=teachertrainloader,
|
||||
epochs=epochs,
|
||||
target_epsilon=dp_epsilon,
|
||||
target_delta=dp_delta,
|
||||
|
@ -144,7 +152,7 @@ def main():
|
|||
train_knowledge_distillation(
|
||||
teacher=teacher,
|
||||
student=student,
|
||||
train_dl=train_loader,
|
||||
train_dl=studenttrainloader,
|
||||
epochs=args.epochs,
|
||||
learning_rate=0.001,
|
||||
T=2,
|
||||
|
@ -157,8 +165,8 @@ def main():
|
|||
torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt")
|
||||
|
||||
print("Testing student and teacher")
|
||||
test_student = test(student, device, test_loader)
|
||||
test_teacher = test(teacher, device, test_loader, True)
|
||||
test_student = test(student, device, testloader)
|
||||
test_teacher = test(teacher, device, testloader, True)
|
||||
print(f"Teacher accuracy: {test_teacher:.2f}%")
|
||||
print(f"Student accuracy: {test_student:.2f}%")
|
||||
|
||||
|
|
36
wresnet-pytorch/src/distillation_utils.py
Normal file
36
wresnet-pytorch/src/distillation_utils.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
from torch.utils.data import random_split
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
import torch.nn.functional as F
|
||||
|
||||
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42):
|
||||
print(f"Train batch size: {train_batch_size}")
|
||||
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
train_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
|
||||
(4, 4, 4, 4), mode='reflect').squeeze()),
|
||||
transforms.ToPILImage(),
|
||||
transforms.RandomCrop(32),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
test_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform)
|
||||
|
||||
#splitting data in half for teacher dataset and student dataset (use manual seed for consistency)
|
||||
seed = torch.Generator().manual_seed(seed_val)
|
||||
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
|
||||
teacher_set = subsets[0]
|
||||
student_set = subsets[1]
|
||||
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
|
||||
|
||||
return teacher_set, student_set, testset
|
|
@ -7,14 +7,18 @@ import torch.nn as nn
|
|||
import numpy as np
|
||||
import random
|
||||
from utils import json_file_to_pyobj, get_loaders
|
||||
from distillation_utils import get_teacherstudent_trainset
|
||||
from WideResNet import WideResNet
|
||||
from tqdm import tqdm
|
||||
import opacus
|
||||
from opacus.validators import ModuleValidator
|
||||
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
SEED = 0
|
||||
|
||||
def set_seed(seed=42):
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
@ -22,7 +26,9 @@ def set_seed(seed=42):
|
|||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
#doing for convenience fix later
|
||||
global SEED
|
||||
SEED = seed
|
||||
|
||||
def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, scheduler, test_loader, log, logfile, checkpointFile):
|
||||
best_test_set_accuracy = 0
|
||||
|
@ -80,7 +86,6 @@ def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, schedul
|
|||
|
||||
def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200, norm=1.0, dp_epsilon=None):
|
||||
train_loader, test_loader = loaders
|
||||
|
||||
dp_delta = 1e-5
|
||||
checkpointFile = 'wrn-{}-{}e-{}d-{}n-dict.pt'.format(int(time.time()), dp_epsilon, dp_delta, norm)
|
||||
|
||||
|
@ -143,8 +148,11 @@ def train(args):
|
|||
logfile = ''
|
||||
|
||||
checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False
|
||||
loaders = get_loaders(dataset, training_configurations.batch_size)
|
||||
|
||||
#get specific teacher set
|
||||
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
|
||||
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
||||
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
||||
loaders = trainloader, testloader
|
||||
if torch.cuda.is_available() and args.cuda:
|
||||
device = torch.device(f'cuda:{args.cuda}')
|
||||
elif torch.cuda.is_available():
|
||||
|
@ -165,7 +173,7 @@ def train(args):
|
|||
net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides)
|
||||
net = net.to(device)
|
||||
|
||||
epochs = training_configurations.epochs
|
||||
epochs = 100
|
||||
best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, logfile, epochs, args.norm, args.epsilon)
|
||||
|
||||
if log:
|
||||
|
|
Loading…
Reference in a new issue