Resnet: add residual block

This commit is contained in:
Akemi Izuko 2024-11-13 12:40:34 -07:00
parent 4dd4868c55
commit 92b0024a27
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -127,12 +127,60 @@ class HParams():
class ResidualBlock(eqx.Module): class ResidualBlock(eqx.Module):
bn1: eqx.nn.BatchNorm bn1: eqx.nn.BatchNorm
bn2: eqx.nn.BatchNorm
conv1: eqx.nn.Conv2d
conv2: eqx.nn.Conv2d
avg_pool: eqx.nn.AvgPool2d
stride: int
relu_leak: float
def __init__(self, in_channels: int, out_channels: int, stride: int,
relu_leak: float, is_b4_res: bool, key):
self.stride = stride
self.relu_leak = relu_leak
self.is_b4_res = is_b4_res
def __init__(self, in_channels: int, out_channels: int, stride: int, key):
keys = jax.random.split(key, 2) keys = jax.random.split(key, 2)
self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001) eps=0.001)
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001)
self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, key=keys[0])
self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3,
key=keys[1])
if stride != 1 or in_channels != out_channels:
self.avg_pool = eqx.nn.AvgPool2d(stride, stride=stride)
# TODO: padding might be wrong...
else:
self.avg_pool = None
def __call__(
self,
x: Float[Array, "batch channels w h"]
) -> Float[Array, "batch channels w h"]:
if self.is_b4_res:
x = self.bn1(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
orig_x = x
else:
orig_x = x
x = self.bn1(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
x = self.conv1(x)
x = self.bn2(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
x = self.conv2(x)
if self.avg_pool is not None:
orig_x = self.avg_pool(orig_x)
x += orig_x
return x
@ -146,7 +194,7 @@ class ResNet(eqx.Module):
hps: HParams hps: HParams
self __init__(self, hps: HParams): def __init__(self, hps: HParams):
self.hps = hps self.hps = hps
keys = jax.random.split(key, 5) keys = jax.random.split(key, 5)
@ -236,9 +284,9 @@ key, subkey = jax.random.split(key, 2)
train_loader = build_dataloader(False) train_loader = build_dataloader(False)
test_loader = build_dataloader(False) test_loader = build_dataloader(False)
print(dataloader) print(train_loader)
x = next(iter(dataloader)) x = next(iter(train_loader))
print(type(x), len(x)) print(type(x), len(x))
print(type(x[0]), type(x[1])) print(type(x[0]), type(x[1]))
print(x[0].shape, x[1].shape) print(x[0].shape, x[1].shape)