diff --git a/one_run_audit/audit.py b/one_run_audit/audit.py index 5c35ca2..4aefae6 100644 --- a/one_run_audit/audit.py +++ b/one_run_audit/audit.py @@ -2,6 +2,7 @@ import argparse import equations import numpy as np import time +import copy import torch import torch.nn as nn from torch import optim @@ -66,12 +67,33 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): return train_dl, test_dl, x_in, x_m, S[p] -def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler): +def evaluate_on(model, dataloader): + correct = 0 + total = 0 + + with torch.no_grad(): + model.eval() + + for data in dataloader: + images, labels = data + images = images.to(DEVICE) + labels = labels.to(DEVICE) + + wrn_outputs = model(images) + outputs = wrn_outputs[0] + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + return correct, total + + +def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler): best_test_set_accuracy = 0 for epoch in range(hp['epochs']): model.train() - for i, data in enumerate(train_loader, 0): + for i, data in enumerate(train_dl, 0): inputs, labels = data inputs = inputs.to(DEVICE) labels = labels.to(DEVICE) @@ -87,30 +109,14 @@ def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, sch scheduler.step() if epoch % 10 == 0 or epoch == hp['epochs'] - 1: - with torch.no_grad(): - correct = 0 - total = 0 - - model.eval() - for data in test_loader: - images, labels = data - images = images.to(DEVICE) - labels = labels.to(DEVICE) - - wrn_outputs = model(images) - outputs = wrn_outputs[0] - _, predicted = torch.max(outputs.data, 1) - total += labels.size(0) - correct += (predicted == labels).sum().item() - - epoch_accuracy = correct / total - epoch_accuracy = round(100 * epoch_accuracy, 2) - print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%") + correct, total = evaluate_on(model, test_dl) + epoch_accuracy = round(100 * correct / total, 2) + print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%") return best_test_set_accuracy -def train(hp): +def train(hp, train_dl, test_dl): model = WideResNet( d=hp["wrn_depth"], k=hp["wrn_width"], @@ -123,6 +129,8 @@ def train(hp): model = ModuleValidator.fix(model) ModuleValidator.validate(model, strict=True) + model_init = copy.deepcopy(model) + criterion = nn.CrossEntropyLoss() optimizer = optim.SGD( model.parameters(), @@ -137,12 +145,6 @@ def train(hp): gamma=0.2 ) - train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size']) - print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}") - print(f"Got x_in: {x_in.shape}") - print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}") - print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}") - print(f"Got train dataloader: {len(train_dl)}") print(f"Training with {hp['epochs']} epochs") if hp['epsilon'] is not None: @@ -186,7 +188,7 @@ def train(hp): scheduler, ) - return model + return model_init, model def main(): @@ -206,7 +208,7 @@ def main(): else: DEVICE = torch.device('cpu') - hyperparams = { + hp = { "target_points": args.m, "wrn_depth": 16, "wrn_width": 1, @@ -214,22 +216,34 @@ def main(): "delta": 1e-5, "norm": args.norm, "batch_size": 4096, - "epochs": 200, + "epochs": 20, } - hyperparams['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format( + hp['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format( int(time.time()), - hyperparams['wrn_depth'], - hyperparams['wrn_width'], - hyperparams['batch_size'], - hyperparams['epochs'], - hyperparams['epsilon'], - hyperparams['delta'], - hyperparams['norm'], + hp['wrn_depth'], + hp['wrn_width'], + hp['batch_size'], + hp['epochs'], + hp['epsilon'], + hp['delta'], + hp['norm'], )) - model = train(hyperparams) - torch.save(model.state_dict(), hyperparams['logfile'].with_suffix('.pt')) + train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size']) + print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}") + print(f"Got x_in: {x_in.shape}") + print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}") + print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}") + print(f"Got train dataloader: {len(train_dl)}") + model_init, model_trained = train(hp, train_dl, test_dl) + + correct, total = evaluate_on(model_init, train_dl) + print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}") + correct, total = evaluate_on(model_trained, test_dl) + print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}") + + torch.save(model_trained.state_dict(), hp['logfile'].with_suffix('.pt')) if __name__ == '__main__': main()