diff --git a/resnet_cifar10.py b/resnet_cifar10.py index 3d2a556..26e7ff2 100755 --- a/resnet_cifar10.py +++ b/resnet_cifar10.py @@ -146,6 +146,7 @@ class ResidualBlock(eqx.Module): eps=0.001) self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, eps=0.001) + # TODO: is bn2 in_channels correct? self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, key=keys[0]) self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, @@ -182,6 +183,76 @@ class ResidualBlock(eqx.Module): return x +class BottleneckBlock(eqx.Module): + bn1: eqx.nn.BatchNorm + bn2: eqx.nn.BatchNorm + bn3: eqx.nn.BatchNorm + conv1: eqx.nn.Conv2d + conv2: eqx.nn.Conv2d + conv3: eqx.nn.Conv2d + project_conv: eqx.nn.Conv2d + relu_leak: float + is_b4_res: bool + + def __init__(self, in_channels: int, out_channels: int, stride: int, + relu_leak: float, is_b4_res: bool, key): + self.stride = stride + self.relu_leak = relu_leak + self.is_b4_res = is_b4_res + + keys = jax.random.split(key, 4) + + # TODO: channels might be wrong for 2 and 3 + self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, + eps=0.001) + self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, + eps=0.001) + self.bn3 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, + eps=0.001) + + # TODO: /4 seems like it wouldn't work, maybe // or ceil? + # TODO: what's up with the lack of padding? + # TODO: does sride make sense for conv 2&3? + mid_channels = out_channels / 4 + self.conv1 = eqx.nn.Conv2d(in_channels, mid_channels, kernel_size=1, + stride=stride, key=keys[0]) + self.conv2 = eqx.nn.Conv2d(mid_channels, mid_channels, kernel_size=3, + stride=[1,1,1,1], key=keys[1]) + self.conv3 = eqx.nn.Conv2d(mid_channels, out_channels, kernel_size=1, + stride=[1,1,1,1], key=keys[2]) + + if in_channels != out_channels: + self.project = eqx.nn.Conv2d(in_channels, out_channels, + kernel_size=1, stride=stride, key=keys[3]) + else: + self.project = None + + def __call__( + self, + x: Float[Array, "batch channels w h"] + ) -> Float[Array, "batch channels w h"]: + if self.is_b4_res: + x = self.bn1(x) + x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) + orig_x = x + else: + orig_x = x + x = self.bn1(x) + x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) + + x = self.conv1(x) + x = self.bn2(x) + x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) + x = self.conv2(x) + x = self.bn3(x) + x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) + x = self.conv3(x) + + if self.project: + orig_x = self.project(orig_x) + + x += orig_x + return x class ResNet(eqx.Module):