added if statement to either keep the datapoint in our out

This commit is contained in:
ARVP 2024-12-04 07:45:06 -07:00
parent 3fda6dd727
commit c6f91352d9
2 changed files with 13 additions and 2 deletions

View file

@ -4,8 +4,9 @@ import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
from torch.utils.data import Subset
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42):
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42, teacher_datapt_out=False):
print(f"Train batch size: {train_batch_size}")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transform = transforms.Compose([
@ -30,6 +31,16 @@ def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_v
seed = torch.Generator().manual_seed(seed_val)
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
teacher_set = subsets[0]
#if removing datapoint from teacher:
if teacher_datapt_out:
teacher_indices = teacher_set.indices
size = len(teacher_set)
index_to_remove = torch.randint(0, size, (1,)).item() # Randomly select one index
keep_bool = torch.ones(size, dtype=torch.bool)
keep_bool[index_to_remove] = False
keep_indices = torch.tensor(teacher_indices)[keep_bool]
teacher_set = Subset(trainset, keep_indices.tolist())
student_set = subsets[1]
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)

View file

@ -149,7 +149,7 @@ def train(args):
checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False
#get specific teacher set
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED, True)
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
loaders = trainloader, testloader