From c6f91352d9bf50a4a7b4d0da1121073019173aca Mon Sep 17 00:00:00 2001 From: ARVP Date: Wed, 4 Dec 2024 07:45:06 -0700 Subject: [PATCH] added if statement to either keep the datapoint in our out --- wresnet-pytorch/src/distillation_utils.py | 13 ++++++++++++- wresnet-pytorch/src/train.py | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/wresnet-pytorch/src/distillation_utils.py b/wresnet-pytorch/src/distillation_utils.py index 302a34a..c761232 100644 --- a/wresnet-pytorch/src/distillation_utils.py +++ b/wresnet-pytorch/src/distillation_utils.py @@ -4,8 +4,9 @@ import torchvision from torchvision import transforms from torchvision.datasets import CIFAR10 import torch.nn.functional as F +from torch.utils.data import Subset -def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42): +def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_val=42, teacher_datapt_out=False): print(f"Train batch size: {train_batch_size}") normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) train_transform = transforms.Compose([ @@ -30,6 +31,16 @@ def get_teacherstudent_trainset(train_batch_size=128, test_batch_size=10, seed_v seed = torch.Generator().manual_seed(seed_val) subsets = random_split(trainset, [0.5, 0.5], generator=seed) teacher_set = subsets[0] + #if removing datapoint from teacher: + if teacher_datapt_out: + teacher_indices = teacher_set.indices + size = len(teacher_set) + index_to_remove = torch.randint(0, size, (1,)).item() # Randomly select one index + keep_bool = torch.ones(size, dtype=torch.bool) + keep_bool[index_to_remove] = False + keep_indices = torch.tensor(teacher_indices)[keep_bool] + teacher_set = Subset(trainset, keep_indices.tolist()) + student_set = subsets[1] testset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=False, transform=test_transform) diff --git a/wresnet-pytorch/src/train.py b/wresnet-pytorch/src/train.py index 4a60f53..5746a7b 100644 --- a/wresnet-pytorch/src/train.py +++ b/wresnet-pytorch/src/train.py @@ -149,7 +149,7 @@ def train(args): checkpoint = True if training_configurations.checkpoint.lower() == 'true' else False #get specific teacher set - teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED) + teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED, True) trainloader = DataLoader(teacherset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4) testloader = DataLoader(testset, batch_size=training_configurations.batch_size, shuffle=True, num_workers=4) loaders = trainloader, testloader