143 lines
5.1 KiB
Python
143 lines
5.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torchsummary import summary
|
|
import math
|
|
|
|
|
|
class IndividualBlock1(nn.Module):
|
|
def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
|
|
super(IndividualBlock1, self).__init__()
|
|
|
|
self.activation = nn.ReLU(inplace=True)
|
|
|
|
self.batch_norm1 = nn.BatchNorm2d(input_features)
|
|
self.batch_norm2 = nn.BatchNorm2d(output_features)
|
|
|
|
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)
|
|
|
|
self.subsample_input = subsample_input
|
|
self.increase_filters = increase_filters
|
|
if subsample_input:
|
|
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
|
|
elif increase_filters:
|
|
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)
|
|
|
|
def forward(self, x):
|
|
|
|
if self.subsample_input or self.increase_filters:
|
|
x = self.batch_norm1(x)
|
|
x = self.activation(x)
|
|
x1 = self.conv1(x)
|
|
else:
|
|
x1 = self.batch_norm1(x)
|
|
x1 = self.activation(x1)
|
|
x1 = self.conv1(x1)
|
|
x1 = self.batch_norm2(x1)
|
|
x1 = self.activation(x1)
|
|
x1 = self.conv2(x1)
|
|
|
|
if self.subsample_input or self.increase_filters:
|
|
return self.conv_inp(x) + x1
|
|
else:
|
|
return x + x1
|
|
|
|
|
|
class IndividualBlockN(nn.Module):
|
|
def __init__(self, input_features, output_features, stride):
|
|
super(IndividualBlockN, self).__init__()
|
|
|
|
self.activation = nn.ReLU(inplace=True)
|
|
|
|
self.batch_norm1 = nn.BatchNorm2d(input_features)
|
|
self.batch_norm2 = nn.BatchNorm2d(output_features)
|
|
|
|
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
|
|
def forward(self, x):
|
|
x1 = self.batch_norm1(x)
|
|
x1 = self.activation(x1)
|
|
x1 = self.conv1(x1)
|
|
x1 = self.batch_norm2(x1)
|
|
x1 = self.activation(x1)
|
|
x1 = self.conv2(x1)
|
|
|
|
return x1 + x
|
|
|
|
|
|
class Nblock(nn.Module):
|
|
|
|
def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
|
|
super(Nblock, self).__init__()
|
|
|
|
layers = []
|
|
for i in range(N):
|
|
if i == 0:
|
|
layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
|
|
else:
|
|
layers.append(IndividualBlockN(output_features, output_features, stride=1))
|
|
|
|
self.nblockLayer = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.nblockLayer(x)
|
|
|
|
|
|
class WideResNet(nn.Module):
|
|
|
|
def __init__(self, d, k, n_classes, input_features, output_features, strides):
|
|
super(WideResNet, self).__init__()
|
|
|
|
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)
|
|
|
|
filters = [16 * k, 32 * k, 64 * k]
|
|
self.out_filters = filters[-1]
|
|
N = (d - 4) // 6
|
|
increase_filters = k > 1
|
|
self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
|
|
self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
|
|
self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])
|
|
|
|
self.batch_norm = nn.BatchNorm2d(filters[-1])
|
|
self.activation = nn.ReLU(inplace=True)
|
|
self.avg_pool = nn.AvgPool2d(kernel_size=8)
|
|
self.fc = nn.Linear(filters[-1], n_classes)
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, nn.Linear):
|
|
m.bias.data.zero_()
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
attention1 = self.block1(x)
|
|
attention2 = self.block2(attention1)
|
|
attention3 = self.block3(attention2)
|
|
out = self.batch_norm(attention3)
|
|
out = self.activation(out)
|
|
out = self.avg_pool(out)
|
|
out = out.view(-1, self.out_filters)
|
|
|
|
return self.fc(out), attention1, attention2, attention3
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# change d and k if you want to check a model other than WRN-40-2
|
|
d = 40
|
|
k = 2
|
|
strides = [1, 1, 2, 2]
|
|
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides)
|
|
|
|
# verify that an output is produced
|
|
sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
|
|
net(sample_input)
|
|
|
|
# Summarize model
|
|
summary(net, input_size=(3, 32, 32))
|