From ad666283c52682ad610d7620810c1d70b221cdbc Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Mon, 2 Dec 2024 18:45:04 -0700 Subject: [PATCH] Wres: clean up student a bit --- wresnet-pytorch/src/distillation_train.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/wresnet-pytorch/src/distillation_train.py b/wresnet-pytorch/src/distillation_train.py index 18b28b0..0ea1daf 100644 --- a/wresnet-pytorch/src/distillation_train.py +++ b/wresnet-pytorch/src/distillation_train.py @@ -63,7 +63,7 @@ def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_ra print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}") @torch.no_grad() -def test(model, device, test_dl, teacher=False): +def test(model, device, test_dl, is_teacher=False): model.to(device) model.eval() @@ -72,7 +72,7 @@ def test(model, device, test_dl, teacher=False): for inputs, labels in test_dl: inputs, labels = inputs.to(device), labels.to(device) - if teacher: + if is_teacher: outputs, _, _, _ = model(inputs) else: outputs = model(inputs) @@ -82,7 +82,6 @@ def test(model, device, test_dl, teacher=False): correct += (predicted == labels).sum().item() accuracy = 100 * correct / total - print(f"Test Accuracy: {accuracy:.2f}%") return accuracy def main(): @@ -158,7 +157,7 @@ def main(): torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt") print("Testing student and teacher") - test_student = test(student, device, test_loader,) + test_student = test(student, device, test_loader) test_teacher = test(teacher, device, test_loader, True) print(f"Teacher accuracy: {test_teacher:.2f}%") print(f"Student accuracy: {test_student:.2f}%")