36 lines
1.4 KiB
Python
36 lines
1.4 KiB
Python
import torch
|
|
from torch.utils.data import random_split
|
|
import torchvision
|
|
from torchvision import transforms
|
|
from torchvision.datasets import CIFAR10
|
|
import torch.nn.functional as F
|
|
|
|
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42):
|
|
print(f"Train batch size: {train_batch_size}")
|
|
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
|
train_transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
|
|
(4, 4, 4, 4), mode='reflect').squeeze()),
|
|
transforms.ToPILImage(),
|
|
transforms.RandomCrop(32),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
normalize,
|
|
])
|
|
|
|
test_transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
normalize
|
|
])
|
|
|
|
trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform)
|
|
|
|
#splitting data in half for teacher dataset and student dataset (use manual seed for consistency)
|
|
seed = torch.Generator().manual_seed(seed_val)
|
|
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
|
|
teacher_set = subsets[0]
|
|
student_set = subsets[1]
|
|
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
|
|
|
|
return teacher_set, student_set, testset
|