O1: update student to use pure train dataloader
This commit is contained in:
parent
1200907c31
commit
3c07d6fe6b
1 changed files with 4 additions and 3 deletions
|
@ -121,9 +121,10 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
|||
|
||||
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
|
||||
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||
pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||
|
||||
return train_dl, test_dl, x_in, x_m, y_m, S_m
|
||||
return train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m
|
||||
|
||||
|
||||
def evaluate_on(model, dataloader):
|
||||
|
@ -297,7 +298,7 @@ def main():
|
|||
hp['norm'],
|
||||
))
|
||||
|
||||
train_dl, test_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size'])
|
||||
train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size'])
|
||||
print(f"len train: {len(train_dl)}")
|
||||
print(f"Got vector Sm: {S_m.shape}, sum={np.sum(S_m)}")
|
||||
print(f"Got x_in: {x_in.shape}")
|
||||
|
@ -319,7 +320,7 @@ def main():
|
|||
print("Training Student Model")
|
||||
model_init, model_trained = train_knowledge_distillation(
|
||||
teacher=teacher_trained,
|
||||
train_dl=train_dl,
|
||||
train_dl=pure_train_dl,
|
||||
epochs=100,
|
||||
device=DEVICE,
|
||||
learning_rate=0.001,
|
||||
|
|
Loading…
Reference in a new issue