2017-08-17 12:43:17 -06:00
import math
from functools import partial
2017-12-11 12:35:48 -07:00
import torch
2017-08-17 12:43:17 -06:00
from torch import nn
from torch.nn import functional as F
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1):
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1,
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
if stride != 1 or inplanes != (planes * self.expansion):
self.shortcut = nn.Sequential(
nn.Conv2d(inplanes, planes * self.expansion, 1, stride=stride,
nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
def forward(self, inputs):
H = self.conv1(inputs)
H = self.bn1(H)
H = F.relu(H)
H = self.conv2(H)
H = self.bn2(H)
H += self.shortcut(inputs)
outputs = F.relu(H)
return outputs
2017-12-11 12:35:48 -07:00
class StochasticBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, survival_rate=1):
self.survival_rate = survival_rate
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1,
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.increasing = inplanes != (planes * self.expansion)
if self.increasing:
assert ((1. * planes * self.expansion) / inplanes) == 2
if stride != 1:
self.shortcut = nn.Sequential(nn.AvgPool2d(stride))
self.shortcut = nn.Sequential()
def forward(self, inputs):
shortcut = self.shortcut(inputs)
if self.increasing:
shortcut = torch.cat([shortcut] + [shortcut.mul(0)], 1)
if not self.training or torch.rand(1)[0] <= self.survival_rate:
H = self.conv1(inputs)
H = self.bn1(H)
H = F.relu(H)
H = self.conv2(H)
H = self.bn2(H)
if self.training:
H /= self.survival_rate
H += shortcut
H = shortcut
outputs = F.relu(H)
return outputs
2017-08-17 12:43:17 -06:00
class PreActBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1):
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1,
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.increasing = stride != 1 or inplanes != (planes * self.expansion)
if self.increasing:
self.shortcut = nn.Sequential(
nn.Conv2d(inplanes, planes * self.expansion, 1, stride=stride,
self.shortcut = nn.Sequential()
def forward(self, inputs):
H = self.bn1(inputs)
H = F.relu(H)
if self.increasing:
inputs = H
H = self.conv1(H)
H = self.bn2(H)
H = F.relu(H)
H = self.conv2(H)
H += self.shortcut(inputs)
return H
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
if stride != 1 or inplanes != (planes * self.expansion):
self.shortcut = nn.Sequential(
nn.Conv2d(inplanes, planes * self.expansion, 1, stride=stride,
nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
def forward(self, inputs):
H = self.conv1(inputs)
H = self.bn1(H)
H = F.relu(H)
H = self.conv2(H)
H = self.bn2(H)
H = F.relu(H)
H = self.conv3(H)
H = self.bn3(H)
H += self.shortcut(inputs)
outputs = F.relu(H)
return outputs
class ResNeXtBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, cardinality=32,
width = math.floor(planes * (base_width / 64.0))
self.conv1 = nn.Conv2d(inplanes, width * cardinality, 1, bias=False)
self.bn1 = nn.BatchNorm2d(width * cardinality)
self.conv2 = nn.Conv2d(width * cardinality, width * cardinality, 3,
groups=cardinality, padding=1, stride=stride,
self.bn2 = nn.BatchNorm2d(width * cardinality)
self.conv3 = nn.Conv2d(width * cardinality, planes * 4, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
if stride != 1 or inplanes != (planes * self.expansion):
self.shortcut = nn.Sequential(
nn.Conv2d(inplanes, planes * self.expansion, 1, stride=stride,
nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
def forward(self, inputs):
H = self.conv1(inputs)
H = self.bn1(H)
H = F.relu(H)
H = self.conv2(H)
H = self.bn2(H)
H = F.relu(H)
H = self.conv3(H)
H = self.bn3(H)
H += self.shortcut(inputs)
outputs = F.relu(H)
return outputs
class PreActBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, stride=stride,
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, 1, bias=False)
self.increasing = stride != 1 or inplanes != (planes * self.expansion)
if self.increasing:
self.shortcut = nn.Sequential(
nn.Conv2d(inplanes, planes * self.expansion, 1, stride=stride,
self.shortcut = nn.Sequential()
def forward(self, inputs):
H = self.bn1(inputs)
H = F.relu(H)
if self.increasing:
inputs = H
H = self.conv1(H)
H = self.bn2(H)
H = F.relu(H)
H = self.conv2(H)
H = self.bn3(H)
H = F.relu(H)
H = self.conv3(H)
H += self.shortcut(inputs)
return H
class ResNet(nn.Module):
def __init__(self, Block, layers, filters, num_classes=10, inplanes=None):
self.inplanes = inplanes or filters[0]
self.pre_act = 'Pre' in Block.__name__
self.conv1 = nn.Conv2d(3, self.inplanes, 3, padding=1, bias=False)
if not self.pre_act:
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.num_sections = len(layers)
for section_index, (size, planes) in enumerate(zip(layers, filters)):
section = []
for layer_index in range(size):
if section_index != 0 and layer_index == 0:
stride = 2
stride = 1
section.append(Block(self.inplanes, planes, stride=stride))
self.inplanes = planes * Block.expansion
section = nn.Sequential(*section)
setattr(self, f'section_{section_index}', section)
if self.pre_act:
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.fc = nn.Linear(filters[-1] * Block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
def forward(self, inputs):
H = self.conv1(inputs)
if not self.pre_act:
H = self.bn1(H)
H = F.relu(H)
for section_index in range(self.num_sections):
H = getattr(self, f'section_{section_index}')(H)
if self.pre_act:
H = self.bn1(H)
H = F.relu(H)
H = F.avg_pool2d(H, H.size()[2:])
H = H.view(H.size(0), -1)
outputs = self.fc(H)
return outputs
2017-12-11 12:35:48 -07:00
class StochasticResNet(ResNet):
def __init__(self, Block, layers, filters, num_classes=10, inplanes=None,
min_survival_rate=1.0, decay='linear'):
super().__init__(Block, layers, filters,
L = sum(layers)
l = 1
for section_index in range(self.num_sections):
section = getattr(self, f'section_{section_index}')
for name, module in section.named_children():
if decay == 'linear':
survival_rate = 1 - ((l / L) * (1 - min_survival_rate))
elif decay == 'uniform':
survival_rate = min_survival_rate
raise NotImplementedError(
f"{decay} decay has not been implemented.")
module.survival_rate = survival_rate
l += 1
assert (l - 1) == L
2017-08-17 12:43:17 -06:00
# From "Deep Residual Learning for Image Recognition"
def ResNet20():
return ResNet(BasicBlock, layers=[3] * 3, filters=[16, 32, 64])
def ResNet32():
return ResNet(BasicBlock, layers=[5] * 3, filters=[16, 32, 64])
def ResNet44():
return ResNet(BasicBlock, layers=[7] * 3, filters=[16, 32, 64])
def ResNet56():
return ResNet(BasicBlock, layers=[9] * 3, filters=[16, 32, 64])
def ResNet110():
return ResNet(BasicBlock, layers=[18] * 3, filters=[16, 32, 64])
def ResNet1202():
return ResNet(BasicBlock, layers=[200] * 3, filters=[16, 32, 64])
2017-12-11 12:35:48 -07:00
# From "Identity Mappings in Deep Residual Networks"
def PreActResNet110():
return ResNet(PreActBlock, layers=[18] * 3, filters=[16, 32, 64])
def PreActResNet164():
return ResNet(PreActBottleneck, layers=[18] * 3, filters=[16, 32, 64])
def PreActResNet1001():
return ResNet(PreActBottleneck, layers=[111] * 3, filters=[16, 32, 64])
# Based on but not in "Identity Mappings in Deep Residual Networks"
def PreActResNet8():
return ResNet(PreActBlock, layers=[1] * 3, filters=[16, 32, 64])
def PreActResNet14():
return ResNet(PreActBlock, layers=[2] * 3, filters=[16, 32, 64])
2017-08-17 12:43:17 -06:00
def PreActResNet20():
return ResNet(PreActBlock, layers=[3] * 3, filters=[16, 32, 64])
def PreActResNet56():
return ResNet(PreActBlock, layers=[9] * 3, filters=[16, 32, 64])
def PreActResNet164Basic():
return ResNet(PreActBlock, layers=[27] * 3, filters=[16, 32, 64])
2017-12-11 12:35:48 -07:00
# From "Deep Networks with Stochastic Depth"
def StochasticResNet110():
return StochasticResNet(StochasticBlock, layers=[18] * 3,
filters=[16, 32, 64], min_survival_rate=0.5,
2017-08-17 12:43:17 -06:00
2017-12-11 12:35:48 -07:00
def StochasticResNet1202():
return StochasticResNet(StochasticBlock, layers=[200] * 3,
filters=[16, 32, 64], min_survival_rate=0.5,
2017-08-17 12:43:17 -06:00
2017-12-11 12:35:48 -07:00
# Based on but not in "Deep Networks for Stochastic Depth"
def StochasticResNet56():
return StochasticResNet(StochasticBlock, layers=[9] * 3,
filters=[16, 32, 64], min_survival_rate=0.5,
def StochasticResNet56_08():
return StochasticResNet(StochasticBlock, layers=[9] * 3,
filters=[16, 32, 64], min_survival_rate=0.8,
2017-08-17 12:43:17 -06:00
# From "Wide Residual Networks"
def WRN(n, k):
assert (n - 4) % 6 == 0
base_filters = [16, 32, 64]
filters = [num_filters * k for num_filters in base_filters]
d = (n - 4) / 2 # l = 2
return ResNet(PreActBlock, layers=[int(d / 3)] * 3, filters=filters,
def WRN_40_4():
return WRN(40, 4)
def WRN_16_8():
return WRN(16, 8)
def WRN_28_10():
return WRN(28, 10)
# From "Aggregated Residual Transformations for Deep Neural Networks"
def ResNeXt29(cardinality, base_width):
Block = partial(ResNeXtBottleneck, cardinality=cardinality,
Block.__name__ = ResNeXtBottleneck.__name__
Block.expansion = ResNeXtBottleneck.expansion
return ResNet(Block, layers=[3, 3, 3], filters=[64, 128, 256])
# From kunagliu/pytorch
def ResNet18():
return ResNet(BasicBlock, layers=[2, 2, 2, 2], filters=[64, 128, 256, 512])
def ResNet34():
return ResNet(BasicBlock, layers=[3, 4, 6, 3], filters=[64, 128, 256, 512])
def ResNet50():
return ResNet(Bottleneck, layers=[3, 4, 6, 3], filters=[64, 128, 256, 512])
def ResNet101():
return ResNet(Bottleneck,
layers=[3, 4, 23, 3], filters=[64, 128, 256, 512])
def ResNet152():
return ResNet(Bottleneck,
layers=[3, 8, 36, 3], filters=[64, 128, 256, 512])