63 lines
2.3 KiB
Python
63 lines
2.3 KiB
Python
|
import json
|
||
|
import collections
|
||
|
import torchvision
|
||
|
from torchvision import transforms
|
||
|
from torch.utils.data import DataLoader
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
|
# Borrowed from https://github.com/ozan-oktay/Attention-Gated-Networks
|
||
|
def json_file_to_pyobj(filename):
|
||
|
def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values())
|
||
|
|
||
|
def json2obj(data): return json.loads(data, object_hook=_json_object_hook)
|
||
|
|
||
|
return json2obj(open(filename).read())
|
||
|
|
||
|
|
||
|
def get_loaders(dataset, train_batch_size=128, test_batch_size=10):
|
||
|
print(f"Train batch size: {train_batch_size}")
|
||
|
|
||
|
if dataset == 'cifar10':
|
||
|
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||
|
|
||
|
train_transform = transforms.Compose([
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
|
||
|
(4, 4, 4, 4), mode='reflect').squeeze()),
|
||
|
transforms.ToPILImage(),
|
||
|
transforms.RandomCrop(32),
|
||
|
transforms.RandomHorizontalFlip(),
|
||
|
transforms.ToTensor(),
|
||
|
normalize,
|
||
|
])
|
||
|
|
||
|
test_transform = transforms.Compose([
|
||
|
transforms.ToTensor(),
|
||
|
normalize
|
||
|
])
|
||
|
|
||
|
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
|
||
|
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||
|
|
||
|
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
|
||
|
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||
|
|
||
|
elif dataset == 'svhn':
|
||
|
normalize = transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
|
||
|
|
||
|
transform = transforms.Compose([
|
||
|
transforms.ToTensor(),
|
||
|
normalize,
|
||
|
])
|
||
|
|
||
|
trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
|
||
|
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||
|
|
||
|
testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
|
||
|
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||
|
|
||
|
return trainloader, testloader
|
||
|
|
||
|
|