import argparse import equations import numpy as np import time import copy import torch import torch.nn as nn from torch import optim from torch.optim.lr_scheduler import MultiStepLR from torch.utils.data import DataLoader, Subset, TensorDataset import torch.nn.functional as F from pathlib import Path from torchvision import transforms from torchvision.datasets import CIFAR10 import pytorch_lightning as pl import opacus from opacus.validators import ModuleValidator from opacus.utils.batch_memory_manager import BatchMemoryManager from WideResNet import WideResNet from equations import get_eps_audit import student_model import warnings warnings.filterwarnings("ignore") DEVICE = None STUDENTBOOL = False def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75): #instantiate istudent student = student_model.Model(num_classes=10).to(device) ce_loss = nn.CrossEntropyLoss() optimizer = optim.Adam(student.parameters(), lr=learning_rate) student_init = copy.deepcopy(student) student.to(device) teacher.to(device) teacher.eval() # Teacher set to evaluation mode student.train() # Student to train mode for epoch in range(epochs): running_loss = 0.0 for inputs, labels in train_dl: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights with torch.no_grad(): teacher_logits, _, _, _ = teacher(inputs) # Forward pass with the student model student_logits = student(inputs) #Soften the student logits by applying softmax first and log() second soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1) soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1) # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network" soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2) # Calculate the true label loss label_loss = ce_loss(student_logits, labels) # Weighted sum of the two losses loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss loss.backward() optimizer.step() running_loss += loss.item() if epoch % 10 == 0: print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}") return student_init, student def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): seed = np.random.randint(0, 1e9) seed ^= int(time.time()) pl.seed_everything(seed) train_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode='reflect').squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) datadir = Path("./data") train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform) # Original dataset x = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms p = np.random.permutation(len(train_ds)) # Choose m points to randomly exclude at chance S = np.full(len(train_ds), True) S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out # Store the m points which could have been included/excluded mask = np.full(len(train_ds), False) mask[:m] = True mask = mask[p] x_m = x[mask] # These are the points being guessed at y_m = np.array(train_ds.targets)[mask].astype(np.int64) S_m = S[p][mask] # Ground truth of inclusion/exclusion for x_m # Remove excluded points from dataset x_in = x[S[p]] y_in = np.array(train_ds.targets).astype(np.int64) y_in = y_in[S[p]] td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long()) train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4) return train_dl, test_dl, x_in, x_m, y_m, S_m def evaluate_on(model, dataloader): correct = 0 total = 0 with torch.no_grad(): model.eval() for data in dataloader: images, labels = data images = images.to(DEVICE) labels = labels.to(DEVICE) wrn_outputs = model(images) if STUDENTBOOL: outputs = wrn_outputs else: outputs = wrn_outputs[0] _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() return correct, total def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler): best_test_set_accuracy = 0 for epoch in range(hp['epochs']): model.train() for i, data in enumerate(train_dl, 0): inputs, labels = data inputs = inputs.to(DEVICE) labels = labels.to(DEVICE) optimizer.zero_grad() wrn_outputs = model(inputs) outputs = wrn_outputs[0] loss = criterion(outputs, labels) loss.backward() optimizer.step() scheduler.step() if epoch % 10 == 0 or epoch == hp['epochs'] - 1: correct, total = evaluate_on(model, test_dl) epoch_accuracy = round(100 * correct / total, 2) print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%") return best_test_set_accuracy def train(hp, train_dl, test_dl): model = WideResNet( d=hp["wrn_depth"], k=hp["wrn_width"], n_classes=10, input_features=3, output_features=16, strides=[1, 1, 2, 2], ) model = model.to(DEVICE) model = ModuleValidator.fix(model) ModuleValidator.validate(model, strict=True) model_init = copy.deepcopy(model) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD( model.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4 ) scheduler = MultiStepLR( optimizer, milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]], gamma=0.2 ) print(f"Training with {hp['epochs']} epochs") if hp['epsilon'] is not None: privacy_engine = opacus.PrivacyEngine() model, optimizer, train_loader = privacy_engine.make_private_with_epsilon( module=model, optimizer=optimizer, data_loader=train_dl, epochs=hp['epochs'], target_epsilon=hp['epsilon'], target_delta=hp['delta'], max_grad_norm=hp['norm'], ) print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}") print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}") with BatchMemoryManager( data_loader=train_loader, max_physical_batch_size=2000, # 1000 ~= 9.4GB vram optimizer=optimizer ) as memory_safe_data_loader: best_test_set_accuracy = train_no_cap( model, hp, memory_safe_data_loader, test_dl, optimizer, criterion, scheduler, ) else: print("Training without differential privacy") best_test_set_accuracy = train_no_cap( model, hp, train_dl, test_dl, optimizer, criterion, scheduler, ) return model_init, model def main(): global DEVICE parser = argparse.ArgumentParser(description='WideResNet O1 audit') parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True) parser.add_argument('--cuda', type=int, help='gpu index', required=False) parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) parser.add_argument('--m', type=int, help='number of target points', required=True) parser.add_argument('--auditmodel', type=str, help='type of model to audit', default="teacher") args = parser.parse_args() if torch.cuda.is_available() and args.cuda: DEVICE = torch.device(f'cuda:{args.cuda}') elif torch.cuda.is_available(): DEVICE = torch.device('cuda:0') else: DEVICE = torch.device('cpu') hp = { "target_points": args.m, "wrn_depth": 16, "wrn_width": 1, "epsilon": args.epsilon, "delta": 1e-5, "norm": args.norm, "batch_size": 4096, "epochs": 100, "k+": 200, "k-": 200, "p_value": 0.05, } hp['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format( int(time.time()), hp['wrn_depth'], hp['wrn_width'], hp['batch_size'], hp['epochs'], hp['epsilon'], hp['delta'], hp['norm'], )) train_dl, test_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size']) print(f"len train: {len(train_dl)}") print(f"Got vector Sm: {S_m.shape}, sum={np.sum(S_m)}") print(f"Got x_in: {x_in.shape}") print(f"Got x_m: {x_m.shape}") print(f"Got y_m: {y_m.shape}") # torch.save(model_init.state_dict(), "data/init_model.pt") # torch.save(model_trained.state_dict(), "data/trained_model.pt") if args.auditmodel == "student": global STUDENTBOOL teacher_init, teacher_trained = train(hp, train_dl, test_dl) STUDENTBOOL = True # torch.save(model_init.state_dict(), "data/init_model.pt") # torch.save(model_trained.state_dict(), "data/trained_model.pt") #train student model print("Training Student Model") model_init, model_trained = train_knowledge_distillation( teacher=teacher_trained, train_dl=train_dl, epochs=100, device=DEVICE, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, ) stcorrect, sttotal = evaluate_on(model_trained, test_dl) stacc = stcorrect/sttotal*100 print(f"Student Accuracy: {stacc}%") else: model_init, model_trained = train(hp, train_dl, test_dl) scores = list() criterion = nn.CrossEntropyLoss() with torch.no_grad(): model_init.eval() x_m = torch.from_numpy(x_m).to(DEVICE) y_m = torch.from_numpy(y_m).long().to(DEVICE) for i in range(len(x_m)): x_point = x_m[i].unsqueeze(0) y_point = y_m[i].unsqueeze(0) is_in = S_m[i] if STUDENTBOOL: init_loss = criterion(model_init(x_point), y_point) trained_loss = criterion(model_trained(x_point), y_point) else: init_loss = criterion(model_init(x_point)[0], y_point) trained_loss = criterion(model_trained(x_point)[0], y_point) scores.append(((init_loss - trained_loss).item(), is_in)) scores = sorted(scores, key=lambda x: x[0]) scores = np.array([x[1] for x in scores]) print(scores[:10]) correct = np.sum(~scores[:hp['k-']]) + np.sum(scores[-hp['k+']:]) total = len(scores) eps_lb = get_eps_audit( hp['target_points'], hp['k+'] + hp['k-'], correct, hp['delta'], hp['p_value'] ) print(f"Audit total: {correct}/{total} = {round(correct/total*100, 2)}") print(f"p[ε < {eps_lb}] < {hp['p_value']}") correct, total = evaluate_on(model_init, train_dl) print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}") correct, total = evaluate_on(model_trained, test_dl) print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}") if __name__ == '__main__': main()