# PyTorch implementation of # https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/train.py # # author: Chenxiang Zhang (orientino) #random stuff import os import argparse import time from pathlib import Path #torch stuff import numpy as np import pytorch_lightning as pl import torch import wandb from torch import nn from torch.utils.data import DataLoader from torchvision import models, transforms from torchvision.datasets import CIFAR10 from tqdm import tqdm from torch.optim.lr_scheduler import MultiStepLR import torch.optim as optim import torch.nn.functional as F import torchvision from torchvision import transforms #privacy libraries import opacus from opacus.validators import ModuleValidator #cutom modules from utils import json_file_to_pyobj, get_loaders from WideResNet import WideResNet import student_model #suppress warning import warnings warnings.filterwarnings("ignore") parser = argparse.ArgumentParser() parser.add_argument("--lr", default=0.1, type=float) parser.add_argument("--epochs", default=1, type=int) parser.add_argument("--n_shadows", default=16, type=int) parser.add_argument("--shadow_id", default=1, type=int) parser.add_argument("--model", default="resnet18", type=str) parser.add_argument("--pkeep", default=0.5, type=float) parser.add_argument("--savedir", default="exp/cifar10", type=str) parser.add_argument("--debug", action="store_true") args = parser.parse_args() DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") def get_trainset(train_batch_size=128, test_batch_size=10): print(f"Train batch size: {train_batch_size}") normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) train_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode='reflect').squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) test_transform = transforms.Compose([ transforms.ToTensor(), normalize ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=train_transform) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=test_transform) return trainset, testset @torch.no_grad() def test(model, test_dl, teacher=False): device = DEVICE model.to(device) model.eval() correct = 0 total = 0 for inputs, labels in test_dl: inputs, labels = inputs.to(device), labels.to(device) if teacher: outputs, _, _, _ = model(inputs) else: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f"Test Accuracy: {accuracy:.2f}%") return accuracy def run(teacher, student): device = DEVICE seed = np.random.randint(0, 1000000000) seed ^= int(time.time()) pl.seed_everything(seed) args.debug = True wandb.init(project="lira", mode="disabled" if args.debug else "online") wandb.config.update(args) # Dataset train_ds, test_ds = get_trainset() # Compute the IN / OUT subset: # If we run each experiment independently then even after a lot of trials # there will still probably be some examples that were always included # or always excluded. So instead, with experiment IDs, we guarantee that # after `args.n_shadows` are done, each example is seen exactly half # of the time in train, and half of the time not in train. size = len(train_ds) np.random.seed(seed) if args.n_shadows is not None: np.random.seed(0) keep = np.random.uniform(0, 1, size=(args.n_shadows, size)) order = keep.argsort(0) keep = order < int(args.pkeep * args.n_shadows) keep = np.array(keep[args.shadow_id], dtype=bool) keep = keep.nonzero()[0] else: keep = np.random.choice(size, size=int(args.pkeep * size), replace=False) keep.sort() keep_bool = np.full((size), False) keep_bool[keep] = True train_ds = torch.utils.data.Subset(train_ds, keep) train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4) # Train learning_rate=0.001 T=2 soft_target_loss_weight=0.25 ce_loss_weight=0.75 ce_loss = nn.CrossEntropyLoss() optimizer = optim.Adam(student.parameters(), lr=learning_rate) teacher.eval() # Teacher set to evaluation mode student.train() # Student to train mode for epoch in range(args.epochs): running_loss = 0.0 for inputs, labels in train_dl: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights with torch.no_grad(): teacher_logits, _, _, _ = teacher(inputs) # Forward pass with the student model student_logits = student(inputs) #Soften the student logits by applying softmax first and log() second soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1) soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1) # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network" soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2) # Calculate the true label loss label_loss = ce_loss(student_logits, labels) # Weighted sum of the two losses loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch {epoch+1}/{args.epochs}, Loss: {running_loss / len(train_dl)}") accuracy = test(student, test_dl) #saving models print("saving model") savedir = os.path.join(args.savedir, str(args.shadow_id)) os.makedirs(savedir, exist_ok=True) np.save(savedir + "/keep.npy", keep_bool) torch.save(student.state_dict(), savedir + "/model.pt") def main(): epochs = args.epochs json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json") training_configurations = json_options.training wrn_depth = training_configurations.wrn_depth wrn_width = training_configurations.wrn_width dataset = training_configurations.dataset.lower() if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') print("Load the teacher model") # instantiate teacher model strides = [1, 1, 2, 2] teacher = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides) teacher = ModuleValidator.fix(teacher) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size) best_test_set_accuracy = 0 dp_epsilon = 8 dp_delta = 1e-5 norm = 1.0 privacy_engine = opacus.PrivacyEngine() teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( module=teacher, optimizer=optimizer, data_loader=train_loader, epochs=epochs, target_epsilon=dp_epsilon, target_delta=dp_delta, max_grad_norm=norm, ) teacher.load_state_dict(torch.load(os.path.join("wrn-1733078278-8e-1e-05d-12.0n-dict.pt"), weights_only=True)) teacher.to(device) teacher.eval() #instantiate student "shadow model" student = student_model.Model(num_classes=10).to(device) # Check norm of layer for both networks -- student should be smaller? print("Norm of 1st layer for teacher:", torch.norm(teacher.conv1.weight).item()) print("Norm of 1st layer for student:", torch.norm(student.features[0].weight).item()) #train student shadow model run(teacher=teacher, student=student) if __name__ == "__main__": main()