diff --git a/wresnet-pytorch/src/distillation_train.py b/wresnet-pytorch/src/distillation_train.py index 5be5453..18b28b0 100644 --- a/wresnet-pytorch/src/distillation_train.py +++ b/wresnet-pytorch/src/distillation_train.py @@ -1,6 +1,6 @@ from datetime import datetime import time - +import argparse from utils import json_file_to_pyobj, get_loaders from WideResNet import WideResNet from opacus.validators import ModuleValidator @@ -13,7 +13,7 @@ import os import torch import torch.nn as nn from torchvision import models, transforms -import student_model +import student_model import torch.optim as optim import torch.nn.functional as F import opacus @@ -86,6 +86,14 @@ def test(model, device, test_dl, teacher=False): return accuracy def main(): + parser = argparse.ArgumentParser(description='Student trainer') + parser.add_argument('--teacher', type=Path, help='path to saved teacher .pt', required=True) + parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True) + parser.add_argument('--cuda', type=int, help='gpu index', required=False) + parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) + parser.add_argument('--epochs', type=int, help='student epochs', required=True) + args = parser.parse_args() + json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json") training_configurations = json_options.training @@ -93,7 +101,9 @@ def main(): wrn_width = training_configurations.wrn_width dataset = training_configurations.dataset.lower() - if torch.cuda.is_available(): + if args.cuda is not None: + device = torch.device(f'cuda:{args.cuda}') + elif torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') @@ -101,19 +111,21 @@ def main(): print("Load the teacher model") # instantiate teacher model - strides = [1, 1, 2, 2] + strides = [1, 1, 2, 2] teacher = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides) - teacher = ModuleValidator.fix(teacher) + teacher = ModuleValidator.fix(teacher) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size) best_test_set_accuracy = 0 - dp_epsilon = 8 - dp_delta = 1e-5 - norm = 1.0 - privacy_engine = opacus.PrivacyEngine() - teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( + + if args.epsilon is not None: + dp_epsilon = args.epsilon + dp_delta = 1e-5 + norm = args.norm + privacy_engine = opacus.PrivacyEngine() + teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon( module=teacher, optimizer=optimizer, data_loader=train_loader, @@ -123,19 +135,28 @@ def main(): max_grad_norm=norm, ) - teacher.load_state_dict(torch.load(os.path.join("wrn-1733078278-8e-1e-05d-12.0n-dict.pt"), weights_only=True)) + teacher.load_state_dict(torch.load(args.teacher, weights_only=True)) teacher.to(device) teacher.eval() #instantiate istudent student = student_model.Model(num_classes=10).to(device) - print("Training student") - train_knowledge_distillation(teacher=teacher, student=student, train_dl=train_loader, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device) - print("Saving student") - current_datetime = datetime.now() - filename = f"students/studentmodel{current_datetime.strftime('%Y%m%d_%H%M%S')}.pt" - torch.save(student.state_dict(), filename) + train_knowledge_distillation( + teacher=teacher, + student=student, + train_dl=train_loader, + epochs=args.epochs, + learning_rate=0.001, + T=2, + soft_target_loss_weight=0.25, + ce_loss_weight=0.75, + device=device + ) + print(f"Saving student model for time {int(time.time())}") + Path('students').mkdir(exist_ok=True) + torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt") + print("Testing student and teacher") test_student = test(student, device, test_loader,) test_teacher = test(teacher, device, test_loader, True) @@ -144,4 +165,5 @@ def main(): if __name__ == "__main__": + main()