Wres: clean up student a bit
This commit is contained in:
parent
524576db31
commit
ad666283c5
1 changed files with 3 additions and 4 deletions
|
@ -63,7 +63,7 @@ def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_ra
|
||||||
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
|
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test(model, device, test_dl, teacher=False):
|
def test(model, device, test_dl, is_teacher=False):
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ def test(model, device, test_dl, teacher=False):
|
||||||
|
|
||||||
for inputs, labels in test_dl:
|
for inputs, labels in test_dl:
|
||||||
inputs, labels = inputs.to(device), labels.to(device)
|
inputs, labels = inputs.to(device), labels.to(device)
|
||||||
if teacher:
|
if is_teacher:
|
||||||
outputs, _, _, _ = model(inputs)
|
outputs, _, _, _ = model(inputs)
|
||||||
else:
|
else:
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
|
@ -82,7 +82,6 @@ def test(model, device, test_dl, teacher=False):
|
||||||
correct += (predicted == labels).sum().item()
|
correct += (predicted == labels).sum().item()
|
||||||
|
|
||||||
accuracy = 100 * correct / total
|
accuracy = 100 * correct / total
|
||||||
print(f"Test Accuracy: {accuracy:.2f}%")
|
|
||||||
return accuracy
|
return accuracy
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -158,7 +157,7 @@ def main():
|
||||||
torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt")
|
torch.save(student.state_dict(), f"students/studentmodel-{int(time.time())}.pt")
|
||||||
|
|
||||||
print("Testing student and teacher")
|
print("Testing student and teacher")
|
||||||
test_student = test(student, device, test_loader,)
|
test_student = test(student, device, test_loader)
|
||||||
test_teacher = test(teacher, device, test_loader, True)
|
test_teacher = test(teacher, device, test_loader, True)
|
||||||
print(f"Teacher accuracy: {test_teacher:.2f}%")
|
print(f"Teacher accuracy: {test_teacher:.2f}%")
|
||||||
print(f"Student accuracy: {test_student:.2f}%")
|
print(f"Student accuracy: {test_student:.2f}%")
|
||||||
|
|
Loading…
Reference in a new issue