diff --git a/wresnet-pytorch/src/distillation_train.py b/wresnet-pytorch/src/distillation_train.py index 7ad9cc4..5be5453 100644 --- a/wresnet-pytorch/src/distillation_train.py +++ b/wresnet-pytorch/src/distillation_train.py @@ -1,3 +1,6 @@ +from datetime import datetime +import time + from utils import json_file_to_pyobj, get_loaders from WideResNet import WideResNet from opacus.validators import ModuleValidator @@ -18,21 +21,8 @@ import warnings warnings.filterwarnings("ignore") -def train_knowledge_distillation(teacher, student, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device): +def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device): # Dataset - transform = transforms.Compose( - [ - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32, padding=4), - transforms.ToTensor(), - transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]), - ] - ) - datadir = Path().home() / "opt/data/cifar" - train_ds = CIFAR10(root=datadir, train=True, download=True, transform=transform) - train_dl = DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=4) - - ce_loss = nn.CrossEntropyLoss() optimizer = optim.Adam(student.parameters(), lr=learning_rate) @@ -72,44 +62,29 @@ def train_knowledge_distillation(teacher, student, epochs, learning_rate, T, sof print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}") - - -def test(model, device, teacher=False): - transform = transforms.Compose( - [ - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32, padding=4), - transforms.ToTensor(), - transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]), - ] - ) - datadir = Path().home() / "opt/data/cifar" - test_ds = CIFAR10(root=datadir, train=True, download=False, transform=transform) - test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4 - ) +@torch.no_grad() +def test(model, device, test_dl, teacher=False): model.to(device) model.eval() correct = 0 total = 0 - with torch.no_grad(): - for inputs, labels in test_dl: - inputs, labels = inputs.to(device), labels.to(device) - if teacher: - outputs, _, _, _ = model(inputs) - else: - outputs = model(inputs) - _, predicted = torch.max(outputs.data, 1) + for inputs, labels in test_dl: + inputs, labels = inputs.to(device), labels.to(device) + if teacher: + outputs, _, _, _ = model(inputs) + else: + outputs = model(inputs) + _, predicted = torch.max(outputs.data, 1) - total += labels.size(0) - correct += (predicted == labels).sum().item() + total += labels.size(0) + correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f"Test Accuracy: {accuracy:.2f}%") return accuracy - def main(): json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json") training_configurations = json_options.training @@ -156,13 +131,17 @@ def main(): print("Training student") - #train_knowledge_distillation(teacher=teacher, student=student, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device) - #test_student = test(student, device) - test_teacher = test(teacher, device, True) + train_knowledge_distillation(teacher=teacher, student=student, train_dl=train_loader, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device) + print("Saving student") + current_datetime = datetime.now() + filename = f"students/studentmodel{current_datetime.strftime('%Y%m%d_%H%M%S')}.pt" + torch.save(student.state_dict(), filename) + print("Testing student and teacher") + test_student = test(student, device, test_loader,) + test_teacher = test(teacher, device, test_loader, True) print(f"Teacher accuracy: {test_teacher:.2f}%") - #print(f"Student accuracy: {test_student:.2f}%") + print(f"Student accuracy: {test_student:.2f}%") if __name__ == "__main__": main() -