from utils import json_file_to_pyobj, get_loaders from WideResNet import WideResNet from opacus.validators import ModuleValidator import os from pathlib import Path from torch.optim.lr_scheduler import MultiStepLR from torchvision.datasets import CIFAR10 from torch.utils.data import DataLoader import os import torch import torch.nn as nn from torchvision import models, transforms import student_model import torch.optim as optim import torch.nn.functional as F import opacus import warnings warnings.filterwarnings("ignore") def train_knowledge_distillation(teacher, student, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device): # Dataset transform = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]), ] ) datadir = Path().home() / "opt/data/cifar" train_ds = CIFAR10(root=datadir, train=True, download=True, transform=transform) train_dl = DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=4) ce_loss = nn.CrossEntropyLoss() optimizer = optim.Adam(student.parameters(), lr=learning_rate) teacher.eval() # Teacher set to evaluation mode student.train() # Student to train mode for epoch in range(epochs): running_loss = 0.0 for inputs, labels in train_dl: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights with torch.no_grad(): teacher_logits, _, _, _ = teacher(inputs) # Forward pass with the student model student_logits = student(inputs) #Soften the student logits by applying softmax first and log() second soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1) soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1) # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network" soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2) # Calculate the true label loss label_loss = ce_loss(student_logits, labels) # Weighted sum of the two losses loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}") def test(model, device, teacher=False): transform = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]), ] ) datadir = Path().home() / "opt/data/cifar" test_ds = CIFAR10(root=datadir, train=True, download=False, transform=transform) test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4 ) model.to(device) model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in test_dl: inputs, labels = inputs.to(device), labels.to(device) if teacher: outputs, _, _, _ = model(inputs) else: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f"Test Accuracy: {accuracy:.2f}%") return accuracy def main(): json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json") training_configurations = json_options.training wrn_depth = training_configurations.wrn_depth wrn_width = training_configurations.wrn_width dataset = training_configurations.dataset.lower() if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') epochs=10 print("Load the teacher model") # instantiate teacher model strides = [1, 1, 2, 2] teacher = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides) teacher = ModuleValidator.fix(teacher) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size) best_test_set_accuracy = 0 dp_epsilon = 8 dp_delta = 1e-5 norm = 1.0 privacy_engine = opacus.PrivacyEngine() teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( module=teacher, optimizer=optimizer, data_loader=train_loader, epochs=epochs, target_epsilon=dp_epsilon, target_delta=dp_delta, max_grad_norm=norm, ) teacher.load_state_dict(torch.load(os.path.join("wrn-1733078278-8e-1e-05d-12.0n-dict.pt"), weights_only=True)) teacher.to(device) teacher.eval() #instantiate istudent student = student_model.Model(num_classes=10).to(device) print("Training student") #train_knowledge_distillation(teacher=teacher, student=student, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device) #test_student = test(student, device) test_teacher = test(teacher, device, True) print(f"Teacher accuracy: {test_teacher:.2f}%") #print(f"Student accuracy: {test_student:.2f}%") if __name__ == "__main__": main()