from datetime import datetime import time from utils import json_file_to_pyobj, get_loaders from WideResNet import WideResNet from opacus.validators import ModuleValidator import os from pathlib import Path from torch.optim.lr_scheduler import MultiStepLR from torchvision.datasets import CIFAR10 from torch.utils.data import DataLoader import os import torch import torch.nn as nn from torchvision import models, transforms import student_model import torch.optim as optim import torch.nn.functional as F import opacus import warnings warnings.filterwarnings("ignore") def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device): # Dataset ce_loss = nn.CrossEntropyLoss() optimizer = optim.Adam(student.parameters(), lr=learning_rate) teacher.eval() # Teacher set to evaluation mode student.train() # Student to train mode for epoch in range(epochs): running_loss = 0.0 for inputs, labels in train_dl: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights with torch.no_grad(): teacher_logits, _, _, _ = teacher(inputs) # Forward pass with the student model student_logits = student(inputs) #Soften the student logits by applying softmax first and log() second soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1) soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1) # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network" soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2) # Calculate the true label loss label_loss = ce_loss(student_logits, labels) # Weighted sum of the two losses loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}") @torch.no_grad() def test(model, device, test_dl, teacher=False): model.to(device) model.eval() correct = 0 total = 0 for inputs, labels in test_dl: inputs, labels = inputs.to(device), labels.to(device) if teacher: outputs, _, _, _ = model(inputs) else: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f"Test Accuracy: {accuracy:.2f}%") return accuracy def main(): json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json") training_configurations = json_options.training wrn_depth = training_configurations.wrn_depth wrn_width = training_configurations.wrn_width dataset = training_configurations.dataset.lower() if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') epochs=10 print("Load the teacher model") # instantiate teacher model strides = [1, 1, 2, 2] teacher = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides) teacher = ModuleValidator.fix(teacher) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size) best_test_set_accuracy = 0 dp_epsilon = 8 dp_delta = 1e-5 norm = 1.0 privacy_engine = opacus.PrivacyEngine() teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( module=teacher, optimizer=optimizer, data_loader=train_loader, epochs=epochs, target_epsilon=dp_epsilon, target_delta=dp_delta, max_grad_norm=norm, ) teacher.load_state_dict(torch.load(os.path.join("wrn-1733078278-8e-1e-05d-12.0n-dict.pt"), weights_only=True)) teacher.to(device) teacher.eval() #instantiate istudent student = student_model.Model(num_classes=10).to(device) print("Training student") train_knowledge_distillation(teacher=teacher, student=student, train_dl=train_loader, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device) print("Saving student") current_datetime = datetime.now() filename = f"students/studentmodel{current_datetime.strftime('%Y%m%d_%H%M%S')}.pt" torch.save(student.state_dict(), filename) print("Testing student and teacher") test_student = test(student, device, test_loader,) test_teacher = test(teacher, device, test_loader, True) print(f"Teacher accuracy: {test_teacher:.2f}%") print(f"Student accuracy: {test_student:.2f}%") if __name__ == "__main__": main()