changed to same data loaders as train.py and added saving student model
This commit is contained in:
parent
7208c16efc
commit
5be312bf18
1 changed files with 24 additions and 45 deletions
|
@ -1,3 +1,6 @@
|
||||||
|
from datetime import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
from utils import json_file_to_pyobj, get_loaders
|
from utils import json_file_to_pyobj, get_loaders
|
||||||
from WideResNet import WideResNet
|
from WideResNet import WideResNet
|
||||||
from opacus.validators import ModuleValidator
|
from opacus.validators import ModuleValidator
|
||||||
|
@ -18,21 +21,8 @@ import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
def train_knowledge_distillation(teacher, student, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
|
def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
|
||||||
# Dataset
|
# Dataset
|
||||||
transform = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomCrop(32, padding=4),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
datadir = Path().home() / "opt/data/cifar"
|
|
||||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=transform)
|
|
||||||
train_dl = DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=4)
|
|
||||||
|
|
||||||
|
|
||||||
ce_loss = nn.CrossEntropyLoss()
|
ce_loss = nn.CrossEntropyLoss()
|
||||||
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
@ -72,28 +62,14 @@ def train_knowledge_distillation(teacher, student, epochs, learning_rate, T, sof
|
||||||
|
|
||||||
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
|
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test(model, device, test_dl, teacher=False):
|
||||||
def test(model, device, teacher=False):
|
|
||||||
transform = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomCrop(32, padding=4),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
datadir = Path().home() / "opt/data/cifar"
|
|
||||||
test_ds = CIFAR10(root=datadir, train=True, download=False, transform=transform)
|
|
||||||
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4
|
|
||||||
)
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
correct = 0
|
correct = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for inputs, labels in test_dl:
|
for inputs, labels in test_dl:
|
||||||
inputs, labels = inputs.to(device), labels.to(device)
|
inputs, labels = inputs.to(device), labels.to(device)
|
||||||
if teacher:
|
if teacher:
|
||||||
|
@ -109,7 +85,6 @@ def test(model, device, teacher=False):
|
||||||
print(f"Test Accuracy: {accuracy:.2f}%")
|
print(f"Test Accuracy: {accuracy:.2f}%")
|
||||||
return accuracy
|
return accuracy
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
|
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
|
||||||
training_configurations = json_options.training
|
training_configurations = json_options.training
|
||||||
|
@ -156,13 +131,17 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
print("Training student")
|
print("Training student")
|
||||||
#train_knowledge_distillation(teacher=teacher, student=student, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
|
train_knowledge_distillation(teacher=teacher, student=student, train_dl=train_loader, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
|
||||||
#test_student = test(student, device)
|
print("Saving student")
|
||||||
test_teacher = test(teacher, device, True)
|
current_datetime = datetime.now()
|
||||||
|
filename = f"students/studentmodel{current_datetime.strftime('%Y%m%d_%H%M%S')}.pt"
|
||||||
|
torch.save(student.state_dict(), filename)
|
||||||
|
print("Testing student and teacher")
|
||||||
|
test_student = test(student, device, test_loader,)
|
||||||
|
test_teacher = test(teacher, device, test_loader, True)
|
||||||
print(f"Teacher accuracy: {test_teacher:.2f}%")
|
print(f"Teacher accuracy: {test_teacher:.2f}%")
|
||||||
#print(f"Student accuracy: {test_student:.2f}%")
|
print(f"Student accuracy: {test_student:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue