diff --git a/wresnet-pytorch/README.md b/wresnet-pytorch/README.md new file mode 100644 index 0000000..3ca72a0 --- /dev/null +++ b/wresnet-pytorch/README.md @@ -0,0 +1,34 @@ +# Wide Residual Networks in PyTorch + +Implementation of Wide Residual Networks (WRNs) in PyTorch. + +## How to train WRNs + +At the moment the CIFAR10 and SVHN datasets are fully supported, with specific augmentations for CIFAR10 drawn from related literature and mean/std normalization for SVHN, and multistep learning rate scheduling for both cases. Training is executed through JSON configuration files, which you can modify or extend to support other configurations of WRNs and/or extend datasets etc. + +### Example Runs + +Train a WideResNet-16-1 on CIFAR10: +``` +python train.py --config configs/WRN-16-1-scratch-CIFAR10.json +``` + +Train a WideResNet-40-2 on SVHN: +``` +python train.py --config configs/WRN-40-2-scratch-SVHN.json +``` + +## Results + +This work has been tested with 4 variants of WRNs. When setting the seed generator equal to 0, you should expect a test-set accuracy performance close to the following values: + +|Model | CIFAR10 | SVHN | +|:---------|:--------|:-------| +| WRN-16-1 |90.97% | 95.52% | +| WRN-16-2 |94.21% | 96.17% | +| WRN-40-1 |93.52% | 96.07% | +| WRN-40-2 |95.14% | 96.14% | + +## Notes + +The motivation for originally implementing WRNs in PyTorch was [this](https://github.com/AlexandrosFerles/NIPS_2019_Reproducibilty_Challenge_Zero-shot_Knowledge_Transfer_via_Adversarial_Belief_Matching) NeurIPS reproducibility project, where WRNs were used as the main framework for few-shot and zero-shot knowledge transfer. \ No newline at end of file diff --git a/wresnet-pytorch/src/WideResNet.py b/wresnet-pytorch/src/WideResNet.py new file mode 100644 index 0000000..7b29c4e --- /dev/null +++ b/wresnet-pytorch/src/WideResNet.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +from torchsummary import summary +import math + + +class IndividualBlock1(nn.Module): + + def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True): + super(IndividualBlock1, self).__init__() + + self.activation = nn.ReLU(inplace=True) + + self.batch_norm1 = nn.BatchNorm2d(input_features) + self.batch_norm2 = nn.BatchNorm2d(output_features) + + self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False) + + self.subsample_input = subsample_input + self.increase_filters = increase_filters + if subsample_input: + self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False) + elif increase_filters: + self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False) + + def forward(self, x): + + if self.subsample_input or self.increase_filters: + x = self.batch_norm1(x) + x = self.activation(x) + x1 = self.conv1(x) + else: + x1 = self.batch_norm1(x) + x1 = self.activation(x1) + x1 = self.conv1(x1) + x1 = self.batch_norm2(x1) + x1 = self.activation(x1) + x1 = self.conv2(x1) + + if self.subsample_input or self.increase_filters: + return self.conv_inp(x) + x1 + else: + return x + x1 + + +class IndividualBlockN(nn.Module): + + def __init__(self, input_features, output_features, stride): + super(IndividualBlockN, self).__init__() + + self.activation = nn.ReLU(inplace=True) + + self.batch_norm1 = nn.BatchNorm2d(input_features) + self.batch_norm2 = nn.BatchNorm2d(output_features) + + self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False) + + def forward(self, x): + x1 = self.batch_norm1(x) + x1 = self.activation(x1) + x1 = self.conv1(x1) + x1 = self.batch_norm2(x1) + x1 = self.activation(x1) + x1 = self.conv2(x1) + + return x1 + x + + +class Nblock(nn.Module): + + def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True): + super(Nblock, self).__init__() + + layers = [] + for i in range(N): + if i == 0: + layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters)) + else: + layers.append(IndividualBlockN(output_features, output_features, stride=1)) + + self.nblockLayer = nn.Sequential(*layers) + + def forward(self, x): + return self.nblockLayer(x) + + +class WideResNet(nn.Module): + + def __init__(self, d, k, n_classes, input_features, output_features, strides): + super(WideResNet, self).__init__() + + self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False) + + filters = [16 * k, 32 * k, 64 * k] + self.out_filters = filters[-1] + N = (d - 4) // 6 + increase_filters = k > 1 + self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters) + self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2]) + self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3]) + + self.batch_norm = nn.BatchNorm2d(filters[-1]) + self.activation = nn.ReLU(inplace=True) + self.avg_pool = nn.AvgPool2d(kernel_size=8) + self.fc = nn.Linear(filters[-1], n_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.bias.data.zero_() + + def forward(self, x): + + x = self.conv1(x) + attention1 = self.block1(x) + attention2 = self.block2(attention1) + attention3 = self.block3(attention2) + out = self.batch_norm(attention3) + out = self.activation(out) + out = self.avg_pool(out) + out = out.view(-1, self.out_filters) + + return self.fc(out), attention1, attention2, attention3 + + +if __name__ == '__main__': + + # change d and k if you want to check a model other than WRN-40-2 + d = 40 + k = 2 + strides = [1, 1, 2, 2] + net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides) + + # verify that an output is produced + sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False) + net(sample_input) + + # Summarize model + summary(net, input_size=(3, 32, 32)) \ No newline at end of file diff --git a/wresnet-pytorch/src/configs/WRN-16-1-scratch-CIFAR10.json b/wresnet-pytorch/src/configs/WRN-16-1-scratch-CIFAR10.json new file mode 100644 index 0000000..a7d8f54 --- /dev/null +++ b/wresnet-pytorch/src/configs/WRN-16-1-scratch-CIFAR10.json @@ -0,0 +1,10 @@ +{ + "training":{ + "dataset": "CIFAR10", + "wrn_depth": 16, + "wrn_width": 1, + "seeds": "0", + "checkpoint": "True", + "log": "True" + } +} diff --git a/wresnet-pytorch/src/configs/WRN-16-1-scratch-SVHN.json b/wresnet-pytorch/src/configs/WRN-16-1-scratch-SVHN.json new file mode 100644 index 0000000..ce7c55b --- /dev/null +++ b/wresnet-pytorch/src/configs/WRN-16-1-scratch-SVHN.json @@ -0,0 +1,10 @@ +{ + "training":{ + "dataset": "SVHN", + "wrn_depth": 16, + "wrn_width": 1, + "seeds": "0", + "checkpoint": "True", + "log": "True" + } +} diff --git a/wresnet-pytorch/src/configs/WRN-16-2-scratch-CIFAR10.json b/wresnet-pytorch/src/configs/WRN-16-2-scratch-CIFAR10.json new file mode 100644 index 0000000..f8918b6 --- /dev/null +++ b/wresnet-pytorch/src/configs/WRN-16-2-scratch-CIFAR10.json @@ -0,0 +1,10 @@ +{ + "training":{ + "dataset": "CIFAR10", + "wrn_depth": 16, + "wrn_width": 2, + "seeds": "0", + "checkpoint": "True", + "log": "True" + } +} diff --git a/wresnet-pytorch/src/configs/WRN-16-2-scratch-SVHN.json b/wresnet-pytorch/src/configs/WRN-16-2-scratch-SVHN.json new file mode 100644 index 0000000..afecf68 --- /dev/null +++ b/wresnet-pytorch/src/configs/WRN-16-2-scratch-SVHN.json @@ -0,0 +1,10 @@ +{ + "training":{ + "dataset": "SVHN", + "wrn_depth": 16, + "wrn_width": 2, + "seeds": "0", + "checkpoint": "True", + "log": "True" + } +} diff --git a/wresnet-pytorch/src/configs/WRN-40-1-scratch-CIFAR10.json b/wresnet-pytorch/src/configs/WRN-40-1-scratch-CIFAR10.json new file mode 100644 index 0000000..bc0a9c2 --- /dev/null +++ b/wresnet-pytorch/src/configs/WRN-40-1-scratch-CIFAR10.json @@ -0,0 +1,10 @@ +{ + "training":{ + "dataset": "CIFAR10", + "wrn_depth": 40, + "wrn_width": 1, + "seeds": "0", + "checkpoint": "True", + "log": "True" + } +} diff --git a/wresnet-pytorch/src/configs/WRN-40-1-scratch-SVHN.json b/wresnet-pytorch/src/configs/WRN-40-1-scratch-SVHN.json new file mode 100644 index 0000000..a9098b5 --- /dev/null +++ b/wresnet-pytorch/src/configs/WRN-40-1-scratch-SVHN.json @@ -0,0 +1,10 @@ +{ + "training":{ + "dataset": "SVHN", + "wrn_depth": 40, + "wrn_width": 1, + "seeds": "0", + "checkpoint": "True", + "log": "True" + } +} diff --git a/wresnet-pytorch/src/configs/WRN-40-2-scratch-CIFAR10.json b/wresnet-pytorch/src/configs/WRN-40-2-scratch-CIFAR10.json new file mode 100644 index 0000000..b156a3d --- /dev/null +++ b/wresnet-pytorch/src/configs/WRN-40-2-scratch-CIFAR10.json @@ -0,0 +1,10 @@ +{ + "training":{ + "dataset": "CIFAR10", + "wrn_depth": 40, + "wrn_width": 2, + "seeds": "012", + "checkpoint": "True", + "log": "True" + } +} diff --git a/wresnet-pytorch/src/configs/WRN-40-2-scratch-SVHN.json b/wresnet-pytorch/src/configs/WRN-40-2-scratch-SVHN.json new file mode 100644 index 0000000..db9b327 --- /dev/null +++ b/wresnet-pytorch/src/configs/WRN-40-2-scratch-SVHN.json @@ -0,0 +1,10 @@ +{ + "training":{ + "dataset": "SVHN", + "wrn_depth": 40, + "wrn_width": 2, + "seeds": "012", + "checkpoint": "True", + "log": "True" + } +} diff --git a/wresnet-pytorch/src/train.py b/wresnet-pytorch/src/train.py new file mode 100644 index 0000000..8f5e7db --- /dev/null +++ b/wresnet-pytorch/src/train.py @@ -0,0 +1,151 @@ +import os +import torch +from torch import optim +from torch.optim.lr_scheduler import MultiStepLR +import torch.nn as nn +import numpy as np +import random +from utils import json_file_to_pyobj, get_loaders +from WideResNet import WideResNet + + +def set_seed(seed=42): + torch.backends.cudnn.deterministic = True + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def _train_seed(net, loaders, device, dataset, log=False, checkpoint=False, logfile='', checkpointFile=''): + + train_loader, test_loader = loaders + + if dataset == 'svhn': + epochs = 100 + else: + epochs = 200 + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) + scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2) + + best_test_set_accuracy = 0 + + for epoch in range(epochs): + + net.train() + for i, data in enumerate(train_loader, 0): + inputs, labels = data + inputs = inputs.to(device) + labels = labels.to(device) + + optimizer.zero_grad() + + wrn_outputs = net(inputs) + outputs = wrn_outputs[0] + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + scheduler.step() + + with torch.no_grad(): + + correct = 0 + total = 0 + + net.eval() + for data in test_loader: + images, labels = data + images = images.to(device) + labels = labels.to(device) + + wrn_outputs = net(images) + outputs = wrn_outputs[0] + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + epoch_accuracy = correct / total + epoch_accuracy = round(100 * epoch_accuracy, 2) + + if log: + with open(logfile, 'a') as temp: + temp.write('Accuracy at epoch {} is {}%\n'.format(epoch + 1, epoch_accuracy)) + + if epoch_accuracy > best_test_set_accuracy: + best_test_set_accuracy = epoch_accuracy + if checkpoint: + torch.save(net.state_dict(), checkpointFile) + + return best_test_set_accuracy + + +def train(args): + json_options = json_file_to_pyobj(args.config) + training_configurations = json_options.training + + wrn_depth = training_configurations.wrn_depth + wrn_width = training_configurations.wrn_width + dataset = training_configurations.dataset.lower() + seeds = [int(seed) for seed in training_configurations.seeds] + log = True if training_configurations.log.lower() == 'true' else False + + if log: + logfile = 'WideResNet-{}-{}-{}.txt'.format(wrn_depth, wrn_width, training_configurations.dataset) + with open(logfile, 'w') as temp: + temp.write('WideResNet-{}-{} on {}\n'.format(wrn_depth, wrn_width, training_configurations.dataset)) + else: + logfile = '' + + checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False + loaders = get_loaders(dataset) + + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + device = torch.device('cpu') + + test_set_accuracies = [] + + for seed in seeds: + set_seed(seed) + + if log: + with open(logfile, 'a') as temp: + temp.write('------------------- SEED {} -------------------\n'.format(seed)) + + strides = [1, 1, 2, 2] + net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides) + net = net.to(device) + + checkpointFile = 'wrn-{}-{}-seed-{}-{}-dict.pth'.format(wrn_depth, wrn_width, dataset, seed) if checkpoint else '' + best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, checkpoint, logfile, checkpointFile) + + if log: + with open(logfile, 'a') as temp: + temp.write('Best test set accuracy of seed {} is {}\n'.format(seed, best_test_set_accuracy)) + + test_set_accuracies.append(best_test_set_accuracy) + + mean_test_set_accuracy, std_test_set_accuracy = np.mean(test_set_accuracies), np.std(test_set_accuracies) + + if log: + with open(logfile, 'a') as temp: + temp.write('Mean test set accuracy is {} with standard deviation equal to {}\n'.format(mean_test_set_accuracy, std_test_set_accuracy)) + + +if __name__ == '__main__': + import argparse + + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3" + + parser = argparse.ArgumentParser(description='WideResNet') + + parser.add_argument('-config', '--config', help='Training Configurations', required=True) + + args = parser.parse_args() + + train(args) diff --git a/wresnet-pytorch/src/utils.py b/wresnet-pytorch/src/utils.py new file mode 100644 index 0000000..8773f7f --- /dev/null +++ b/wresnet-pytorch/src/utils.py @@ -0,0 +1,62 @@ +import json +import collections +import torchvision +from torchvision import transforms +from torch.utils.data import DataLoader +import torch.nn.functional as F + + +# Borrowed from https://github.com/ozan-oktay/Attention-Gated-Networks +def json_file_to_pyobj(filename): + def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values()) + + def json2obj(data): return json.loads(data, object_hook=_json_object_hook) + + return json2obj(open(filename).read()) + + +def get_loaders(dataset, train_batch_size=128, test_batch_size=10): + + if dataset == 'cifar10': + + normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + + train_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), + (4, 4, 4, 4), mode='reflect').squeeze()), + transforms.ToPILImage(), + transforms.RandomCrop(32), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ]) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + normalize + ]) + + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) + trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4) + + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform) + testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4) + + elif dataset == 'svhn': + + normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)) + + transform = transforms.Compose([ + transforms.ToTensor(), + normalize, + ]) + + trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform) + trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4) + + testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform) + testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4) + + return trainloader, testloader +