Resent: add bottleneck block
This commit is contained in:
parent
92b0024a27
commit
456eb30050
1 changed files with 71 additions and 0 deletions
|
@ -146,6 +146,7 @@ class ResidualBlock(eqx.Module):
|
||||||
eps=0.001)
|
eps=0.001)
|
||||||
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
||||||
eps=0.001)
|
eps=0.001)
|
||||||
|
# TODO: is bn2 in_channels correct?
|
||||||
self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3,
|
self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3,
|
||||||
stride=stride, key=keys[0])
|
stride=stride, key=keys[0])
|
||||||
self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3,
|
self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3,
|
||||||
|
@ -182,6 +183,76 @@ class ResidualBlock(eqx.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckBlock(eqx.Module):
|
||||||
|
bn1: eqx.nn.BatchNorm
|
||||||
|
bn2: eqx.nn.BatchNorm
|
||||||
|
bn3: eqx.nn.BatchNorm
|
||||||
|
conv1: eqx.nn.Conv2d
|
||||||
|
conv2: eqx.nn.Conv2d
|
||||||
|
conv3: eqx.nn.Conv2d
|
||||||
|
project_conv: eqx.nn.Conv2d
|
||||||
|
relu_leak: float
|
||||||
|
is_b4_res: bool
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int, out_channels: int, stride: int,
|
||||||
|
relu_leak: float, is_b4_res: bool, key):
|
||||||
|
self.stride = stride
|
||||||
|
self.relu_leak = relu_leak
|
||||||
|
self.is_b4_res = is_b4_res
|
||||||
|
|
||||||
|
keys = jax.random.split(key, 4)
|
||||||
|
|
||||||
|
# TODO: channels might be wrong for 2 and 3
|
||||||
|
self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
||||||
|
eps=0.001)
|
||||||
|
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
||||||
|
eps=0.001)
|
||||||
|
self.bn3 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
||||||
|
eps=0.001)
|
||||||
|
|
||||||
|
# TODO: /4 seems like it wouldn't work, maybe // or ceil?
|
||||||
|
# TODO: what's up with the lack of padding?
|
||||||
|
# TODO: does sride make sense for conv 2&3?
|
||||||
|
mid_channels = out_channels / 4
|
||||||
|
self.conv1 = eqx.nn.Conv2d(in_channels, mid_channels, kernel_size=1,
|
||||||
|
stride=stride, key=keys[0])
|
||||||
|
self.conv2 = eqx.nn.Conv2d(mid_channels, mid_channels, kernel_size=3,
|
||||||
|
stride=[1,1,1,1], key=keys[1])
|
||||||
|
self.conv3 = eqx.nn.Conv2d(mid_channels, out_channels, kernel_size=1,
|
||||||
|
stride=[1,1,1,1], key=keys[2])
|
||||||
|
|
||||||
|
if in_channels != out_channels:
|
||||||
|
self.project = eqx.nn.Conv2d(in_channels, out_channels,
|
||||||
|
kernel_size=1, stride=stride, key=keys[3])
|
||||||
|
else:
|
||||||
|
self.project = None
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: Float[Array, "batch channels w h"]
|
||||||
|
) -> Float[Array, "batch channels w h"]:
|
||||||
|
if self.is_b4_res:
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
||||||
|
orig_x = x
|
||||||
|
else:
|
||||||
|
orig_x = x
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.bn3(x)
|
||||||
|
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
||||||
|
x = self.conv3(x)
|
||||||
|
|
||||||
|
if self.project:
|
||||||
|
orig_x = self.project(orig_x)
|
||||||
|
|
||||||
|
x += orig_x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ResNet(eqx.Module):
|
class ResNet(eqx.Module):
|
||||||
|
|
Loading…
Reference in a new issue