O1: update student to use pure train dataloader

This commit is contained in:
Akemi Izuko 2024-12-04 23:36:36 -07:00
parent 1200907c31
commit 3c07d6fe6b

View file

@ -121,9 +121,10 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
return train_dl, test_dl, x_in, x_m, y_m, S_m
return train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m
def evaluate_on(model, dataloader):
@ -297,7 +298,7 @@ def main():
hp['norm'],
))
train_dl, test_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size'])
train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size'])
print(f"len train: {len(train_dl)}")
print(f"Got vector Sm: {S_m.shape}, sum={np.sum(S_m)}")
print(f"Got x_in: {x_in.shape}")
@ -319,7 +320,7 @@ def main():
print("Training Student Model")
model_init, model_trained = train_knowledge_distillation(
teacher=teacher_trained,
train_dl=train_dl,
train_dl=pure_train_dl,
epochs=100,
device=DEVICE,
learning_rate=0.001,