made to use student/teacher subsets that are half data each

This commit is contained in:
ARVP 2024-12-04 07:25:56 -07:00
parent 369249ce69
commit 3fda6dd727
10 changed files with 350 additions and 22 deletions

143
lira-pytorch/WideResNet.py Normal file
View file

@ -0,0 +1,143 @@
import torch
import torch.nn as nn
from torchsummary import summary
import math
class IndividualBlock1(nn.Module):
def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(IndividualBlock1, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)
self.subsample_input = subsample_input
self.increase_filters = increase_filters
if subsample_input:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
elif increase_filters:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, x):
if self.subsample_input or self.increase_filters:
x = self.batch_norm1(x)
x = self.activation(x)
x1 = self.conv1(x)
else:
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
if self.subsample_input or self.increase_filters:
return self.conv_inp(x) + x1
else:
return x + x1
class IndividualBlockN(nn.Module):
def __init__(self, input_features, output_features, stride):
super(IndividualBlockN, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
def forward(self, x):
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
return x1 + x
class Nblock(nn.Module):
def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(Nblock, self).__init__()
layers = []
for i in range(N):
if i == 0:
layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
else:
layers.append(IndividualBlockN(output_features, output_features, stride=1))
self.nblockLayer = nn.Sequential(*layers)
def forward(self, x):
return self.nblockLayer(x)
class WideResNet(nn.Module):
def __init__(self, d, k, n_classes, input_features, output_features, strides):
super(WideResNet, self).__init__()
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)
filters = [16 * k, 32 * k, 64 * k]
self.out_filters = filters[-1]
N = (d - 4) // 6
increase_filters = k > 1
self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])
self.batch_norm = nn.BatchNorm2d(filters[-1])
self.activation = nn.ReLU(inplace=True)
self.avg_pool = nn.AvgPool2d(kernel_size=8)
self.fc = nn.Linear(filters[-1], n_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
x = self.conv1(x)
attention1 = self.block1(x)
attention2 = self.block2(attention1)
attention3 = self.block3(attention2)
out = self.batch_norm(attention3)
out = self.activation(out)
out = self.avg_pool(out)
out = out.view(-1, self.out_filters)
return self.fc(out), attention1, attention2, attention3
if __name__ == '__main__':
# change d and k if you want to check a model other than WRN-40-2
d = 40
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides)
# verify that an output is produced
sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
net(sample_input)
# Summarize model
summary(net, input_size=(3, 32, 32))

View file

@ -0,0 +1,36 @@
import torch
from torch.utils.data import random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42):
print(f"Train batch size: {train_batch_size}")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform)
#splitting data in half for teacher dataset and student dataset (use manual seed for consistency)
seed = torch.Generator().manual_seed(seed_val)
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
teacher_set = subsets[0]
student_set = subsets[1]
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
return teacher_set, student_set, testset

View file

@ -17,13 +17,14 @@ from tqdm import tqdm
import student_model import student_model
from utils import json_file_to_pyobj, get_loaders from utils import json_file_to_pyobj, get_loaders
from distillation_utils import get_teacherstudent_trainset
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--n_queries", default=2, type=int) parser.add_argument("--n_queries", default=2, type=int)
parser.add_argument("--model", default="resnet18", type=str) parser.add_argument("--model", default="resnet18", type=str)
parser.add_argument("--savedir", default="exp/cifar10", type=str) parser.add_argument("--savedir", default="exp/cifar10", type=str)
args = parser.parse_args() args = parser.parse_args()
SEED = 42
@torch.no_grad() @torch.no_grad()
def run(): def run():
@ -31,7 +32,12 @@ def run():
dataset = "cifar10" dataset = "cifar10"
# Dataset # Dataset
train_dl, test_dl = get_loaders(dataset, 4096) json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
train_ds, test_ds = studentset, testset
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)
# Infer the logits with multiple queries # Infer the logits with multiple queries
for path in os.listdir(args.savedir): for path in os.listdir(args.savedir):

View file

