Main: add layernorms

This commit is contained in:
Akemi Izuko 2024-12-16 23:20:24 -07:00
parent 59cf5aea85
commit 69eefb9264
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -24,15 +24,18 @@ class CNN(eqx.Module):
layers: list layers: list
def __init__(self, key): def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4) key1, key2, key3, key4, key5, key6 = jax.random.split(key, 6)
self.layers = [ self.layers = [
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1), eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
eqx.nn.MaxPool2d(kernel_size=2), eqx.nn.MaxPool2d(kernel_size=2),
jax.nn.relu, jax.nn.relu,
jnp.ravel, jnp.ravel,
eqx.nn.GroupNorm(32, 1728),
eqx.nn.Linear(1728, 512, key=key2), eqx.nn.Linear(1728, 512, key=key2),
eqx.nn.GroupNorm(16, 512),
jax.nn.sigmoid, jax.nn.sigmoid,
eqx.nn.Linear(512, 64, key=key3), eqx.nn.Linear(512, 64, key=key3),
eqx.nn.GroupNorm(8, 64),
jax.nn.relu, jax.nn.relu,
eqx.nn.Linear(64, 10, key=key4), eqx.nn.Linear(64, 10, key=key4),
jax.nn.log_softmax, jax.nn.log_softmax,