35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
import torch.nn as nn
|
|
|
|
|
|
# Lightweight neural network class to be used as student:
|
|
class ModifiedLightNNRegressor(nn.Module):
|
|
def __init__(self, num_classes=10):
|
|
super(ModifiedLightNNRegressor, self).__init__()
|
|
self.features = nn.Sequential(
|
|
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
)
|
|
# Include an extra regressor (in our case linear)
|
|
self.regressor = nn.Sequential(
|
|
nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
|
)
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(1024, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(256, num_classes)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
regressor_output = self.regressor(x)
|
|
x = torch.flatten(x, 1)
|
|
x = self.classifier(x)
|
|
return x, regressor_output
|
|
|
|
|
|
Model = ModifiedLightNNRegressor
|