#based off of https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html#prerequisites import torchvision import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms import torchvision.datasets as datasets device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) transforms_cifar = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Loading the CIFAR-10 dataset: train_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar) test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar) from torch.utils.data import Subset num_images_to_keep = 2000 train_dataset = Subset(train_dataset, range(min(num_images_to_keep, 50_000))) test_dataset = Subset(test_dataset, range(min(num_images_to_keep, 10_000))) #Dataloaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2) # Deeper neural network class to be used as teacher: class DeepNN(nn.Module): def __init__(self, num_classes=10): super(DeepNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 128, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) self.classifier = nn.Sequential( nn.Linear(2048, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, num_classes) ) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) x = self.classifier(x) return x # Lightweight neural network class to be used as student: class LightNN(nn.Module): def __init__(self, num_classes=10): super(LightNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(16, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) self.classifier = nn.Sequential( nn.Linear(1024, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, num_classes) ) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) x = self.classifier(x) return x def train(model, train_loader, epochs, learning_rate, device): criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) model.train() for epoch in range(epochs): running_loss = 0.0 for inputs, labels in train_loader: # inputs: A collection of batch_size images # labels: A vector of dimensionality batch_size with integers denoting class of each image inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes # labels: The actual labels of the images. Vector of dimensionality batch_size loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}") def test(model, test_loader, device): model.to(device) model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f"Test Accuracy: {accuracy:.2f}%") return accuracy torch.manual_seed(42) nn_deep = DeepNN(num_classes=10).to(device) train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device) test_accuracy_deep = test(nn_deep, test_loader, device) # Instantiate the lightweight network: torch.manual_seed(42) nn_light = LightNN(num_classes=10).to(device) torch.manual_seed(42) new_nn_light = LightNN(num_classes=10).to(device) # Print the norm of the first layer of the initial lightweight model print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item()) # Print the norm of the first layer of the new lightweight model print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item()) total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters())) print(f"DeepNN parameters: {total_params_deep}") total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters())) print(f"LightNN parameters: {total_params_light}") train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device) test_accuracy_light_ce = test(nn_light, test_loader, device) print(f"Teacher accuracy: {test_accuracy_deep:.2f}%") print(f"Student accuracy: {test_accuracy_light_ce:.2f}%") def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device): ce_loss = nn.CrossEntropyLoss() optimizer = optim.Adam(student.parameters(), lr=learning_rate) teacher.eval() # Teacher set to evaluation mode student.train() # Student to train mode for epoch in range(epochs): running_loss = 0.0 for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights with torch.no_grad(): teacher_logits = teacher(inputs) # Forward pass with the student model student_logits = student(inputs) #Soften the student logits by applying softmax first and log() second soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1) soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1) # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network" soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2) # Calculate the true label loss label_loss = ce_loss(student_logits, labels) # Weighted sum of the two losses loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}") train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device) test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device) # Compare the student test accuracy with and without the teacher, after distillation print(f"Teacher accuracy: {test_accuracy_deep:.2f}%") print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%") print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%") class ModifiedDeepNNCosine(nn.Module): def __init__(self, num_classes=10): super(ModifiedDeepNNCosine, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 128, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) self.classifier = nn.Sequential( nn.Linear(2048, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, num_classes) ) def forward(self, x): x = self.features(x) flattened_conv_output = torch.flatten(x, 1) x = self.classifier(flattened_conv_output) flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2) return x, flattened_conv_output_after_pooling # Create a similar student class where we return a tuple. We do not apply pooling after flattening. class ModifiedLightNNCosine(nn.Module): def __init__(self, num_classes=10): super(ModifiedLightNNCosine, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(16, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) self.classifier = nn.Sequential( nn.Linear(1024, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, num_classes) ) def forward(self, x): x = self.features(x) flattened_conv_output = torch.flatten(x, 1) x = self.classifier(flattened_conv_output) return x, flattened_conv_output # We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device) modified_nn_deep.load_state_dict(nn_deep.state_dict()) # Once again ensure the norm of the first layer is the same for both networks print("Norm of 1st layer for deep_nn:", torch.norm(nn_deep.features[0].weight).item()) print("Norm of 1st layer for modified_deep_nn:", torch.norm(modified_nn_deep.features[0].weight).item()) # Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization. torch.manual_seed(42) modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device) print("Norm of 1st layer:", torch.norm(modified_nn_light.features[0].weight).item()) # Create a sample input tensor sample_input = torch.randn(128, 3, 32, 32).to(device) # Batch size: 128, Filters: 3, Image size: 32x32 # Pass the input through the student logits, hidden_representation = modified_nn_light(sample_input) # Print the shapes of the tensors print("Student logits shape:", logits.shape) # batch_size x total_classes print("Student hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size # Pass the input through the teacher logits, hidden_representation = modified_nn_deep(sample_input) # Print the shapes of the tensors print("Teacher logits shape:", logits.shape) # batch_size x total_classes print("Teacher hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device): ce_loss = nn.CrossEntropyLoss() cosine_loss = nn.CosineEmbeddingLoss() optimizer = optim.Adam(student.parameters(), lr=learning_rate) teacher.to(device) student.to(device) teacher.eval() # Teacher set to evaluation mode student.train() # Student to train mode for epoch in range(epochs): running_loss = 0.0 for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() # Forward pass with the teacher model and keep only the hidden representation with torch.no_grad(): _, teacher_hidden_representation = teacher(inputs) # Forward pass with the student model student_logits, student_hidden_representation = student(inputs) # Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase. hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device)) # Calculate the true label loss label_loss = ce_loss(student_logits, labels) # Weighted sum of the two losses loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}") def test_multiple_outputs(model, test_loader, device): model.to(device) model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(device), labels.to(device) outputs, _ = model(inputs) # Disregard the second tensor of the tuple _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f"Test Accuracy: {accuracy:.2f}%") return accuracy # Train and test the lightweight network with cross entropy loss train_cosine_loss(teacher=modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device) test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light, test_loader, device) # Pass the sample input only from the convolutional feature extractor convolutional_fe_output_student = nn_light.features(sample_input) convolutional_fe_output_teacher = nn_deep.features(sample_input) # Print their shapes print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape) print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape) class ModifiedDeepNNRegressor(nn.Module): def __init__(self, num_classes=10): super(ModifiedDeepNNRegressor, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 128, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) self.classifier = nn.Sequential( nn.Linear(2048, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, num_classes) ) def forward(self, x): x = self.features(x) conv_feature_map = x x = torch.flatten(x, 1) x = self.classifier(x) return x, conv_feature_map class ModifiedLightNNRegressor(nn.Module): def __init__(self, num_classes=10): super(ModifiedLightNNRegressor, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(16, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) # Include an extra regressor (in our case linear) self.regressor = nn.Sequential( nn.Conv2d(16, 32, kernel_size=3, padding=1) ) self.classifier = nn.Sequential( nn.Linear(1024, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, num_classes) ) def forward(self, x): x = self.features(x) regressor_output = self.regressor(x) x = torch.flatten(x, 1) x = self.classifier(x) return x, regressor_output def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device): ce_loss = nn.CrossEntropyLoss() mse_loss = nn.MSELoss() optimizer = optim.Adam(student.parameters(), lr=learning_rate) teacher.to(device) student.to(device) teacher.eval() # Teacher set to evaluation mode student.train() # Student to train mode for epoch in range(epochs): running_loss = 0.0 for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() # Again ignore teacher logits with torch.no_grad(): _, teacher_feature_map = teacher(inputs) # Forward pass with the student model student_logits, regressor_feature_map = student(inputs) # Calculate the loss hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map) # Calculate the true label loss label_loss = ce_loss(student_logits, labels) # Weighted sum of the two losses loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}") # Initialize a ModifiedLightNNRegressor torch.manual_seed(42) modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device) # We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance modified_nn_deep_reg = ModifiedDeepNNRegressor(num_classes=10).to(device) modified_nn_deep_reg.load_state_dict(nn_deep.state_dict()) # Train and test once again train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device) test_accuracy_light_ce_and_mse_loss = test_multiple_outputs(modified_nn_light_reg, test_loader, device) print(f"Teacher accuracy: {test_accuracy_deep:.2f}%") print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%") print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%") print(f"Student accuracy with CE + CosineLoss: {test_accuracy_light_ce_and_cosine_loss:.2f}%") print(f"Student accuracy with CE + RegressorMSE: {test_accuracy_light_ce_and_mse_loss:.2f}%") # # For more information, see: # # - [Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a # neural network. In: Neural Information Processing System Deep # Learning Workshop (2015)](https://arxiv.org/abs/1503.02531) # - [Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C., # Bengio, Y.: Fitnets: Hints for thin deep nets. In: Proceedings of # the International Conference on Learning # Representations (2015)](https://arxiv.org/abs/1412.6550) #