31 lines
965 B
Python
31 lines
965 B
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
|
||
|
class ModifiedLightNNCosine(nn.Module):
|
||
|
def __init__(self, num_classes=10):
|
||
|
super(ModifiedLightNNCosine, self).__init__()
|
||
|
self.features = nn.Sequential(
|
||
|
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
||
|
nn.ReLU(),
|
||
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
||
|
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
||
|
nn.ReLU(),
|
||
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
||
|
)
|
||
|
self.classifier = nn.Sequential(
|
||
|
nn.Linear(1024, 256),
|
||
|
nn.ReLU(),
|
||
|
nn.Dropout(0.1),
|
||
|
nn.Linear(256, num_classes)
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.features(x)
|
||
|
flattened_conv_output = torch.flatten(x, 1)
|
||
|
x = self.classifier(flattened_conv_output)
|
||
|
return x
|
||
|
|
||
|
Model = ModifiedLightNNCosine
|
||
|
|