O1: update student to use pure train dataloader
This commit is contained in:
parent
1200907c31
commit
3c07d6fe6b
1 changed files with 4 additions and 3 deletions
|
@ -121,9 +121,10 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
|
|
||||||
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
|
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
|
||||||
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||||
|
pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||||
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
return train_dl, test_dl, x_in, x_m, y_m, S_m
|
return train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m
|
||||||
|
|
||||||
|
|
||||||
def evaluate_on(model, dataloader):
|
def evaluate_on(model, dataloader):
|
||||||
|
@ -297,7 +298,7 @@ def main():
|
||||||
hp['norm'],
|
hp['norm'],
|
||||||
))
|
))
|
||||||
|
|
||||||
train_dl, test_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size'])
|
train_dl, test_dl, pure_train_dl, x_in, x_m, y_m, S_m = get_dataloaders(hp['target_points'], hp['batch_size'])
|
||||||
print(f"len train: {len(train_dl)}")
|
print(f"len train: {len(train_dl)}")
|
||||||
print(f"Got vector Sm: {S_m.shape}, sum={np.sum(S_m)}")
|
print(f"Got vector Sm: {S_m.shape}, sum={np.sum(S_m)}")
|
||||||
print(f"Got x_in: {x_in.shape}")
|
print(f"Got x_in: {x_in.shape}")
|
||||||
|
@ -319,7 +320,7 @@ def main():
|
||||||
print("Training Student Model")
|
print("Training Student Model")
|
||||||
model_init, model_trained = train_knowledge_distillation(
|
model_init, model_trained = train_knowledge_distillation(
|
||||||
teacher=teacher_trained,
|
teacher=teacher_trained,
|
||||||
train_dl=train_dl,
|
train_dl=pure_train_dl,
|
||||||
epochs=100,
|
epochs=100,
|
||||||
device=DEVICE,
|
device=DEVICE,
|
||||||
learning_rate=0.001,
|
learning_rate=0.001,
|
||||||
|
|
Loading…
Reference in a new issue