Torchlira: add wresnet-16

This commit is contained in:
Akemi Izuko 2024-11-30 13:35:38 -07:00
parent 2b865a5f58
commit 944ff9d5cd
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC
12 changed files with 473 additions and 0 deletions

34
wresnet-pytorch/README.md Normal file
View file

@ -0,0 +1,34 @@
# Wide Residual Networks in PyTorch
Implementation of Wide Residual Networks (WRNs) in PyTorch.
## How to train WRNs
At the moment the CIFAR10 and SVHN datasets are fully supported, with specific augmentations for CIFAR10 drawn from related literature and mean/std normalization for SVHN, and multistep learning rate scheduling for both cases. Training is executed through JSON configuration files, which you can modify or extend to support other configurations of WRNs and/or extend datasets etc.
### Example Runs
Train a WideResNet-16-1 on CIFAR10:
```
python train.py --config configs/WRN-16-1-scratch-CIFAR10.json
```
Train a WideResNet-40-2 on SVHN:
```
python train.py --config configs/WRN-40-2-scratch-SVHN.json
```
## Results
This work has been tested with 4 variants of WRNs. When setting the seed generator equal to 0, you should expect a test-set accuracy performance close to the following values:
|Model | CIFAR10 | SVHN |
|:---------|:--------|:-------|
| WRN-16-1 |90.97% | 95.52% |
| WRN-16-2 |94.21% | 96.17% |
| WRN-40-1 |93.52% | 96.07% |
| WRN-40-2 |95.14% | 96.14% |
## Notes
The motivation for originally implementing WRNs in PyTorch was [this](https://github.com/AlexandrosFerles/NIPS_2019_Reproducibilty_Challenge_Zero-shot_Knowledge_Transfer_via_Adversarial_Belief_Matching) NeurIPS reproducibility project, where WRNs were used as the main framework for few-shot and zero-shot knowledge transfer.

View file

@ -0,0 +1,146 @@
import torch
import torch.nn as nn
from torchsummary import summary
import math
class IndividualBlock1(nn.Module):
def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(IndividualBlock1, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)
self.subsample_input = subsample_input
self.increase_filters = increase_filters
if subsample_input:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
elif increase_filters:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, x):
if self.subsample_input or self.increase_filters:
x = self.batch_norm1(x)
x = self.activation(x)
x1 = self.conv1(x)
else:
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
if self.subsample_input or self.increase_filters:
return self.conv_inp(x) + x1
else:
return x + x1
class IndividualBlockN(nn.Module):
def __init__(self, input_features, output_features, stride):
super(IndividualBlockN, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
def forward(self, x):
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
return x1 + x
class Nblock(nn.Module):
def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(Nblock, self).__init__()
layers = []
for i in range(N):
if i == 0:
layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
else:
layers.append(IndividualBlockN(output_features, output_features, stride=1))
self.nblockLayer = nn.Sequential(*layers)
def forward(self, x):
return self.nblockLayer(x)
class WideResNet(nn.Module):
def __init__(self, d, k, n_classes, input_features, output_features, strides):
super(WideResNet, self).__init__()
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)
filters = [16 * k, 32 * k, 64 * k]
self.out_filters = filters[-1]
N = (d - 4) // 6
increase_filters = k > 1
self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])
self.batch_norm = nn.BatchNorm2d(filters[-1])
self.activation = nn.ReLU(inplace=True)
self.avg_pool = nn.AvgPool2d(kernel_size=8)
self.fc = nn.Linear(filters[-1], n_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
x = self.conv1(x)
attention1 = self.block1(x)
attention2 = self.block2(attention1)
attention3 = self.block3(attention2)
out = self.batch_norm(attention3)
out = self.activation(out)
out = self.avg_pool(out)
out = out.view(-1, self.out_filters)
return self.fc(out), attention1, attention2, attention3
if __name__ == '__main__':
# change d and k if you want to check a model other than WRN-40-2
d = 40
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides)
# verify that an output is produced
sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
net(sample_input)
# Summarize model
summary(net, input_size=(3, 32, 32))

View file

@ -0,0 +1,10 @@
{
"training":{
"dataset": "CIFAR10",
"wrn_depth": 16,
"wrn_width": 1,
"seeds": "0",
"checkpoint": "True",
"log": "True"
}
}

View file

@ -0,0 +1,10 @@
{
"training":{
"dataset": "SVHN",
"wrn_depth": 16,
"wrn_width": 1,
"seeds": "0",
"checkpoint": "True",
"log": "True"
}
}

View file

@ -0,0 +1,10 @@
{
"training":{
"dataset": "CIFAR10",
"wrn_depth": 16,
"wrn_width": 2,
"seeds": "0",
"checkpoint": "True",
"log": "True"
}
}

View file

