Resnet: add residual block
This commit is contained in:
parent
4dd4868c55
commit
92b0024a27
1 changed files with 52 additions and 4 deletions
|
@ -127,12 +127,60 @@ class HParams():
|
|||
|
||||
class ResidualBlock(eqx.Module):
|
||||
bn1: eqx.nn.BatchNorm
|
||||
bn2: eqx.nn.BatchNorm
|
||||
conv1: eqx.nn.Conv2d
|
||||
conv2: eqx.nn.Conv2d
|
||||
avg_pool: eqx.nn.AvgPool2d
|
||||
stride: int
|
||||
relu_leak: float
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, stride: int,
|
||||
relu_leak: float, is_b4_res: bool, key):
|
||||
self.stride = stride
|
||||
self.relu_leak = relu_leak
|
||||
self.is_b4_res = is_b4_res
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, stride: int, key):
|
||||
keys = jax.random.split(key, 2)
|
||||
|
||||
self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
||||
eps=0.001)
|
||||
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
||||
eps=0.001)
|
||||
self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3,
|
||||
stride=stride, key=keys[0])
|
||||
self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3,
|
||||
key=keys[1])
|
||||
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
self.avg_pool = eqx.nn.AvgPool2d(stride, stride=stride)
|
||||
# TODO: padding might be wrong...
|
||||
else:
|
||||
self.avg_pool = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Float[Array, "batch channels w h"]
|
||||
) -> Float[Array, "batch channels w h"]:
|
||||
if self.is_b4_res:
|
||||
x = self.bn1(x)
|
||||
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
||||
orig_x = x
|
||||
else:
|
||||
orig_x = x
|
||||
x = self.bn1(x)
|
||||
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn2(x)
|
||||
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.avg_pool is not None:
|
||||
orig_x = self.avg_pool(orig_x)
|
||||
|
||||
x += orig_x
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -146,7 +194,7 @@ class ResNet(eqx.Module):
|
|||
hps: HParams
|
||||
|
||||
|
||||
self __init__(self, hps: HParams):
|
||||
def __init__(self, hps: HParams):
|
||||
self.hps = hps
|
||||
keys = jax.random.split(key, 5)
|
||||
|
||||
|
@ -236,9 +284,9 @@ key, subkey = jax.random.split(key, 2)
|
|||
train_loader = build_dataloader(False)
|
||||
test_loader = build_dataloader(False)
|
||||
|
||||
print(dataloader)
|
||||
print(train_loader)
|
||||
|
||||
x = next(iter(dataloader))
|
||||
x = next(iter(train_loader))
|
||||
print(type(x), len(x))
|
||||
print(type(x[0]), type(x[1]))
|
||||
print(x[0].shape, x[1].shape)
|
||||
|
|
Loading…
Reference in a new issue