From 92b0024a27c154bf595e04924f8876a0464a7a13 Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Wed, 13 Nov 2024 12:40:34 -0700 Subject: [PATCH] Resnet: add residual block --- resnet_cifar10.py | 56 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/resnet_cifar10.py b/resnet_cifar10.py index e1f217d..3d2a556 100755 --- a/resnet_cifar10.py +++ b/resnet_cifar10.py @@ -127,12 +127,60 @@ class HParams(): class ResidualBlock(eqx.Module): bn1: eqx.nn.BatchNorm + bn2: eqx.nn.BatchNorm + conv1: eqx.nn.Conv2d + conv2: eqx.nn.Conv2d + avg_pool: eqx.nn.AvgPool2d + stride: int + relu_leak: float + + def __init__(self, in_channels: int, out_channels: int, stride: int, + relu_leak: float, is_b4_res: bool, key): + self.stride = stride + self.relu_leak = relu_leak + self.is_b4_res = is_b4_res - def __init__(self, in_channels: int, out_channels: int, stride: int, key): keys = jax.random.split(key, 2) self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, eps=0.001) + self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, + eps=0.001) + self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=stride, key=keys[0]) + self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, + key=keys[1]) + + if stride != 1 or in_channels != out_channels: + self.avg_pool = eqx.nn.AvgPool2d(stride, stride=stride) + # TODO: padding might be wrong... + else: + self.avg_pool = None + + def __call__( + self, + x: Float[Array, "batch channels w h"] + ) -> Float[Array, "batch channels w h"]: + if self.is_b4_res: + x = self.bn1(x) + x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) + orig_x = x + else: + orig_x = x + x = self.bn1(x) + x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) + + x = self.conv1(x) + x = self.bn2(x) + x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) + x = self.conv2(x) + + if self.avg_pool is not None: + orig_x = self.avg_pool(orig_x) + + x += orig_x + return x + @@ -146,7 +194,7 @@ class ResNet(eqx.Module): hps: HParams - self __init__(self, hps: HParams): + def __init__(self, hps: HParams): self.hps = hps keys = jax.random.split(key, 5) @@ -236,9 +284,9 @@ key, subkey = jax.random.split(key, 2) train_loader = build_dataloader(False) test_loader = build_dataloader(False) -print(dataloader) +print(train_loader) -x = next(iter(dataloader)) +x = next(iter(train_loader)) print(type(x), len(x)) print(type(x[0]), type(x[1])) print(x[0].shape, x[1].shape)