import torch from torch.utils.data import random_split import torchvision from torchvision import transforms from torchvision.datasets import CIFAR10 import torch.nn.functional as F def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42): print(f"Train batch size: {train_batch_size}") normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) train_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode='reflect').squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) test_transform = transforms.Compose([ transforms.ToTensor(), normalize ]) trainset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=False, transform=train_transform) #splitting data in half for teacher dataset and student dataset (use manual seed for consistency) seed = torch.Generator().manual_seed(seed_val) subsets = random_split(trainset, [0.5, 0.5], generator=seed) teacher_set = subsets[0] student_set = subsets[1] testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform) return teacher_set, student_set, testset