@ -0,0 +1,10 @@
{
"training":{
"dataset": "SVHN",
"wrn_depth": 16,
"wrn_width": 2,
"seeds": "0",
"checkpoint": "True",
"log": "True"
}
}

View file

@ -0,0 +1,10 @@
{
"training":{
"dataset": "CIFAR10",
"wrn_depth": 40,
"wrn_width": 1,
"seeds": "0",
"checkpoint": "True",
"log": "True"
}
}

View file

@ -0,0 +1,10 @@
{
"training":{
"dataset": "SVHN",
"wrn_depth": 40,
"wrn_width": 1,
"seeds": "0",
"checkpoint": "True",
"log": "True"
}
}

View file

@ -0,0 +1,10 @@
{
"training":{
"dataset": "CIFAR10",
"wrn_depth": 40,
"wrn_width": 2,
"seeds": "012",
"checkpoint": "True",
"log": "True"
}
}

View file

@ -0,0 +1,10 @@
{
"training":{
"dataset": "SVHN",
"wrn_depth": 40,
"wrn_width": 2,
"seeds": "012",
"checkpoint": "True",
"log": "True"
}
}

View file

@ -0,0 +1,151 @@
import os
import torch
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
import torch.nn as nn
import numpy as np
import random
from utils import json_file_to_pyobj, get_loaders
from WideResNet import WideResNet
def set_seed(seed=42):
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def _train_seed(net, loaders, device, dataset, log=False, checkpoint=False, logfile='', checkpointFile=''):
train_loader, test_loader = loaders
if dataset == 'svhn':
epochs = 100
else:
epochs = 200
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
best_test_set_accuracy = 0
for epoch in range(epochs):
net.train()
for i, data in enumerate(train_loader, 0):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
wrn_outputs = net(inputs)
outputs = wrn_outputs[0]
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
with torch.no_grad():
correct = 0
total = 0
net.eval()
for data in test_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
wrn_outputs = net(images)
outputs = wrn_outputs[0]
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_accuracy = correct / total
epoch_accuracy = round(100 * epoch_accuracy, 2)
if log:
with open(logfile, 'a') as temp:
temp.write('Accuracy at epoch {} is {}%\n'.format(epoch + 1, epoch_accuracy))
if epoch_accuracy > best_test_set_accuracy:
best_test_set_accuracy = epoch_accuracy
if checkpoint:
torch.save(net.state_dict(), checkpointFile)
return best_test_set_accuracy
def train(args):
json_options = json_file_to_pyobj(args.config)
training_configurations = json_options.training
wrn_depth = training_configurations.wrn_depth
wrn_width = training_configurations.wrn_width
dataset = training_configurations.dataset.lower()
seeds = [int(seed) for seed in training_configurations.seeds]
log = True if training_configurations.log.lower() == 'true' else False
if log:
logfile = 'WideResNet-{}-{}-{}.txt'.format(wrn_depth, wrn_width, training_configurations.dataset)
with open(logfile, 'w') as temp:
temp.write('WideResNet-{}-{} on {}\n'.format(wrn_depth, wrn_width, training_configurations.dataset))
else:
logfile = ''
checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False
loaders = get_loaders(dataset)
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
test_set_accuracies = []
for seed in seeds:
set_seed(seed)
if log:
with open(logfile, 'a') as temp:
temp.write('------------------- SEED {} -------------------\n'.format(seed))
strides = [1, 1, 2, 2]
net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides)
net = net.to(device)
checkpointFile = 'wrn-{}-{}-seed-{}-{}-dict.pth'.format(wrn_depth, wrn_width, dataset, seed) if checkpoint else ''
best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, checkpoint, logfile, checkpointFile)
if log:
with open(logfile, 'a') as temp:
temp.write('Best test set accuracy of seed {} is {}\n'.format(seed, best_test_set_accuracy))
test_set_accuracies.append(best_test_set_accuracy)
mean_test_set_accuracy, std_test_set_accuracy = np.mean(test_set_accuracies), np.std(test_set_accuracies)
if log:
with open(logfile, 'a') as temp:
temp.write('Mean test set accuracy is {} with standard deviation equal to {}\n'.format(mean_test_set_accuracy, std_test_set_accuracy))
if __name__ == '__main__':
import argparse
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"
parser = argparse.ArgumentParser(description='WideResNet')
parser.add_argument('-config', '--config', help='Training Configurations', required=True)
args = parser.parse_args()
train(args)

View file

@ -0,0 +1,62 @@
import json
import collections
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
# Borrowed from https://github.com/ozan-oktay/Attention-Gated-Networks
def json_file_to_pyobj(filename):
def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values())
def json2obj(data): return json.loads(data, object_hook=_json_object_hook)
return json2obj(open(filename).read())
def get_loaders(dataset, train_batch_size=128, test_batch_size=10):
if dataset == 'cifar10':
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
elif dataset == 'svhn':
normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
return trainloader, testloader