mia_on_model_distillation/wresnet-pytorch/src/WideResNet.py

146 lines
5.1 KiB
Python
Raw Normal View History

2024-11-30 13:35:38 -07:00
import torch
import torch.nn as nn
from torchsummary import summary
import math
class IndividualBlock1(nn.Module):
def __init__(self, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(IndividualBlock1, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)
self.subsample_input = subsample_input
self.increase_filters = increase_filters
if subsample_input:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
elif increase_filters:
self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, x):
if self.subsample_input or self.increase_filters:
x = self.batch_norm1(x)
x = self.activation(x)
x1 = self.conv1(x)
else:
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
if self.subsample_input or self.increase_filters:
return self.conv_inp(x) + x1
else:
return x + x1
class IndividualBlockN(nn.Module):
def __init__(self, input_features, output_features, stride):
super(IndividualBlockN, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.batch_norm1 = nn.BatchNorm2d(input_features)
self.batch_norm2 = nn.BatchNorm2d(output_features)
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
def forward(self, x):
x1 = self.batch_norm1(x)
x1 = self.activation(x1)
x1 = self.conv1(x1)
x1 = self.batch_norm2(x1)
x1 = self.activation(x1)
x1 = self.conv2(x1)
return x1 + x
class Nblock(nn.Module):
def __init__(self, N, input_features, output_features, stride, subsample_input=True, increase_filters=True):
super(Nblock, self).__init__()
layers = []
for i in range(N):
if i == 0:
layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
else:
layers.append(IndividualBlockN(output_features, output_features, stride=1))
self.nblockLayer = nn.Sequential(*layers)
def forward(self, x):
return self.nblockLayer(x)
class WideResNet(nn.Module):
def __init__(self, d, k, n_classes, input_features, output_features, strides):
super(WideResNet, self).__init__()
self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)
filters = [16 * k, 32 * k, 64 * k]
self.out_filters = filters[-1]
N = (d - 4) // 6
increase_filters = k > 1
self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])
self.batch_norm = nn.BatchNorm2d(filters[-1])
self.activation = nn.ReLU(inplace=True)
self.avg_pool = nn.AvgPool2d(kernel_size=8)
self.fc = nn.Linear(filters[-1], n_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
x = self.conv1(x)
attention1 = self.block1(x)
attention2 = self.block2(attention1)
attention3 = self.block3(attention2)
out = self.batch_norm(attention3)
out = self.activation(out)
out = self.avg_pool(out)
out = out.view(-1, self.out_filters)
return self.fc(out), attention1, attention2, attention3
if __name__ == '__main__':
# change d and k if you want to check a model other than WRN-40-2
d = 40
k = 2
strides = [1, 1, 2, 2]
net = WideResNet(d=d, k=k, n_classes=10, input_features=3, output_features=16, strides=strides)
# verify that an output is produced
sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
net(sample_input)
# Summarize model
summary(net, input_size=(3, 32, 32))