From 9196532b6afd3b9a0b99619ce56f60789af713a6 Mon Sep 17 00:00:00 2001 From: ARVP Date: Sat, 23 Nov 2024 17:26:51 -0700 Subject: [PATCH] Distill: add student model --- .gitignore | 3 + cifar10-fast-simple/distillation.py | 495 ++++++++++++++++++++++++++++ cifar10-fast-simple/studentmodel.py | 35 ++ cifar10-fast-simple/train.py | 107 +++++- 4 files changed, 635 insertions(+), 5 deletions(-) create mode 100644 .gitignore create mode 100644 cifar10-fast-simple/distillation.py create mode 100644 cifar10-fast-simple/studentmodel.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..12ec3df --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +cifar10-fast-simple/teachermodels +cifar10-fast-simple/data +distillation/distilled_models diff --git a/cifar10-fast-simple/distillation.py b/cifar10-fast-simple/distillation.py new file mode 100644 index 0000000..6eb97f4 --- /dev/null +++ b/cifar10-fast-simple/distillation.py @@ -0,0 +1,495 @@ +#based off of https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html#prerequisites +import torchvision +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from torch.utils.data import Subset + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +transforms_cifar = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + +# Loading the CIFAR-10 dataset: +train_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar) +test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar) + +#Dataloaders +train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2) +test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2) + + +# Deeper neural network class to be used as teacher: +class DeepNN(nn.Module): + def __init__(self, num_classes=10): + super(DeepNN, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 128, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(64, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + self.classifier = nn.Sequential( + nn.Linear(2048, 512), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(512, num_classes) + ) + + def forward(self, x): + x = self.features(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + +# Lightweight convolutional neural network class to be used as student: +class LightNN(nn.Module): + def __init__(self, num_classes=10): + super(LightNN, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(16, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + self.classifier = nn.Sequential( + nn.Linear(1024, 256), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(256, num_classes) + ) + + def forward(self, x): + x = self.features(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + +def train(model, train_loader, epochs, learning_rate, device): + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + model.train() + + for epoch in range(epochs): + running_loss = 0.0 + for inputs, labels in train_loader: + # inputs: A collection of batch_size images + # labels: A vector of dimensionality batch_size with integers denoting class of each image + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + outputs = model(inputs) + + # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes + # labels: The actual labels of the images. Vector of dimensionality batch_size + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + + print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}") + +def test(model, test_loader, device): + model.to(device) + model.eval() + + correct = 0 + total = 0 + + with torch.no_grad(): + for inputs, labels in test_loader: + inputs, labels = inputs.to(device), labels.to(device) + + outputs = model(inputs) + _, predicted = torch.max(outputs.data, 1) + + total += labels.size(0) + correct += (predicted == labels).sum().item() + + accuracy = 100 * correct / total + print(f"Test Accuracy: {accuracy:.2f}%") + return accuracy + +#train teacher model +torch.manual_seed(42) +nn_deep = DeepNN(num_classes=10).to(device) +train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device) +test_accuracy_deep = test(nn_deep, test_loader, device) + +# Instantiate the lightweight network: +torch.manual_seed(42) +nn_light = LightNN(num_classes=10).to(device) + + +torch.manual_seed(42) +new_nn_light = LightNN(num_classes=10).to(device) + +# Print the norm of the first layer of the initial lightweight model +print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item()) +# Print the norm of the first layer of the new lightweight model +print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item()) + + +total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters())) +print(f"DeepNN parameters: {total_params_deep}") +total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters())) +print(f"LightNN parameters: {total_params_light}") + +train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device) +test_accuracy_light_ce = test(nn_light, test_loader, device) + +print(f"Teacher accuracy: {test_accuracy_deep:.2f}%") +print(f"Student accuracy: {test_accuracy_light_ce:.2f}%") + + +def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device): + ce_loss = nn.CrossEntropyLoss() + optimizer = optim.Adam(student.parameters(), lr=learning_rate) + + teacher.eval() # Teacher set to evaluation mode + student.train() # Student to train mode + + for epoch in range(epochs): + running_loss = 0.0 + for inputs, labels in train_loader: + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + + # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights + with torch.no_grad(): + teacher_logits = teacher(inputs) + + # Forward pass with the student model + student_logits = student(inputs) + + #Soften the student logits by applying softmax first and log() second + soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1) + soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1) + + # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network" + soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2) + + # Calculate the true label loss + label_loss = ce_loss(student_logits, labels) + + # Weighted sum of the two losses + loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss + + loss.backward() + optimizer.step() + + running_loss += loss.item() + + print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}") + +train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device) +test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device) + +# Compare the student test accuracy with and without the teacher, after distillation +print(f"Teacher accuracy: {test_accuracy_deep:.2f}%") +print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%") +print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%") + + +class ModifiedDeepNNCosine(nn.Module): + def __init__(self, num_classes=10): + super(ModifiedDeepNNCosine, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 128, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(64, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + self.classifier = nn.Sequential( + nn.Linear(2048, 512), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(512, num_classes) + ) + + def forward(self, x): + x = self.features(x) + flattened_conv_output = torch.flatten(x, 1) + x = self.classifier(flattened_conv_output) + flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2) + return x, flattened_conv_output_after_pooling + +# Create a similar student class where we return a tuple. We do not apply pooling after flattening. +class ModifiedLightNNCosine(nn.Module): + def __init__(self, num_classes=10): + super(ModifiedLightNNCosine, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(16, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + self.classifier = nn.Sequential( + nn.Linear(1024, 256), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(256, num_classes) + ) + + def forward(self, x): + x = self.features(x) + flattened_conv_output = torch.flatten(x, 1) + x = self.classifier(flattened_conv_output) + return x, flattened_conv_output + +# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance +modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device) +modified_nn_deep.load_state_dict(nn_deep.state_dict()) + +# Once again ensure the norm of the first layer is the same for both networks +print("Norm of 1st layer for deep_nn:", torch.norm(nn_deep.features[0].weight).item()) +print("Norm of 1st layer for modified_deep_nn:", torch.norm(modified_nn_deep.features[0].weight).item()) + +# Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization. +torch.manual_seed(42) +modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device) +print("Norm of 1st layer:", torch.norm(modified_nn_light.features[0].weight).item()) + +# Create a sample input tensor +sample_input = torch.randn(128, 3, 32, 32).to(device) # Batch size: 128, Filters: 3, Image size: 32x32 + +# Pass the input through the student +logits, hidden_representation = modified_nn_light(sample_input) + +# Print the shapes of the tensors +print("Student logits shape:", logits.shape) # batch_size x total_classes +print("Student hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size + +# Pass the input through the teacher +logits, hidden_representation = modified_nn_deep(sample_input) + +# Print the shapes of the tensors +print("Teacher logits shape:", logits.shape) # batch_size x total_classes +print("Teacher hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size + +def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device): + ce_loss = nn.CrossEntropyLoss() + cosine_loss = nn.CosineEmbeddingLoss() + optimizer = optim.Adam(student.parameters(), lr=learning_rate) + + teacher.to(device) + student.to(device) + teacher.eval() # Teacher set to evaluation mode + student.train() # Student to train mode + + for epoch in range(epochs): + running_loss = 0.0 + for inputs, labels in train_loader: + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + + # Forward pass with the teacher model and keep only the hidden representation + with torch.no_grad(): + _, teacher_hidden_representation = teacher(inputs) + + # Forward pass with the student model + student_logits, student_hidden_representation = student(inputs) + + # Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase. + hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device)) + + # Calculate the true label loss + label_loss = ce_loss(student_logits, labels) + + # Weighted sum of the two losses + loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss + + loss.backward() + optimizer.step() + + running_loss += loss.item() + + print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}") + +def test_multiple_outputs(model, test_loader, device): + model.to(device) + model.eval() + + correct = 0 + total = 0 + + with torch.no_grad(): + for inputs, labels in test_loader: + inputs, labels = inputs.to(device), labels.to(device) + + outputs, _ = model(inputs) # Disregard the second tensor of the tuple + _, predicted = torch.max(outputs.data, 1) + + total += labels.size(0) + correct += (predicted == labels).sum().item() + + accuracy = 100 * correct / total + print(f"Test Accuracy: {accuracy:.2f}%") + return accuracy + +# Train and test the lightweight network with cross entropy loss +train_cosine_loss(teacher=modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device) +test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light, test_loader, device) + + +# Pass the sample input only from the convolutional feature extractor +convolutional_fe_output_student = nn_light.features(sample_input) +convolutional_fe_output_teacher = nn_deep.features(sample_input) + +# Print their shapes +print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape) +print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape) + +class ModifiedDeepNNRegressor(nn.Module): + def __init__(self, num_classes=10): + super(ModifiedDeepNNRegressor, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 128, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(64, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + self.classifier = nn.Sequential( + nn.Linear(2048, 512), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(512, num_classes) + ) + + def forward(self, x): + x = self.features(x) + conv_feature_map = x + x = torch.flatten(x, 1) + x = self.classifier(x) + return x, conv_feature_map + +class ModifiedLightNNRegressor(nn.Module): + def __init__(self, num_classes=10): + super(ModifiedLightNNRegressor, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(16, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + # Include an extra regressor (in our case linear) + self.regressor = nn.Sequential( + nn.Conv2d(16, 32, kernel_size=3, padding=1) + ) + self.classifier = nn.Sequential( + nn.Linear(1024, 256), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(256, num_classes) + ) + + def forward(self, x): + x = self.features(x) + regressor_output = self.regressor(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x, regressor_output + + +def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device): + ce_loss = nn.CrossEntropyLoss() + mse_loss = nn.MSELoss() + optimizer = optim.Adam(student.parameters(), lr=learning_rate) + + teacher.to(device) + student.to(device) + teacher.eval() # Teacher set to evaluation mode + student.train() # Student to train mode + + for epoch in range(epochs): + running_loss = 0.0 + for inputs, labels in train_loader: + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + + # Again ignore teacher logits + with torch.no_grad(): + _, teacher_feature_map = teacher(inputs) + + # Forward pass with the student model + student_logits, regressor_feature_map = student(inputs) + + # Calculate the loss + hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map) + + # Calculate the true label loss + label_loss = ce_loss(student_logits, labels) + + # Weighted sum of the two losses + loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss + + loss.backward() + optimizer.step() + + running_loss += loss.item() + + print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}") + + +# Initialize a ModifiedLightNNRegressor +torch.manual_seed(42) +modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device) + +# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance +modified_nn_deep_reg = ModifiedDeepNNRegressor(num_classes=10).to(device) +modified_nn_deep_reg.load_state_dict(nn_deep.state_dict()) + +# Train and test once again +train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device) +test_accuracy_light_ce_and_mse_loss = test_multiple_outputs(modified_nn_light_reg, test_loader, device) + +print(f"Teacher accuracy: {test_accuracy_deep:.2f}%") +print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%") +print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%") +print(f"Student accuracy with CE + CosineLoss: {test_accuracy_light_ce_and_cosine_loss:.2f}%") +print(f"Student accuracy with CE + RegressorMSE: {test_accuracy_light_ce_and_mse_loss:.2f}%") + +# +# For more information, see: +# +# - [Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a +# neural network. In: Neural Information Processing System Deep +# Learning Workshop (2015)](https://arxiv.org/abs/1503.02531) +# - [Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C., +# Bengio, Y.: Fitnets: Hints for thin deep nets. In: Proceedings of +# the International Conference on Learning +# Representations (2015)](https://arxiv.org/abs/1412.6550) +# diff --git a/cifar10-fast-simple/studentmodel.py b/cifar10-fast-simple/studentmodel.py new file mode 100644 index 0000000..071ffff --- /dev/null +++ b/cifar10-fast-simple/studentmodel.py @@ -0,0 +1,35 @@ +import torch.nn as nn + + +# Lightweight neural network class to be used as student: +class ModifiedLightNNRegressor(nn.Module): + def __init__(self, num_classes=10): + super(ModifiedLightNNRegressor, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(16, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + # Include an extra regressor (in our case linear) + self.regressor = nn.Sequential( + nn.Conv2d(16, 32, kernel_size=3, padding=1) + ) + self.classifier = nn.Sequential( + nn.Linear(1024, 256), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(256, num_classes) + ) + + def forward(self, x): + x = self.features(x) + regressor_output = self.regressor(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x, regressor_output + + +Model = ModifiedLightNNRegressor diff --git a/cifar10-fast-simple/train.py b/cifar10-fast-simple/train.py index 0dba02a..7b15ca7 100644 --- a/cifar10-fast-simple/train.py +++ b/cifar10-fast-simple/train.py @@ -4,6 +4,13 @@ import torch import torch.nn as nn import torchvision import model +import studentmodel +import torch.optim as optim +import torchvision.datasets as datasets +import torchvision.transforms as transforms + + +from datetime import datetime def train(seed=0): @@ -45,7 +52,7 @@ def train(seed=0): torch.backends.cudnn.benchmark = True # Load dataset - train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype) + train_data, train_targets, valid_data, valid_targets, train_loader, test_loader = load_cifar10(device, dtype) # Compute special weights for first layer weights = model.patch_whitening(train_data[:10000, :, 4:-4, 4:-4]) @@ -169,7 +176,7 @@ def train(seed=0): print(f"{epoch:5} {batch_count:8d} {train_time:19.2f} {valid_acc:22.4f}") - return valid_acc + return valid_acc, valid_model, train_loader, test_loader, device def preprocess_data(data, device, dtype): # Convert to torch float16 tensor @@ -196,10 +203,85 @@ def load_cifar10(device, dtype, data_dir="~/data"): train_targets = torch.tensor(train.targets).to(device) valid_targets = torch.tensor(valid.targets).to(device) + #From other code... this is SOOOO BAD FIX THIS BUT JUST DOING IT TO GET CODE RUNNING FOR NOW + # Loading the CIFAR-10 dataset: + transforms_cifar = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + train_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transforms_cifar) + test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transforms_cifar) + #Dataloaders + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2) + # Pad 32x32 to 40x40 train_data = nn.ReflectionPad2d(4)(train_data) - return train_data, train_targets, valid_data, valid_targets + return train_data, train_targets, valid_data, valid_targets, train_loader, test_loader + +#used for testing models +def test_multiple_outputs(model, test_loader, device): + model.to(device) + model.eval() + + correct = 0 + total = 0 + + with torch.no_grad(): + for inputs, labels in test_loader: + inputs, labels = inputs.to(device), labels.to(device) + + outputs, _ = model(inputs) # Disregard the second tensor of the tuple + _, predicted = torch.max(outputs.data, 1) + + total += labels.size(0) + correct += (predicted == labels).sum().item() + + accuracy = 100 * correct / total + print(f"Test Accuracy: {accuracy:.2f}%") + return accuracy + +#used for distillation training +def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device): + ce_loss = nn.CrossEntropyLoss() + mse_loss = nn.MSELoss() + optimizer = optim.Adam(student.parameters(), lr=learning_rate) + + teacher.to(device) + student.to(device) + teacher.eval() # Teacher set to evaluation mode + student.train() # Student to train mode + + for epoch in range(epochs): + running_loss = 0.0 + for inputs, labels in train_loader: + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + + # Again ignore teacher logits + with torch.no_grad(): + _, teacher_feature_map = teacher(inputs) + + # Forward pass with the student model + student_logits, regressor_feature_map = student(inputs) + + # Calculate the loss + hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map) + + # Calculate the true label loss + label_loss = ce_loss(student_logits, labels) + + # Weighted sum of the two losses + loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss + + loss.backward() + optimizer.step() + + running_loss += loss.item() + + print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}") def update_ema(train_model, valid_model, rho): @@ -253,11 +335,10 @@ def print_info(): def main(): print_info() - accuracies = [] threshold = 0.94 for run in range(100): - valid_acc = train(seed=run) + valid_acc, teacher_model, train_loader, test_loader, device = train(seed=run) accuracies.append(valid_acc) # Print accumulated results @@ -272,7 +353,23 @@ def main(): print(f"Max accuracy: {max(accuracies)}") print(f"Mean accuracy: {mean} +- {std}") print() + + #saving teacher model + current_datetime = datetime.now() + filename = f"teachermodels/{current_datetime.strftime('%Y%m%d_%H%M%S')}.pt" + #torch.save(model, filename) + #distillation training for student models + distilled_model = studentmodel.Model(num_classes=10).to(device) + #train_mse_loss(teacher=teacher_model, student=distilled_model, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device) + #test_student_accuracy = test_multiple_outputs(distilled_model, test_loader, device) + #print(f"Student accuracy with CE + RegressorMSE: {test_student_accuracy:.2f}%") + #saving student model + #current_datetime = datetime.now() + #filename = f"studentmodels/{current_datetime.strftime('%Y%m%d_%H%M%S')}.pt" + #torch.save(distilled_model, filename) + + if __name__ == "__main__": main()