From 3c07d6fe6b324da8bc5494dda50f02b7e880507e Mon Sep 17 00:00:00 2001 From: akemi Date: Wed, 4 Dec 2024 23:36:36 -0700 Subject: [PATCH] O1: update student to use pure train dataloader --- one_run_audit/audit.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/one_run_audit/audit.py b/one_run_audit/audit.py index d25de1a..5ae3ca0 100644 --- a/one_run_audit/audit.py +++ b/one_run_audit/audit.py @@ -121,9 +121,10 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long()) train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4) + pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4) - return train_dl, test_dl, x_in, x_m, y_m, S_m + return train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m def evaluate_on(model, dataloader): @@ -297,7 +298,7 @@ def main(): hp['norm'], )) - train_dl, test_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size']) + train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size']) print(f"len train: {len(train_dl)}") print(f"Got vector Sm: {S_m.shape}, sum={np.sum(S_m)}") print(f"Got x_in: {x_in.shape}") @@ -319,7 +320,7 @@ def main(): print("Training Student Model") model_init, model_trained = train_knowledge_distillation( teacher=teacher_trained, - train_dl=train_dl, + train_dl=pure_train_dl, epochs=100, device=DEVICE, learning_rate=0.001,