Resent: add bottleneck block
This commit is contained in:
parent
92b0024a27
commit
456eb30050
1 changed files with 71 additions and 0 deletions
|
@ -146,6 +146,7 @@ class ResidualBlock(eqx.Module):
|
|||
eps=0.001)
|
||||
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
||||
eps=0.001)
|
||||
# TODO: is bn2 in_channels correct?
|
||||
self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3,
|
||||
stride=stride, key=keys[0])
|
||||
self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3,
|
||||
|
@ -182,6 +183,76 @@ class ResidualBlock(eqx.Module):
|
|||
return x
|
||||
|
||||
|
||||
class BottleneckBlock(eqx.Module):
|
||||
bn1: eqx.nn.BatchNorm
|
||||
bn2: eqx.nn.BatchNorm
|
||||
bn3: eqx.nn.BatchNorm
|
||||
conv1: eqx.nn.Conv2d
|
||||
conv2: eqx.nn.Conv2d
|
||||
conv3: eqx.nn.Conv2d
|
||||
project_conv: eqx.nn.Conv2d
|
||||
relu_leak: float
|
||||
is_b4_res: bool
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, stride: int,
|
||||
relu_leak: float, is_b4_res: bool, key):
|
||||
self.stride = stride
|
||||
self.relu_leak = relu_leak
|
||||
self.is_b4_res = is_b4_res
|
||||
|
||||
keys = jax.random.split(key, 4)
|
||||
|
||||
# TODO: channels might be wrong for 2 and 3
|
||||
self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
||||
eps=0.001)
|
||||
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
||||
eps=0.001)
|
||||
self.bn3 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
||||
eps=0.001)
|
||||
|
||||
# TODO: /4 seems like it wouldn't work, maybe // or ceil?
|
||||
# TODO: what's up with the lack of padding?
|
||||
# TODO: does sride make sense for conv 2&3?
|
||||
mid_channels = out_channels / 4
|
||||
self.conv1 = eqx.nn.Conv2d(in_channels, mid_channels, kernel_size=1,
|
||||
stride=stride, key=keys[0])
|
||||
self.conv2 = eqx.nn.Conv2d(mid_channels, mid_channels, kernel_size=3,
|
||||
stride=[1,1,1,1], key=keys[1])
|
||||
self.conv3 = eqx.nn.Conv2d(mid_channels, out_channels, kernel_size=1,
|
||||
stride=[1,1,1,1], key=keys[2])
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.project = eqx.nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=1, stride=stride, key=keys[3])
|
||||
else:
|
||||
self.project = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Float[Array, "batch channels w h"]
|
||||
) -> Float[Array, "batch channels w h"]:
|
||||
if self.is_b4_res:
|
||||
x = self.bn1(x)
|
||||
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
||||
orig_x = x
|
||||
else:
|
||||
orig_x = x
|
||||
x = self.bn1(x)
|
||||
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn2(x)
|
||||
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
||||
x = self.conv2(x)
|
||||
x = self.bn3(x)
|
||||
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
||||
x = self.conv3(x)
|
||||
|
||||
if self.project:
|
||||
orig_x = self.project(orig_x)
|
||||
|
||||
x += orig_x
|
||||
return x
|
||||
|
||||
|
||||
class ResNet(eqx.Module):
|
||||
|
|
Loading…
Reference in a new issue