changed to same data loaders as train.py and added saving student model

This commit is contained in:
Ruby 2024-12-01 15:33:03 -07:00
parent 7208c16efc
commit 5be312bf18

View file

@ -1,3 +1,6 @@
from datetime import datetime
import time
from utils import json_file_to_pyobj, get_loaders from utils import json_file_to_pyobj, get_loaders
from WideResNet import WideResNet from WideResNet import WideResNet
from opacus.validators import ModuleValidator from opacus.validators import ModuleValidator
@ -18,21 +21,8 @@ import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
def train_knowledge_distillation(teacher, student, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device): def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
# Dataset # Dataset
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
]
)
datadir = Path().home() / "opt/data/cifar"
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=transform)
train_dl = DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=4)
ce_loss = nn.CrossEntropyLoss() ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate) optimizer = optim.Adam(student.parameters(), lr=learning_rate)
@ -72,28 +62,14 @@ def train_knowledge_distillation(teacher, student, epochs, learning_rate, T, sof
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}") print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
@torch.no_grad()
def test(model, device, test_dl, teacher=False):
def test(model, device, teacher=False):
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
]
)
datadir = Path().home() / "opt/data/cifar"
test_ds = CIFAR10(root=datadir, train=True, download=False, transform=transform)
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4
)
model.to(device) model.to(device)
model.eval() model.eval()
correct = 0 correct = 0
total = 0 total = 0
with torch.no_grad():
for inputs, labels in test_dl: for inputs, labels in test_dl:
inputs, labels = inputs.to(device), labels.to(device) inputs, labels = inputs.to(device), labels.to(device)
if teacher: if teacher:
@ -109,7 +85,6 @@ def test(model, device, teacher=False):
print(f"Test Accuracy: {accuracy:.2f}%") print(f"Test Accuracy: {accuracy:.2f}%")
return accuracy return accuracy
def main(): def main():
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json") json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training training_configurations = json_options.training
@ -156,13 +131,17 @@ def main():
print("Training student") print("Training student")
#train_knowledge_distillation(teacher=teacher, student=student, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device) train_knowledge_distillation(teacher=teacher, student=student, train_dl=train_loader, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
#test_student = test(student, device) print("Saving student")
test_teacher = test(teacher, device, True) current_datetime = datetime.now()
filename = f"students/studentmodel{current_datetime.strftime('%Y%m%d_%H%M%S')}.pt"
torch.save(student.state_dict(), filename)
print("Testing student and teacher")
test_student = test(student, device, test_loader,)
test_teacher = test(teacher, device, test_loader, True)
print(f"Teacher accuracy: {test_teacher:.2f}%") print(f"Teacher accuracy: {test_teacher:.2f}%")
#print(f"Student accuracy: {test_student:.2f}%") print(f"Student accuracy: {test_student:.2f}%")
if __name__ == "__main__": if __name__ == "__main__":
main() main()