added if statement to either keep the datapoint in our out
This commit is contained in:
parent
3fda6dd727
commit
c6f91352d9
2 changed files with 13 additions and 2 deletions
|
@ -4,8 +4,9 @@ import torchvision
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import Subset
|
||||||
|
|
||||||
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42):
|
def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42, teacher_datapt_out=False):
|
||||||
print(f"Train batch size: {train_batch_size}")
|
print(f"Train batch size: {train_batch_size}")
|
||||||
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||||
train_transform = transforms.Compose([
|
train_transform = transforms.Compose([
|
||||||
|
@ -30,6 +31,16 @@ def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_v
|
||||||
seed = torch.Generator().manual_seed(seed_val)
|
seed = torch.Generator().manual_seed(seed_val)
|
||||||
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
|
subsets = random_split(trainset, [0.5, 0.5], generator=seed)
|
||||||
teacher_set = subsets[0]
|
teacher_set = subsets[0]
|
||||||
|
#if removing datapoint from teacher:
|
||||||
|
if teacher_datapt_out:
|
||||||
|
teacher_indices = teacher_set.indices
|
||||||
|
size = len(teacher_set)
|
||||||
|
index_to_remove = torch.randint(0, size, (1,)).item() # Randomly select one index
|
||||||
|
keep_bool = torch.ones(size, dtype=torch.bool)
|
||||||
|
keep_bool[index_to_remove] = False
|
||||||
|
keep_indices = torch.tensor(teacher_indices)[keep_bool]
|
||||||
|
teacher_set = Subset(trainset, keep_indices.tolist())
|
||||||
|
|
||||||
student_set = subsets[1]
|
student_set = subsets[1]
|
||||||
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
|
testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform)
|
||||||
|
|
||||||
|
|
|
@ -149,7 +149,7 @@ def train(args):
|
||||||
|
|
||||||
checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False
|
checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False
|
||||||
#get specific teacher set
|
#get specific teacher set
|
||||||
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
|
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED, True)
|
||||||
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
||||||
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4)
|
||||||
loaders = trainloader, testloader
|
loaders = trainloader, testloader
|
||||||
|
|
Loading…
Reference in a new issue