changed to same data loaders as train.py and added saving student model
This commit is contained in:
parent
7208c16efc
commit
5be312bf18
1 changed files with 24 additions and 45 deletions
|
@ -1,3 +1,6 @@
|
|||
from datetime import datetime
|
||||
import time
|
||||
|
||||
from utils import json_file_to_pyobj, get_loaders
|
||||
from WideResNet import WideResNet
|
||||
from opacus.validators import ModuleValidator
|
||||
|
@ -18,21 +21,8 @@ import warnings
|
|||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def train_knowledge_distillation(teacher, student, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
|
||||
def train_knowledge_distillation(teacher, student, train_dl, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
|
||||
# Dataset
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
||||
]
|
||||
)
|
||||
datadir = Path().home() / "opt/data/cifar"
|
||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=transform)
|
||||
train_dl = DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=4)
|
||||
|
||||
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
||||
|
||||
|
@ -72,44 +62,29 @@ def train_knowledge_distillation(teacher, student, epochs, learning_rate, T, sof
|
|||
|
||||
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
|
||||
|
||||
|
||||
|
||||
def test(model, device, teacher=False):
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
||||
]
|
||||
)
|
||||
datadir = Path().home() / "opt/data/cifar"
|
||||
test_ds = CIFAR10(root=datadir, train=True, download=False, transform=transform)
|
||||
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test(model, device, test_dl, teacher=False):
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, labels in test_dl:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
if teacher:
|
||||
outputs, _, _, _ = model(inputs)
|
||||
else:
|
||||
outputs = model(inputs)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
for inputs, labels in test_dl:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
if teacher:
|
||||
outputs, _, _, _ = model(inputs)
|
||||
else:
|
||||
outputs = model(inputs)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
print(f"Test Accuracy: {accuracy:.2f}%")
|
||||
return accuracy
|
||||
|
||||
|
||||
def main():
|
||||
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
|
||||
training_configurations = json_options.training
|
||||
|
@ -156,13 +131,17 @@ def main():
|
|||
|
||||
|
||||
print("Training student")
|
||||
#train_knowledge_distillation(teacher=teacher, student=student, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
|
||||
#test_student = test(student, device)
|
||||
test_teacher = test(teacher, device, True)
|
||||
train_knowledge_distillation(teacher=teacher, student=student, train_dl=train_loader, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
|
||||
print("Saving student")
|
||||
current_datetime = datetime.now()
|
||||
filename = f"students/studentmodel{current_datetime.strftime('%Y%m%d_%H%M%S')}.pt"
|
||||
torch.save(student.state_dict(), filename)
|
||||
print("Testing student and teacher")
|
||||
test_student = test(student, device, test_loader,)
|
||||
test_teacher = test(teacher, device, test_loader, True)
|
||||
print(f"Teacher accuracy: {test_teacher:.2f}%")
|
||||
#print(f"Student accuracy: {test_student:.2f}%")
|
||||
print(f"Student accuracy: {test_student:.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
Loading…
Reference in a new issue