Torchlira: add wresnet-16
This commit is contained in:
parent
2b865a5f58
commit
944ff9d5cd
12 changed files with 473 additions and 0 deletions
34
wresnet-pytorch/README.md
Normal file
34
wresnet-pytorch/README.md
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# Wide Residual Networks in PyTorch
|
||||||
|
|
||||||
|
Implementation of Wide Residual Networks (WRNs) in PyTorch.
|
||||||
|
|
||||||
|
## How to train WRNs
|
||||||
|
|
||||||
|
At the moment the CIFAR10 and SVHN datasets are fully supported, with specific augmentations for CIFAR10 drawn from related literature and mean/std normalization for SVHN, and multistep learning rate scheduling for both cases. Training is executed through JSON configuration files, which you can modify or extend to support other configurations of WRNs and/or extend datasets etc.
|
||||||
|
|
||||||
|
### Example Runs
|
||||||
|
|
||||||
|
Train a WideResNet-16-1 on CIFAR10:
|
||||||
|
```
|
||||||
|
python train.py --config configs/WRN-16-1-scratch-CIFAR10.json
|
||||||
|
```
|
||||||
|
|
||||||
|
Train a WideResNet-40-2 on SVHN:
|
||||||
|
```
|
||||||
|
python train.py --config configs/WRN-40-2-scratch-SVHN.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
This work has been tested with 4 variants of WRNs. When setting the seed generator equal to 0, you should expect a test-set accuracy performance close to the following values:
|
||||||
|
|
||||||
|
|Model | CIFAR10 | SVHN |
|
||||||
|
|:---------|:--------|:-------|
|
||||||
|
| WRN-16-1 |90.97% | 95.52% |
|
||||||
|
| WRN-16-2 |94.21% | 96.17% |
|
||||||
|
| WRN-40-1 |93.52% | 96.07% |
|
||||||
|
| WRN-40-2 |95.14% | 96.14% |
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
The motivation for originally implementing WRNs in PyTorch was [this](https://github.com/AlexandrosFerles/NIPS_2019_Reproducibilty_Challenge_Zero-shot_Knowledge_Transfer_via_Adversarial_Belief_Matching) NeurIPS reproducibility project, where WRNs were used as the main framework for few-shot and zero-shot knowledge transfer.
|
146
wresnet-pytorch/src/WideResNet.py
Normal file
146
wresnet-pytorch/src/WideResNet.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchsummary import summary
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class IndividualBlock1(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
|
||||||
|
super(IndividualBlock1, self).__init__()
|
||||||
|
|
||||||
|
self.activation = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.batch_norm1 = nn.BatchNorm2d(input_features)
|
||||||
|
self.batch_norm2 = nn.BatchNorm2d(output_features)
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||||
|
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
|
|
||||||
|
self.subsample_input = subsample_input
|
||||||
|
self.increase_filters = increase_filters
|
||||||
|
if subsample_input:
|
||||||
|
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
|
||||||
|
elif increase_filters:
|
||||||
|
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
if self.subsample_input or self.increase_filters:
|
||||||
|
x = self.batch_norm1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x1 = self.conv1(x)
|
||||||
|
else:
|
||||||
|
x1 = self.batch_norm1(x)
|
||||||
|
x1 = self.activation(x1)
|
||||||
|
x1 = self.conv1(x1)
|
||||||
|
x1 = self.batch_norm2(x1)
|
||||||
|
x1 = self.activation(x1)
|
||||||
|
x1 = self.conv2(x1)
|
||||||
|
|
||||||
|
if self.subsample_input or self.increase_filters:
|
||||||
|
return self.conv_inp(x) + x1
|
||||||
|
else:
|
||||||
|
return x + x1
|
||||||
|
|
||||||
|
|
||||||
|
class IndividualBlockN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, input_features, output_features, stride):
|
||||||
|
super(IndividualBlockN, self).__init__()
|
||||||
|
|
||||||
|
self.activation = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.batch_norm1 = nn.BatchNorm2d(input_features)
|
||||||
|
self.batch_norm2 = nn.BatchNorm2d(output_features)
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||||
|
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.batch_norm1(x)
|
||||||
|
x1 = self.activation(x1)
|
||||||
|
x1 = self.conv1(x1)
|
||||||
|
x1 = self.batch_norm2(x1)
|
||||||
|
x1 = self.activation(x1)
|
||||||
|
x1 = self.conv2(x1)
|
||||||
|
|
||||||
|
return x1 + x
|
||||||
|
|
||||||
|
|
||||||
|
class Nblock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
|
||||||
|
super(Nblock, self).__init__()
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
for i in range(N):
|
||||||
|
if i == 0:
|
||||||
|
layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
|
||||||
|
else:
|
||||||
|
layers.append(IndividualBlockN(output_features, output_features, stride=1))
|
||||||
|
|
||||||
|
self.nblockLayer = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.nblockLayer(x)
|
||||||
|
|
||||||
|
|
||||||
|
class WideResNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d, k, n_classes, input_features, output_features, strides):
|
||||||
|
super(WideResNet, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)
|
||||||
|
|
||||||
|
filters = [16 * k, 32 * k, 64 * k]
|
||||||
|
self.out_filters = filters[-1]
|
||||||
|
N = (d - 4) // 6
|
||||||
|
increase_filters = k > 1
|
||||||
|
self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
|
||||||
|
self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
|
||||||
|
self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])
|
||||||
|
|
||||||
|
self.batch_norm = nn.BatchNorm2d(filters[-1])
|
||||||
|
self.activation = nn.ReLU(inplace=True)
|
||||||
|
self.avg_pool = nn.AvgPool2d(kernel_size=8)
|
||||||
|
self.fc = nn.Linear(filters[-1], n_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
attention1 = self.block1(x)
|
||||||
|
attention2 = self.block2(attention1)
|
||||||
|
attention3 = self.block3(attention2)
|
||||||
|
out = self.batch_norm(attention3)
|
||||||
|
out = self.activation(out)
|
||||||
|
out = self.avg_pool(out)
|
||||||
|
out = out.view(-1, self.out_filters)
|
||||||
|
|
||||||
|
return self.fc(out), attention1, attention2, attention3
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# change d and k if you want to check a model other than WRN-40-2
|
||||||
|
d = 40
|
||||||
|
k = 2
|
||||||
|
strides = [1, 1, 2, 2]
|
||||||
|
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides)
|
||||||
|
|
||||||
|
# verify that an output is produced
|
||||||
|
sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
|
||||||
|
net(sample_input)
|
||||||
|
|
||||||
|
# Summarize model
|
||||||
|
summary(net, input_size=(3, 32, 32))
|
10
wresnet-pytorch/src/configs/WRN-16-1-scratch-CIFAR10.json
Normal file
10
wresnet-pytorch/src/configs/WRN-16-1-scratch-CIFAR10.json
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"training":{
|
||||||
|
"dataset": "CIFAR10",
|
||||||
|
"wrn_depth": 16,
|
||||||
|
"wrn_width": 1,
|
||||||
|
"seeds": "0",
|
||||||
|
"checkpoint": "True",
|
||||||
|
"log": "True"
|
||||||
|
}
|
||||||
|
}
|
10
wresnet-pytorch/src/configs/WRN-16-1-scratch-SVHN.json
Normal file
10
wresnet-pytorch/src/configs/WRN-16-1-scratch-SVHN.json
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"training":{
|
||||||
|
"dataset": "SVHN",
|
||||||
|
"wrn_depth": 16,
|
||||||
|
"wrn_width": 1,
|
||||||
|
"seeds": "0",
|
||||||
|
"checkpoint": "True",
|
||||||
|
"log": "True"
|
||||||
|
}
|
||||||
|
}
|
10
wresnet-pytorch/src/configs/WRN-16-2-scratch-CIFAR10.json
Normal file
10
wresnet-pytorch/src/configs/WRN-16-2-scratch-CIFAR10.json
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"training":{
|
||||||
|
"dataset": "CIFAR10",
|
||||||
|
"wrn_depth": 16,
|
||||||
|
"wrn_width": 2,
|
||||||
|
"seeds": "0",
|
||||||
|
"checkpoint": "True",
|
||||||
|
"log": "True"
|
||||||
|
}
|
||||||
|
}
|
10
wresnet-pytorch/src/configs/WRN-16-2-scratch-SVHN.json
Normal file
10
wresnet-pytorch/src/configs/WRN-16-2-scratch-SVHN.json
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"training":{
|
||||||
|
"dataset": "SVHN",
|
||||||
|
"wrn_depth": 16,
|
||||||
|
"wrn_width": 2,
|
||||||
|
"seeds": "0",
|
||||||
|
"checkpoint": "True",
|
||||||
|
"log": "True"
|
||||||
|
}
|
||||||
|
}
|
10
wresnet-pytorch/src/configs/WRN-40-1-scratch-CIFAR10.json
Normal file
10
wresnet-pytorch/src/configs/WRN-40-1-scratch-CIFAR10.json
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"training":{
|
||||||
|
"dataset": "CIFAR10",
|
||||||
|
"wrn_depth": 40,
|
||||||
|
"wrn_width": 1,
|
||||||
|
"seeds": "0",
|
||||||
|
"checkpoint": "True",
|
||||||
|
"log": "True"
|
||||||
|
}
|
||||||
|
}
|
10
wresnet-pytorch/src/configs/WRN-40-1-scratch-SVHN.json
Normal file
10
wresnet-pytorch/src/configs/WRN-40-1-scratch-SVHN.json
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"training":{
|
||||||
|
"dataset": "SVHN",
|
||||||
|
"wrn_depth": 40,
|
||||||
|
"wrn_width": 1,
|
||||||
|
"seeds": "0",
|
||||||
|
"checkpoint": "True",
|
||||||
|
"log": "True"
|
||||||
|
}
|
||||||
|
}
|
10
wresnet-pytorch/src/configs/WRN-40-2-scratch-CIFAR10.json
Normal file
10
wresnet-pytorch/src/configs/WRN-40-2-scratch-CIFAR10.json
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"training":{
|
||||||
|
"dataset": "CIFAR10",
|
||||||
|
"wrn_depth": 40,
|
||||||
|
"wrn_width": 2,
|
||||||
|
"seeds": "012",
|
||||||
|
"checkpoint": "True",
|
||||||
|
"log": "True"
|
||||||
|
}
|
||||||
|
}
|
10
wresnet-pytorch/src/configs/WRN-40-2-scratch-SVHN.json
Normal file
10
wresnet-pytorch/src/configs/WRN-40-2-scratch-SVHN.json
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"training":{
|
||||||
|
"dataset": "SVHN",
|
||||||
|
"wrn_depth": 40,
|
||||||
|
"wrn_width": 2,
|
||||||
|
"seeds": "012",
|
||||||
|
"checkpoint": "True",
|
||||||
|
"log": "True"
|
||||||
|
}
|
||||||
|
}
|
151
wresnet-pytorch/src/train.py
Normal file
151
wresnet-pytorch/src/train.py
Normal file
|
@ -0,0 +1,151 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from torch import optim
|
||||||
|
from torch.optim.lr_scheduler import MultiStepLR
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from utils import json_file_to_pyobj, get_loaders
|
||||||
|
from WideResNet import WideResNet
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed=42):
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def _train_seed(net, loaders, device, dataset, log=False, checkpoint=False, logfile='', checkpointFile=''):
|
||||||
|
|
||||||
|
train_loader, test_loader = loaders
|
||||||
|
|
||||||
|
if dataset == 'svhn':
|
||||||
|
epochs = 100
|
||||||
|
else:
|
||||||
|
epochs = 200
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
|
||||||
|
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
|
||||||
|
|
||||||
|
best_test_set_accuracy = 0
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
|
||||||
|
net.train()
|
||||||
|
for i, data in enumerate(train_loader, 0):
|
||||||
|
inputs, labels = data
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
wrn_outputs = net(inputs)
|
||||||
|
outputs = wrn_outputs[0]
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
net.eval()
|
||||||
|
for data in test_loader:
|
||||||
|
images, labels = data
|
||||||
|
images = images.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
wrn_outputs = net(images)
|
||||||
|
outputs = wrn_outputs[0]
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
epoch_accuracy = correct / total
|
||||||
|
epoch_accuracy = round(100 * epoch_accuracy, 2)
|
||||||
|
|
||||||
|
if log:
|
||||||
|
with open(logfile, 'a') as temp:
|
||||||
|
temp.write('Accuracy at epoch {} is {}%\n'.format(epoch + 1, epoch_accuracy))
|
||||||
|
|
||||||
|
if epoch_accuracy > best_test_set_accuracy:
|
||||||
|
best_test_set_accuracy = epoch_accuracy
|
||||||
|
if checkpoint:
|
||||||
|
torch.save(net.state_dict(), checkpointFile)
|
||||||
|
|
||||||
|
return best_test_set_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
json_options = json_file_to_pyobj(args.config)
|
||||||
|
training_configurations = json_options.training
|
||||||
|
|
||||||
|
wrn_depth = training_configurations.wrn_depth
|
||||||
|
wrn_width = training_configurations.wrn_width
|
||||||
|
dataset = training_configurations.dataset.lower()
|
||||||
|
seeds = [int(seed) for seed in training_configurations.seeds]
|
||||||
|
log = True if training_configurations.log.lower() == 'true' else False
|
||||||
|
|
||||||
|
if log:
|
||||||
|
logfile = 'WideResNet-{}-{}-{}.txt'.format(wrn_depth, wrn_width, training_configurations.dataset)
|
||||||
|
with open(logfile, 'w') as temp:
|
||||||
|
temp.write('WideResNet-{}-{} on {}\n'.format(wrn_depth, wrn_width, training_configurations.dataset))
|
||||||
|
else:
|
||||||
|
logfile = ''
|
||||||
|
|
||||||
|
checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False
|
||||||
|
loaders = get_loaders(dataset)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device('cuda:0')
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
test_set_accuracies = []
|
||||||
|
|
||||||
|
for seed in seeds:
|
||||||
|
set_seed(seed)
|
||||||
|
|
||||||
|
if log:
|
||||||
|
with open(logfile, 'a') as temp:
|
||||||
|
temp.write('------------------- SEED {} -------------------\n'.format(seed))
|
||||||
|
|
||||||
|
strides = [1, 1, 2, 2]
|
||||||
|
net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides)
|
||||||
|
net = net.to(device)
|
||||||
|
|
||||||
|
checkpointFile = 'wrn-{}-{}-seed-{}-{}-dict.pth'.format(wrn_depth, wrn_width, dataset, seed) if checkpoint else ''
|
||||||
|
best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, checkpoint, logfile, checkpointFile)
|
||||||
|
|
||||||
|
if log:
|
||||||
|
with open(logfile, 'a') as temp:
|
||||||
|
temp.write('Best test set accuracy of seed {} is {}\n'.format(seed, best_test_set_accuracy))
|
||||||
|
|
||||||
|
test_set_accuracies.append(best_test_set_accuracy)
|
||||||
|
|
||||||
|
mean_test_set_accuracy, std_test_set_accuracy = np.mean(test_set_accuracies), np.std(test_set_accuracies)
|
||||||
|
|
||||||
|
if log:
|
||||||
|
with open(logfile, 'a') as temp:
|
||||||
|
temp.write('Mean test set accuracy is {} with standard deviation equal to {}\n'.format(mean_test_set_accuracy, std_test_set_accuracy))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='WideResNet')
|
||||||
|
|
||||||
|
parser.add_argument('-config', '--config', help='Training Configurations', required=True)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
train(args)
|
62
wresnet-pytorch/src/utils.py
Normal file
62
wresnet-pytorch/src/utils.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
import json
|
||||||
|
import collections
|
||||||
|
import torchvision
|
||||||
|
from torchvision import transforms
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# Borrowed from https://github.com/ozan-oktay/Attention-Gated-Networks
|
||||||
|
def json_file_to_pyobj(filename):
|
||||||
|
def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values())
|
||||||
|
|
||||||
|
def json2obj(data): return json.loads(data, object_hook=_json_object_hook)
|
||||||
|
|
||||||
|
return json2obj(open(filename).read())
|
||||||
|
|
||||||
|
|
||||||
|
def get_loaders(dataset, train_batch_size=128, test_batch_size=10):
|
||||||
|
|
||||||
|
if dataset == 'cifar10':
|
||||||
|
|
||||||
|
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||||
|
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
|
||||||
|
(4, 4, 4, 4), mode='reflect').squeeze()),
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
transforms.RandomCrop(32),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
normalize,
|
||||||
|
])
|
||||||
|
|
||||||
|
test_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
normalize
|
||||||
|
])
|
||||||
|
|
||||||
|
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
|
||||||
|
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
|
||||||
|
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
elif dataset == 'svhn':
|
||||||
|
|
||||||
|
normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
|
||||||
|
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
normalize,
|
||||||
|
])
|
||||||
|
|
||||||
|
trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
|
||||||
|
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
|
||||||
|
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
return trainloader, testloader
|
||||||
|
|
Loading…
Reference in a new issue