diff --git a/lira-pytorch/WideResNet.py b/lira-pytorch/WideResNet.py new file mode 100644 index 0000000..024818e --- /dev/null +++ b/lira-pytorch/WideResNet.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn +from torchsummary import summary +import math + + +class IndividualBlock1(nn.Module): + def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True): + super(IndividualBlock1, self).__init__() + + self.activation = nn.ReLU(inplace=True) + + self.batch_norm1 = nn.BatchNorm2d(input_features) + self.batch_norm2 = nn.BatchNorm2d(output_features) + + self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False) + + self.subsample_input = subsample_input + self.increase_filters = increase_filters + if subsample_input: + self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False) + elif increase_filters: + self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False) + + def forward(self, x): + + if self.subsample_input or self.increase_filters: + x = self.batch_norm1(x) + x = self.activation(x) + x1 = self.conv1(x) + else: + x1 = self.batch_norm1(x) + x1 = self.activation(x1) + x1 = self.conv1(x1) + x1 = self.batch_norm2(x1) + x1 = self.activation(x1) + x1 = self.conv2(x1) + + if self.subsample_input or self.increase_filters: + return self.conv_inp(x) + x1 + else: + return x + x1 + + +class IndividualBlockN(nn.Module): + def __init__(self, input_features, output_features, stride): + super(IndividualBlockN, self).__init__() + + self.activation = nn.ReLU(inplace=True) + + self.batch_norm1 = nn.BatchNorm2d(input_features) + self.batch_norm2 = nn.BatchNorm2d(output_features) + + self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False) + + def forward(self, x): + x1 = self.batch_norm1(x) + x1 = self.activation(x1) + x1 = self.conv1(x1) + x1 = self.batch_norm2(x1) + x1 = self.activation(x1) + x1 = self.conv2(x1) + + return x1 + x + + +class Nblock(nn.Module): + + def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True): + super(Nblock, self).__init__() + + layers = [] + for i in range(N): + if i == 0: + layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters)) + else: + layers.append(IndividualBlockN(output_features, output_features, stride=1)) + + self.nblockLayer = nn.Sequential(*layers) + + def forward(self, x): + return self.nblockLayer(x) + + +class WideResNet(nn.Module): + + def __init__(self, d, k, n_classes, input_features, output_features, strides): + super(WideResNet, self).__init__() + + self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False) + + filters = [16 * k, 32 * k, 64 * k] + self.out_filters = filters[-1] + N = (d - 4) // 6 + increase_filters = k > 1 + self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters) + self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2]) + self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3]) + + self.batch_norm = nn.BatchNorm2d(filters[-1]) + self.activation = nn.ReLU(inplace=True) + self.avg_pool = nn.AvgPool2d(kernel_size=8) + self.fc = nn.Linear(filters[-1], n_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.bias.data.zero_() + + def forward(self, x): + x = self.conv1(x) + attention1 = self.block1(x) + attention2 = self.block2(attention1) + attention3 = self.block3(attention2) + out = self.batch_norm(attention3) + out = self.activation(out) + out = self.avg_pool(out) + out = out.view(-1, self.out_filters) + + return self.fc(out), attention1, attention2, attention3 + + +if __name__ == '__main__': + + # change d and k if you want to check a model other than WRN-40-2 + d = 40 + k = 2 + strides = [1, 1, 2, 2] + net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides) + + # verify that an output is produced + sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False) + net(sample_input) + + # Summarize model + summary(net, input_size=(3, 32, 32)) diff --git a/lira-pytorch/distillation_utils.py b/lira-pytorch/distillation_utils.py new file mode 100644 index 0000000..302a34a --- /dev/null +++ b/lira-pytorch/distillation_utils.py @@ -0,0 +1,36 @@ +import torch +from torch.utils.data import random_split +import torchvision +from torchvision import transforms +from torchvision.datasets import CIFAR10 +import torch.nn.functional as F + +def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42): + print(f"Train batch size: {train_batch_size}") + normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + train_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), + (4, 4, 4, 4), mode='reflect').squeeze()), + transforms.ToPILImage(), + transforms.RandomCrop(32), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ]) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + normalize + ]) + + trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform) + + #splitting data in half for teacher dataset and student dataset (use manual seed for consistency) + seed = torch.Generator().manual_seed(seed_val) + subsets = random_split(trainset, [0.5, 0.5], generator=seed) + teacher_set = subsets[0] + student_set = subsets[1] + testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform) + + return teacher_set, student_set, testset diff --git a/lira-pytorch/inference.py b/lira-pytorch/inference.py index 0afb0e0..9f4b34d 100644 --- a/lira-pytorch/inference.py +++ b/lira-pytorch/inference.py @@ -17,13 +17,14 @@ from tqdm import tqdm import student_model from utils import json_file_to_pyobj, get_loaders +from distillation_utils import get_teacherstudent_trainset parser = argparse.ArgumentParser() parser.add_argument("--n_queries", default=2, type=int) parser.add_argument("--model", default="resnet18", type=str) parser.add_argument("--savedir", default="exp/cifar10", type=str) args = parser.parse_args() - +SEED = 42 @torch.no_grad() def run(): @@ -31,7 +32,12 @@ def run(): dataset = "cifar10" # Dataset - train_dl, test_dl = get_loaders(dataset, 4096) + json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json") + training_configurations = json_options.training + teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED) + train_ds, test_ds = studentset, testset + train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4) + test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4) # Infer the logits with multiple queries for path in os.listdir(args.savedir): diff --git a/lira-pytorch/score.py b/lira-pytorch/score.py index 68933fa..48b872b 100644 --- a/lira-pytorch/score.py +++ b/lira-pytorch/score.py @@ -23,11 +23,14 @@ from pathlib import Path import numpy as np from torchvision.datasets import CIFAR10 +from distillation_utils import get_teacherstudent_trainset +from utils import json_file_to_pyobj, get_loaders +from torch.utils.data import DataLoader parser = argparse.ArgumentParser() parser.add_argument("--savedir", default="exp/cifar10", type=str) args = parser.parse_args() - +SEED = 42 def load_one(path): """ @@ -56,9 +59,17 @@ def load_one(path): def get_labels(): + # Dataset + json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json") + training_configurations = json_options.training + teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED) + # Get the indices of the student set + student_indices = studentset.indices datadir = Path().home() / "opt/data/cifar" train_ds = CIFAR10(root=datadir, train=True, download=True) - return np.array(train_ds.targets) + # Access the original targets for the student set + student_targets = [train_ds.targets[i] for i in student_indices] + return np.array(student_targets) def load_stats(): diff --git a/lira-pytorch/student_shadow_train.py b/lira-pytorch/student_shadow_train.py index 5495bc7..c7fd360 100644 --- a/lira-pytorch/student_shadow_train.py +++ b/lira-pytorch/student_shadow_train.py @@ -22,9 +22,7 @@ import torch.optim as optim import torch.nn.functional as F import torchvision from torchvision import transforms - - - +from distillation_utils import get_teacherstudent_trainset #privacy libraries import opacus from opacus.validators import ModuleValidator @@ -37,6 +35,7 @@ import student_model import warnings warnings.filterwarnings("ignore") +SEED = 42 #setting for testing parser = argparse.ArgumentParser() parser.add_argument("--lr", default=0.1, type=float) @@ -113,7 +112,11 @@ def run(teacher, student): wandb.config.update(args) # Dataset - train_ds, test_ds = get_trainset() + #get specific student set + json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json") + training_configurations = json_options.training + teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED) + train_ds, test_ds = studentset, testset # Compute the IN / OUT subset: # If we run each experiment independently then even after a lot of trials # there will still probably be some examples that were always included @@ -215,7 +218,11 @@ def main(): criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) - train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size) + #get specific teacher set + teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED) + trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4) + testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4) + loaders = trainloader, testloader best_test_set_accuracy = 0 dp_epsilon = 8 dp_delta = 1e-5 @@ -224,14 +231,14 @@ def main(): teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( module=teacher, optimizer=optimizer, - data_loader=train_loader, + data_loader=trainloader, epochs=epochs, target_epsilon=dp_epsilon, target_delta=dp_delta, max_grad_norm=norm, ) - teacher.load_state_dict(torch.load(os.path.join("wrn-1733078278-8e-1e-05d-12.0n-dict.pt"), weights_only=True)) + teacher.load_state_dict(torch.load(os.path.join("teachers_in/wrn-1733273741-8.0e-1e-05d-12.0n-dict.pt"), weights_only=True)) teacher.to(device) teacher.eval() #instantiate student "shadow model" diff --git a/lira-pytorch/utils.py b/lira-pytorch/utils.py new file mode 100644 index 0000000..bc01f92 --- /dev/null +++ b/lira-pytorch/utils.py @@ -0,0 +1,62 @@ +import json +import collections +import torchvision +from torchvision import transforms +from torch.utils.data import DataLoader +import torch.nn.functional as F + + +# Borrowed from https://github.com/ozan-oktay/Attention-Gated-Networks +def json_file_to_pyobj(filename): + def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values()) + + def json2obj(data): return json.loads(data, object_hook=_json_object_hook) + + return json2obj(open(filename).read()) + + +def get_loaders(dataset, train_batch_size=128, test_batch_size=10): + print(f"Train batch size: {train_batch_size}") + + if dataset == 'cifar10': + normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + + train_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), + (4, 4, 4, 4), mode='reflect').squeeze()), + transforms.ToPILImage(), + transforms.RandomCrop(32), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ]) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + normalize + ]) + + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) + trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4) + + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform) + testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4) + + elif dataset == 'svhn': + normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)) + + transform = transforms.Compose([ + transforms.ToTensor(), + normalize, + ]) + + trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform) + trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4) + + testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform) + testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4) + + return trainloader, testloader + + diff --git a/lira-pytorch/wresnet16-audit-cifar10.json b/lira-pytorch/wresnet16-audit-cifar10.json new file mode 100644 index 0000000..f78b650 --- /dev/null +++ b/lira-pytorch/wresnet16-audit-cifar10.json @@ -0,0 +1,11 @@ +{ + "training":{ + "dataset": "CIFAR10", + "wrn_depth": 16, + "wrn_width": 1, + "checkpoint": "True", + "log": "True", + "batch_size": 4096, + "epochs": 200 + } +} diff --git a/wresnet-pytorch/src/distillation_train.py b/wresnet-pytorch/src/distillation_train.py index 0ea1daf..ec4ec26 100644 --- a/wresnet-pytorch/src/distillation_train.py +++ b/wresnet-pytorch/src/distillation_train.py @@ -17,9 +17,12 @@ import student_model import torch.optim as optim import torch.nn.functional as F import opacus +from distillation_utils import get_teacherstudent_trainset + + import warnings warnings.filterwarnings("ignore") - +SEED = 42 #setting for testing def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device): # Dataset @@ -116,7 +119,12 @@ def main(): criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) - train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size) + #get specific student set + teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED) + teachertrainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4) + studenttrainloader = DataLoader(studentset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4) + testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4) + best_test_set_accuracy = 0 if args.epsilon is not None: @@ -127,7 +135,7 @@ def main(): teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( module=teacher, optimizer=optimizer, - data_loader=train_loader, + data_loader=teachertrainloader, epochs=epochs, target_epsilon=dp_epsilon, target_delta=dp_delta, @@ -144,7 +152,7 @@ def main(): train_knowledge_distillation( teacher=teacher, student=student, - train_dl=train_loader, + train_dl=studenttrainloader, epochs=args.epochs, learning_rate=0.001, T=2, @@ -157,8 +165,8 @@ def main(): torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt") print("Testing student and teacher") - test_student = test(student, device, test_loader) - test_teacher = test(teacher, device, test_loader, True) + test_student = test(student, device, testloader) + test_teacher = test(teacher, device, testloader, True) print(f"Teacher accuracy: {test_teacher:.2f}%") print(f"Student accuracy: {test_student:.2f}%") diff --git a/wresnet-pytorch/src/distillation_utils.py b/wresnet-pytorch/src/distillation_utils.py new file mode 100644 index 0000000..302a34a --- /dev/null +++ b/wresnet-pytorch/src/distillation_utils.py @@ -0,0 +1,36 @@ +import torch +from torch.utils.data import random_split +import torchvision +from torchvision import transforms +from torchvision.datasets import CIFAR10 +import torch.nn.functional as F + +def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42): + print(f"Train batch size: {train_batch_size}") + normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + train_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), + (4, 4, 4, 4), mode='reflect').squeeze()), + transforms.ToPILImage(), + transforms.RandomCrop(32), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ]) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + normalize + ]) + + trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform) + + #splitting data in half for teacher dataset and student dataset (use manual seed for consistency) + seed = torch.Generator().manual_seed(seed_val) + subsets = random_split(trainset, [0.5, 0.5], generator=seed) + teacher_set = subsets[0] + student_set = subsets[1] + testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform) + + return teacher_set, student_set, testset diff --git a/wresnet-pytorch/src/train.py b/wresnet-pytorch/src/train.py index cff770f..4a60f53 100644 --- a/wresnet-pytorch/src/train.py +++ b/wresnet-pytorch/src/train.py @@ -7,14 +7,18 @@ import torch.nn as nn import numpy as np import random from utils import json_file_to_pyobj, get_loaders +from distillation_utils import get_teacherstudent_trainset from WideResNet import WideResNet from tqdm import tqdm import opacus from opacus.validators import ModuleValidator from opacus.utils.batch_memory_manager import BatchMemoryManager +from torch.utils.data import DataLoader + import warnings warnings.filterwarnings("ignore") +SEED = 0 def set_seed(seed=42): torch.backends.cudnn.deterministic = True @@ -22,7 +26,9 @@ def set_seed(seed=42): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) - + #doing for convenience fix later + global SEED + SEED = seed def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, scheduler, test_loader, log, logfile, checkpointFile): best_test_set_accuracy = 0 @@ -80,7 +86,6 @@ def train_no_cap(net, epochs, data_loader, device, optimizer, criterion, schedul def _train_seed(net, loaders, device, dataset, log=False, logfile='', epochs=200, norm=1.0, dp_epsilon=None): train_loader, test_loader = loaders - dp_delta = 1e-5 checkpointFile = 'wrn-{}-{}e-{}d-{}n-dict.pt'.format(int(time.time()), dp_epsilon, dp_delta, norm) @@ -143,8 +148,11 @@ def train(args): logfile = '' checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False - loaders = get_loaders(dataset, training_configurations.batch_size) - + #get specific teacher set + teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED) + trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4) + testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4) + loaders = trainloader, testloader if torch.cuda.is_available() and args.cuda: device = torch.device(f'cuda:{args.cuda}') elif torch.cuda.is_available(): @@ -165,7 +173,7 @@ def train(args): net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides) net = net.to(device) - epochs = training_configurations.epochs + epochs = 100 best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, logfile, epochs, args.norm, args.epsilon) if log: