Resent: add bottleneck block

This commit is contained in:
Akemi Izuko 2024-11-13 14:57:16 -07:00
parent 92b0024a27
commit 456eb30050
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -146,6 +146,7 @@ class ResidualBlock(eqx.Module):
eps=0.001) eps=0.001)
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001) eps=0.001)
# TODO: is bn2 in_channels correct?
self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, key=keys[0]) stride=stride, key=keys[0])
self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3,
@ -182,6 +183,76 @@ class ResidualBlock(eqx.Module):
return x return x
class BottleneckBlock(eqx.Module):
bn1: eqx.nn.BatchNorm
bn2: eqx.nn.BatchNorm
bn3: eqx.nn.BatchNorm
conv1: eqx.nn.Conv2d
conv2: eqx.nn.Conv2d
conv3: eqx.nn.Conv2d
project_conv: eqx.nn.Conv2d
relu_leak: float
is_b4_res: bool
def __init__(self, in_channels: int, out_channels: int, stride: int,
relu_leak: float, is_b4_res: bool, key):
self.stride = stride
self.relu_leak = relu_leak
self.is_b4_res = is_b4_res
keys = jax.random.split(key, 4)
# TODO: channels might be wrong for 2 and 3
self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001)
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001)
self.bn3 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001)
# TODO: /4 seems like it wouldn't work, maybe // or ceil?
# TODO: what's up with the lack of padding?
# TODO: does sride make sense for conv 2&3?
mid_channels = out_channels / 4
self.conv1 = eqx.nn.Conv2d(in_channels, mid_channels, kernel_size=1,
stride=stride, key=keys[0])
self.conv2 = eqx.nn.Conv2d(mid_channels, mid_channels, kernel_size=3,
stride=[1,1,1,1], key=keys[1])
self.conv3 = eqx.nn.Conv2d(mid_channels, out_channels, kernel_size=1,
stride=[1,1,1,1], key=keys[2])
if in_channels != out_channels:
self.project = eqx.nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=stride, key=keys[3])
else:
self.project = None
def __call__(
self,
x: Float[Array, "batch channels w h"]
) -> Float[Array, "batch channels w h"]:
if self.is_b4_res:
x = self.bn1(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
orig_x = x
else:
orig_x = x
x = self.bn1(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
x = self.conv1(x)
x = self.bn2(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
x = self.conv2(x)
x = self.bn3(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
x = self.conv3(x)
if self.project:
orig_x = self.project(orig_x)
x += orig_x
return x
class ResNet(eqx.Module): class ResNet(eqx.Module):