added if statement to either keep the datapoint in our out
This commit is contained in:
parent
3fda6dd727
commit
c6f91352d9
2 changed files with 13 additions and 2 deletions
|
@ -4,8 +4,9 @@ import torchvision
|
|||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Subset
|
||||
|
||||
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42):
|
||||
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42, teacher_datapt_out=False):
|
||||
print(f"Train batch size: {train_batch_size}")
|
||||
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
train_transform = transforms.Compose([
|
||||
|
@ -30,6 +31,16 @@ def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_v
|
|||
seed = torch.Generator().manual_seed(seed_val)
|
||||
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
|
||||
teacher_set = subsets[0]
|
||||
#if removing datapoint from teacher:
|
||||
if teacher_datapt_out:
|
||||
teacher_indices = teacher_set.indices
|
||||
size = len(teacher_set)
|
||||
index_to_remove = torch.randint(0, size, (1,)).item() # Randomly select one index
|
||||
keep_bool = torch.ones(size, dtype=torch.bool)
|
||||
keep_bool[index_to_remove] = False
|
||||
keep_indices = torch.tensor(teacher_indices)[keep_bool]
|
||||
teacher_set = Subset(trainset, keep_indices.tolist())
|
||||
|
||||
student_set = subsets[1]
|
||||
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
|
||||
|
||||
|
|
|
@ -149,7 +149,7 @@ def train(args):
|
|||
|
||||
checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False
|
||||
#get specific teacher set
|
||||
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
|
||||
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED, True)
|
||||
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
||||
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
||||
loaders = trainloader, testloader
|
||||
|
|
Loading…
Reference in a new issue