Wres: clean up student a bit
This commit is contained in:
parent
5531212b11
commit
3ea041ef65
1 changed files with 3 additions and 4 deletions
|
@ -63,7 +63,7 @@ def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_ra
|
|||
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def test(model, device, test_dl, teacher=False):
|
||||
def test(model, device, test_dl, is_teacher=False):
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
|
@ -72,7 +72,7 @@ def test(model, device, test_dl, teacher=False):
|
|||
|
||||
for inputs, labels in test_dl:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
if teacher:
|
||||
if is_teacher:
|
||||
outputs, _, _, _ = model(inputs)
|
||||
else:
|
||||
outputs = model(inputs)
|
||||
|
@ -82,7 +82,6 @@ def test(model, device, test_dl, teacher=False):
|
|||
correct += (predicted == labels).sum().item()
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
print(f"Test Accuracy: {accuracy:.2f}%")
|
||||
return accuracy
|
||||
|
||||
def main():
|
||||
|
@ -158,7 +157,7 @@ def main():
|
|||
torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt")
|
||||
|
||||
print("Testing student and teacher")
|
||||
test_student = test(student, device, test_loader,)
|
||||
test_student = test(student, device, test_loader)
|
||||
test_teacher = test(teacher, device, test_loader, True)
|
||||
print(f"Teacher accuracy: {test_teacher:.2f}%")
|
||||
print(f"Student accuracy: {test_student:.2f}%")
|
||||
|
|
Loading…
Reference in a new issue