@ -23,11 +23,14 @@ from pathlib import Path
import numpy as np import numpy as np
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from distillation_utils import get_teacherstudent_trainset
from utils import json_file_to_pyobj, get_loaders
from torch.utils.data import DataLoader
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--savedir", default="exp/cifar10", type=str) parser.add_argument("--savedir", default="exp/cifar10", type=str)
args = parser.parse_args() args = parser.parse_args()
SEED = 42
def load_one(path): def load_one(path):
""" """
@ -56,9 +59,17 @@ def load_one(path):
def get_labels(): def get_labels():
# Dataset
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
# Get the indices of the student set
student_indices = studentset.indices
datadir = Path().home() / "opt/data/cifar" datadir = Path().home() / "opt/data/cifar"
train_ds = CIFAR10(root=datadir, train=True, download=True) train_ds = CIFAR10(root=datadir, train=True, download=True)
return np.array(train_ds.targets) # Access the original targets for the student set
student_targets = [train_ds.targets[i] for i in student_indices]
return np.array(student_targets)
def load_stats(): def load_stats():

View file

@ -22,9 +22,7 @@ import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
import torchvision import torchvision
from torchvision import transforms from torchvision import transforms
from distillation_utils import get_teacherstudent_trainset
#privacy libraries #privacy libraries
import opacus import opacus
from opacus.validators import ModuleValidator from opacus.validators import ModuleValidator
@ -37,6 +35,7 @@ import student_model
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
SEED = 42 #setting for testing
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--lr", default=0.1, type=float) parser.add_argument("--lr", default=0.1, type=float)
@ -113,7 +112,11 @@ def run(teacher, student):
wandb.config.update(args) wandb.config.update(args)
# Dataset # Dataset
train_ds, test_ds = get_trainset() #get specific student set
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
train_ds, test_ds = studentset, testset
# Compute the IN / OUT subset: # Compute the IN / OUT subset:
# If we run each experiment independently then even after a lot of trials # If we run each experiment independently then even after a lot of trials
# there will still probably be some examples that were always included # there will still probably be some examples that were always included
@ -215,7 +218,11 @@ def main():
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size) #get specific teacher set
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
loaders = trainloader, testloader
best_test_set_accuracy = 0 best_test_set_accuracy = 0
dp_epsilon = 8 dp_epsilon = 8
dp_delta = 1e-5 dp_delta = 1e-5
@ -224,14 +231,14 @@ def main():
teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=teacher, module=teacher,
optimizer=optimizer, optimizer=optimizer,
data_loader=train_loader, data_loader=trainloader,
epochs=epochs, epochs=epochs,
target_epsilon=dp_epsilon, target_epsilon=dp_epsilon,
target_delta=dp_delta, target_delta=dp_delta,
max_grad_norm=norm, max_grad_norm=norm,
) )
teacher.load_state_dict(torch.load(os.path.join("wrn-1733078278-8e-1e-05d-12.0n-dict.pt"), weights_only=True)) teacher.load_state_dict(torch.load(os.path.join("teachers_in/wrn-1733273741-8.0e-1e-05d-12.0n-dict.pt"), weights_only=True))
teacher.to(device) teacher.to(device)
teacher.eval() teacher.eval()
#instantiate student "shadow model" #instantiate student "shadow model"

62
lira-pytorch/utils.py Normal file
View file

@ -0,0 +1,62 @@
import json
import collections
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
# Borrowed from https://github.com/ozan-oktay/Attention-Gated-Networks
def json_file_to_pyobj(filename):
def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values())
def json2obj(data): return json.loads(data, object_hook=_json_object_hook)
return json2obj(open(filename).read())
def get_loaders(dataset, train_batch_size=128, test_batch_size=10):
print(f"Train batch size: {train_batch_size}")
if dataset == 'cifar10':
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
elif dataset == 'svhn':
normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
return trainloader, testloader

View file

@ -0,0 +1,11 @@
{
"training":{
"dataset": "CIFAR10",
"wrn_depth": 16,
"wrn_width": 1,
"checkpoint": "True",
"log": "True",
"batch_size": 4096,
"epochs": 200
}
}

View file

