import torch import torch.nn as nn # Lightweight neural network class to be used as student: class ModifiedLightNNRegressor(nn.Module): def __init__(self, num_classes=10): super(ModifiedLightNNRegressor, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(16, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) # Include an extra regressor (in our case linear) self.regressor = nn.Sequential( nn.Conv2d(16, 32, kernel_size=3, padding=1) ) self.classifier = nn.Sequential( nn.Linear(1024, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, num_classes) ) def forward(self, x): x = self.features(x) regressor_output = self.regressor(x) x = torch.flatten(x, 1) x = self.classifier(x) return x, regressor_output Model = ModifiedLightNNRegressor