Main: add layernorms
This commit is contained in:
parent
59cf5aea85
commit
69eefb9264
1 changed files with 4 additions and 1 deletions
|
@ -24,15 +24,18 @@ class CNN(eqx.Module):
|
||||||
layers: list
|
layers: list
|
||||||
|
|
||||||
def __init__(self, key):
|
def __init__(self, key):
|
||||||
key1, key2, key3, key4 = jax.random.split(key, 4)
|
key1, key2, key3, key4, key5, key6 = jax.random.split(key, 6)
|
||||||
self.layers = [
|
self.layers = [
|
||||||
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
|
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
|
||||||
eqx.nn.MaxPool2d(kernel_size=2),
|
eqx.nn.MaxPool2d(kernel_size=2),
|
||||||
jax.nn.relu,
|
jax.nn.relu,
|
||||||
jnp.ravel,
|
jnp.ravel,
|
||||||
|
eqx.nn.GroupNorm(32, 1728),
|
||||||
eqx.nn.Linear(1728, 512, key=key2),
|
eqx.nn.Linear(1728, 512, key=key2),
|
||||||
|
eqx.nn.GroupNorm(16, 512),
|
||||||
jax.nn.sigmoid,
|
jax.nn.sigmoid,
|
||||||
eqx.nn.Linear(512, 64, key=key3),
|
eqx.nn.Linear(512, 64, key=key3),
|
||||||
|
eqx.nn.GroupNorm(8, 64),
|
||||||
jax.nn.relu,
|
jax.nn.relu,
|
||||||
eqx.nn.Linear(64, 10, key=key4),
|
eqx.nn.Linear(64, 10, key=key4),
|
||||||
jax.nn.log_softmax,
|
jax.nn.log_softmax,
|
||||||
|
|
Loading…
Reference in a new issue