@ -17,9 +17,12 @@ import student_model
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
import opacus import opacus
from distillation_utils import get_teacherstudent_trainset
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
SEED = 42 #setting for testing
def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device): def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
# Dataset # Dataset
@ -116,7 +119,12 @@ def main():
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size) #get specific student set
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
teachertrainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
studenttrainloader = DataLoader(studentset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
best_test_set_accuracy = 0 best_test_set_accuracy = 0
if args.epsilon is not None: if args.epsilon is not None:
@ -127,7 +135,7 @@ def main():
teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=teacher, module=teacher,
optimizer=optimizer, optimizer=optimizer,
data_loader=train_loader, data_loader=teachertrainloader,
epochs=epochs, epochs=epochs,
target_epsilon=dp_epsilon, target_epsilon=dp_epsilon,
target_delta=dp_delta, target_delta=dp_delta,
@ -144,7 +152,7 @@ def main():
train_knowledge_distillation( train_knowledge_distillation(
teacher=teacher, teacher=teacher,
student=student, student=student,
train_dl=train_loader, train_dl=studenttrainloader,
epochs=args.epochs, epochs=args.epochs,
learning_rate=0.001, learning_rate=0.001,
T=2, T=2,
@ -157,8 +165,8 @@ def main():
torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt") torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt")
print("Testing student and teacher") print("Testing student and teacher")
test_student = test(student, device, test_loader) test_student = test(student, device, testloader)
test_teacher = test(teacher, device, test_loader, True) test_teacher = test(teacher, device, testloader, True)
print(f"Teacher accuracy: {test_teacher:.2f}%") print(f"Teacher accuracy: {test_teacher:.2f}%")
print(f"Student accuracy: {test_student:.2f}%") print(f"Student accuracy: {test_student:.2f}%")

View file

@ -0,0 +1,36 @@
import torch
from torch.utils.data import random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42):
print(f"Train batch size: {train_batch_size}")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform)
#splitting data in half for teacher dataset and student dataset (use manual seed for consistency)
seed = torch.Generator().manual_seed(seed_val)
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
teacher_set = subsets[0]
student_set = subsets[1]
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
return teacher_set, student_set, testset

View file

@ -7,14 +7,18 @@ import torch.nn as nn
import numpy as np import numpy as np
import random import random
from utils import json_file_to_pyobj, get_loaders from utils import json_file_to_pyobj, get_loaders
from distillation_utils import get_teacherstudent_trainset
from WideResNet import WideResNet from WideResNet import WideResNet
from tqdm import tqdm from tqdm import tqdm
import opacus import opacus
from opacus.validators import ModuleValidator from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager from opacus.utils.batch_memory_manager import BatchMemoryManager
from torch.utils.data import DataLoader
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
SEED = 0
def set_seed(seed=42): def set_seed(seed=42):
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
@ -22,7 +26,9 @@ def set_seed(seed=42):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
#doing for convenience fix later
global SEED
SEED = seed
def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, scheduler, test_loader, log, logfile, checkpointFile): def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, scheduler, test_loader, log, logfile, checkpointFile):
best_test_set_accuracy = 0 best_test_set_accuracy = 0
@ -80,7 +86,6 @@ def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, schedul
def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200, norm=1.0, dp_epsilon=None): def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200, norm=1.0, dp_epsilon=None):
train_loader, test_loader = loaders train_loader, test_loader = loaders
dp_delta = 1e-5 dp_delta = 1e-5
checkpointFile = 'wrn-{}-{}e-{}d-{}n-dict.pt'.format(int(time.time()), dp_epsilon, dp_delta, norm) checkpointFile = 'wrn-{}-{}e-{}d-{}n-dict.pt'.format(int(time.time()), dp_epsilon, dp_delta, norm)
@ -143,8 +148,11 @@ def train(args):
logfile = '' logfile = ''
checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False
loaders = get_loaders(dataset, training_configurations.batch_size) #get specific teacher set
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
loaders = trainloader, testloader
if torch.cuda.is_available() and args.cuda: if torch.cuda.is_available() and args.cuda:
device = torch.device(f'cuda:{args.cuda}') device = torch.device(f'cuda:{args.cuda}')
elif torch.cuda.is_available(): elif torch.cuda.is_available():
@ -165,7 +173,7 @@ def train(args):
net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides) net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides)
net = net.to(device) net = net.to(device)
epochs = training_configurations.epochs epochs = 100
best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, logfile, epochs, args.norm, args.epsilon) best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, logfile, epochs, args.norm, args.epsilon)
if log: if log: