Wres: add distillation code
This commit is contained in:
parent
424cb01a15
commit
0eb26f8979
2 changed files with 199 additions and 0 deletions
170
wresnet-pytorch/src/distillation_train.py
Normal file
170
wresnet-pytorch/src/distillation_train.py
Normal file
|
@ -0,0 +1,170 @@
|
|||
from utils import json_file_to_pyobj, get_loaders
|
||||
from WideResNet import WideResNet
|
||||
from opacus.validators import ModuleValidator
|
||||
import os
|
||||
from pathlib import Path
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torch.utils.data import DataLoader
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models, transforms
|
||||
import student_model
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import opacus
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def train_knowledge_distillation(teacher, student, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
|
||||
# Dataset
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
||||
]
|
||||
)
|
||||
datadir = Path().home() / "opt/data/cifar"
|
||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=transform)
|
||||
train_dl = DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=4)
|
||||
|
||||
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
||||
|
||||
teacher.eval() # Teacher set to evaluation mode
|
||||
student.train() # Student to train mode
|
||||
|
||||
for epoch in range(epochs):
|
||||
running_loss = 0.0
|
||||
for inputs, labels in train_dl:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
|
||||
with torch.no_grad():
|
||||
teacher_logits, _, _, _ = teacher(inputs)
|
||||
|
||||
# Forward pass with the student model
|
||||
student_logits = student(inputs)
|
||||
#Soften the student logits by applying softmax first and log() second
|
||||
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
|
||||
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
|
||||
|
||||
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
|
||||
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
|
||||
|
||||
# Calculate the true label loss
|
||||
label_loss = ce_loss(student_logits, labels)
|
||||
|
||||
# Weighted sum of the two losses
|
||||
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
|
||||
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
|
||||
|
||||
|
||||
|
||||
def test(model, device, teacher=False):
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
||||
]
|
||||
)
|
||||
datadir = Path().home() / "opt/data/cifar"
|
||||
test_ds = CIFAR10(root=datadir, train=True, download=False, transform=transform)
|
||||
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4
|
||||
)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, labels in test_dl:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
if teacher:
|
||||
outputs, _, _, _ = model(inputs)
|
||||
else:
|
||||
outputs = model(inputs)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
print(f"Test Accuracy: {accuracy:.2f}%")
|
||||
return accuracy
|
||||
|
||||
|
||||
def main():
|
||||
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
|
||||
training_configurations = json_options.training
|
||||
|
||||
wrn_depth = training_configurations.wrn_depth
|
||||
wrn_width = training_configurations.wrn_width
|
||||
dataset = training_configurations.dataset.lower()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda:0')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
epochs=10
|
||||
|
||||
print("Load the teacher model")
|
||||
# instantiate teacher model
|
||||
strides = [1, 1, 2, 2]
|
||||
teacher = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides)
|
||||
teacher = ModuleValidator.fix(teacher)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
|
||||
scheduler = MultiStepLR(optimizer, milestones=[int(elem*epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
|
||||
train_loader, test_loader = get_loaders(dataset, training_configurations.batch_size)
|
||||
best_test_set_accuracy = 0
|
||||
dp_epsilon = 8
|
||||
dp_delta = 1e-5
|
||||
norm = 1.0
|
||||
privacy_engine = opacus.PrivacyEngine()
|
||||
teacher, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
||||
module=teacher,
|
||||
optimizer=optimizer,
|
||||
data_loader=train_loader,
|
||||
epochs=epochs,
|
||||
target_epsilon=dp_epsilon,
|
||||
target_delta=dp_delta,
|
||||
max_grad_norm=norm,
|
||||
)
|
||||
|
||||
|
||||
|
||||
teacher.load_state_dict(torch.load(os.path.join("wrn-1733078278-8e-1e-05d-12.0n-dict.pt"), weights_only=True))
|
||||
teacher.to(device)
|
||||
teacher.eval()
|
||||
#instantiate istudent
|
||||
student = student_model.Model(num_classes=10).to(device)
|
||||
|
||||
|
||||
print("Training student")
|
||||
#train_knowledge_distillation(teacher=teacher, student=student, epochs=100, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
|
||||
#test_student = test(student, device)
|
||||
test_teacher = test(teacher, device, True)
|
||||
print(f"Teacher accuracy: {test_teacher:.2f}%")
|
||||
#print(f"Student accuracy: {test_student:.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
29
wresnet-pytorch/src/student_model.py
Normal file
29
wresnet-pytorch/src/student_model.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
|
||||
class ModifiedLightNNCosine(nn.Module):
|
||||
def __init__(self, num_classes=10):
|
||||
super(ModifiedLightNNCosine, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(1024, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(256, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
flattened_conv_output = torch.flatten(x, 1)
|
||||
x = self.classifier(flattened_conv_output)
|
||||
return x
|
||||
|
||||
Model = ModifiedLightNNCosine
|
Loading…
Reference in a new issue