From 69eefb9264071e75376a57f2e69adccd993872ef Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Mon, 16 Dec 2024 23:20:24 -0700 Subject: [PATCH] Main: add layernorms --- src/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main.py b/src/main.py index 7526c74..28c32c1 100644 --- a/src/main.py +++ b/src/main.py @@ -24,15 +24,18 @@ class CNN(eqx.Module): layers: list def __init__(self, key): - key1, key2, key3, key4 = jax.random.split(key, 4) + key1, key2, key3, key4, key5, key6 = jax.random.split(key, 6) self.layers = [ eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1), eqx.nn.MaxPool2d(kernel_size=2), jax.nn.relu, jnp.ravel, + eqx.nn.GroupNorm(32, 1728), eqx.nn.Linear(1728, 512, key=key2), + eqx.nn.GroupNorm(16, 512), jax.nn.sigmoid, eqx.nn.Linear(512, 64, key=key3), + eqx.nn.GroupNorm(8, 64), jax.nn.relu, eqx.nn.Linear(64, 10, key=key4), jax.nn.log_softmax,