diff --git a/wresnet-pytorch/src/WideResNet.py b/wresnet-pytorch/src/WideResNet.py index 7b29c4e..887547d 100644 --- a/wresnet-pytorch/src/WideResNet.py +++ b/wresnet-pytorch/src/WideResNet.py @@ -45,7 +45,6 @@ class IndividualBlock1(nn.Module): class IndividualBlockN(nn.Module): - def __init__(self, input_features, output_features, stride): super(IndividualBlockN, self).__init__() @@ -117,7 +116,6 @@ class WideResNet(nn.Module): m.bias.data.zero_() def forward(self, x): - x = self.conv1(x) attention1 = self.block1(x) attention2 = self.block2(attention1) @@ -143,4 +141,4 @@ if __name__ == '__main__': net(sample_input) # Summarize model - summary(net, input_size=(3, 32, 32)) \ No newline at end of file + summary(net, input_size=(3, 32, 32)) diff --git a/wresnet-pytorch/src/train.py b/wresnet-pytorch/src/train.py index 8f5e7db..0977f72 100644 --- a/wresnet-pytorch/src/train.py +++ b/wresnet-pytorch/src/train.py @@ -7,6 +7,7 @@ import numpy as np import random from utils import json_file_to_pyobj, get_loaders from WideResNet import WideResNet +from tqdm import tqdm def set_seed(seed=42): @@ -17,25 +18,19 @@ def set_seed(seed=42): torch.cuda.manual_seed(seed) -def _train_seed(net, loaders, device, dataset, log=False, checkpoint=False, logfile='', checkpointFile=''): - +def _train_seed(net, loaders, device, dataset, log=False, checkpoint=False, logfile='', checkpointFile='', epochs=200): train_loader, test_loader = loaders - if dataset == 'svhn': - epochs = 100 - else: - epochs = 200 - criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) best_test_set_accuracy = 0 - for epoch in range(epochs): - + print(f"Training with {epochs} epochs") + for epoch in tqdm(range(epochs)): net.train() - for i, data in enumerate(train_loader, 0): + for i, data in tqdm(enumerate(train_loader, 0), leave=False): inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) @@ -89,18 +84,19 @@ def train(args): wrn_depth = training_configurations.wrn_depth wrn_width = training_configurations.wrn_width dataset = training_configurations.dataset.lower() - seeds = [int(seed) for seed in training_configurations.seeds] + #seeds = [int(seed) for seed in training_configurations.seeds] + seeds = [int.from_bytes(os.urandom(8), byteorder='big')] log = True if training_configurations.log.lower() == 'true' else False if log: - logfile = 'WideResNet-{}-{}-{}.txt'.format(wrn_depth, wrn_width, training_configurations.dataset) + logfile = 'WideResNet-{}-{}-{}-{}-{}.txt'.format(wrn_depth, wrn_width, training_configurations.dataset, training_configurations.batch_size, training_configurations.epochs) with open(logfile, 'w') as temp: - temp.write('WideResNet-{}-{} on {}\n'.format(wrn_depth, wrn_width, training_configurations.dataset)) + temp.write('WideResNet-{}-{} on {} {}batch for {} epochs\n'.format(wrn_depth, wrn_width, training_configurations.dataset, training_configurations.batch_size, training_configurations.epochs)) else: logfile = '' checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False - loaders = get_loaders(dataset) + loaders = get_loaders(dataset, training_configurations.batch_size) if torch.cuda.is_available(): device = torch.device('cuda:0') @@ -121,7 +117,8 @@ def train(args): net = net.to(device) checkpointFile = 'wrn-{}-{}-seed-{}-{}-dict.pth'.format(wrn_depth, wrn_width, dataset, seed) if checkpoint else '' - best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, checkpoint, logfile, checkpointFile) + epochs = training_configurations.epochs + best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, checkpoint, logfile, checkpointFile, epochs) if log: with open(logfile, 'a') as temp: diff --git a/wresnet-pytorch/src/utils.py b/wresnet-pytorch/src/utils.py index 8773f7f..9279551 100644 --- a/wresnet-pytorch/src/utils.py +++ b/wresnet-pytorch/src/utils.py @@ -16,9 +16,9 @@ def json_file_to_pyobj(filename): def get_loaders(dataset, train_batch_size=128, test_batch_size=10): + print(f"Train batch size: {train_batch_size}") if dataset == 'cifar10': - normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) train_transform = transforms.Compose([ @@ -44,7 +44,6 @@ def get_loaders(dataset, train_batch_size=128, test_batch_size=10): testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4) elif dataset == 'svhn': - normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)) transform = transforms.Compose([ diff --git a/wresnet-pytorch/src/wresnet16-audit-cifar10.json b/wresnet-pytorch/src/wresnet16-audit-cifar10.json new file mode 100644 index 0000000..f78b650 --- /dev/null +++ b/wresnet-pytorch/src/wresnet16-audit-cifar10.json @@ -0,0 +1,11 @@ +{ + "training":{ + "dataset": "CIFAR10", + "wrn_depth": 16, + "wrn_width": 1, + "checkpoint": "True", + "log": "True", + "batch_size": 4096, + "epochs": 200 + } +}