Wres: clean up student a bit

This commit is contained in:
Akemi Izuko 2024-12-02 18:45:04 -07:00
parent 5531212b11
commit 3ea041ef65
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -63,7 +63,7 @@ def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_ra
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
@torch.no_grad()
def test(model, device, test_dl, teacher=False):
def test(model, device, test_dl, is_teacher=False):
model.to(device)
model.eval()
@ -72,7 +72,7 @@ def test(model, device, test_dl, teacher=False):
for inputs, labels in test_dl:
inputs, labels = inputs.to(device), labels.to(device)
if teacher:
if is_teacher:
outputs, _, _, _ = model(inputs)
else:
outputs = model(inputs)
@ -82,7 +82,6 @@ def test(model, device, test_dl, teacher=False):
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")
return accuracy
def main():
@ -158,7 +157,7 @@ def main():
torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt")
print("Testing student and teacher")
test_student = test(student, device, test_loader,)
test_student = test(student, device, test_loader)
test_teacher = test(teacher, device, test_loader, True)
print(f"Teacher accuracy: {test_teacher:.2f}%")
print(f"Student accuracy: {test_student:.2f}%")