diff --git a/one_run_audit/audit.py b/one_run_audit/audit.py index e3ceee5..1083156 100644 --- a/one_run_audit/audit.py +++ b/one_run_audit/audit.py @@ -545,7 +545,7 @@ def main(): model_init, model_trained = train_knowledge_distillation( teacher=teacher_trained, train_dl=train_dl, - epochs=100, + epochs=hp['epochs'], device=DEVICE, learning_rate=0.001, T=2,