Init
This commit is contained in:
commit
a9700e22bb
1 changed files with 206 additions and 0 deletions
206
src/main.py
Normal file
206
src/main.py
Normal file
|
@ -0,0 +1,206 @@
|
||||||
|
# Code from:
|
||||||
|
# https://docs.kidger.site/equinox/examples/mnist/
|
||||||
|
import equinox as eqx
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import optax
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from functools import partial
|
||||||
|
from jaxtyping import Array, Float, Int, PyTree
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
DATA_ROOT="data"
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
LEARNING_RATE = 3e-4
|
||||||
|
STEPS = 300
|
||||||
|
PRINT_EVERY = 30
|
||||||
|
SEED = 5678
|
||||||
|
|
||||||
|
|
||||||
|
class CNN(eqx.Module):
|
||||||
|
layers: list
|
||||||
|
|
||||||
|
def __init__(self, key):
|
||||||
|
key1, key2, key3, key4 = jax.random.split(key, 4)
|
||||||
|
self.layers = [
|
||||||
|
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
|
||||||
|
eqx.nn.MaxPool2d(kernel_size=2),
|
||||||
|
jax.nn.relu,
|
||||||
|
jnp.ravel,
|
||||||
|
eqx.nn.Linear(1728, 512, key=key2),
|
||||||
|
jax.nn.sigmoid,
|
||||||
|
eqx.nn.Linear(512, 64, key=key3),
|
||||||
|
jax.nn.relu,
|
||||||
|
eqx.nn.Linear(64, 10, key=key4),
|
||||||
|
jax.nn.log_softmax,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cross_entropy(
|
||||||
|
y: Int[Array, " batch"],
|
||||||
|
pred_y: Float[Array, "batch 10"]
|
||||||
|
) -> Float[Array, ""]:
|
||||||
|
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
|
||||||
|
return -jnp.mean(pred_y)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def loss(
|
||||||
|
model,
|
||||||
|
x: Float[Array, "batch 1 28 28"],
|
||||||
|
y: Int[Array, " batch"]
|
||||||
|
) -> Float[Array, ""]:
|
||||||
|
pred_y = jax.vmap(model)(x)
|
||||||
|
return CNN.cross_entropy(y, pred_y)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@partial(jax.jit, static_argnums=(1,))
|
||||||
|
def loss2(
|
||||||
|
params,
|
||||||
|
statics,
|
||||||
|
x: Float[Array, "batch 1 28 28"],
|
||||||
|
y: Int[Array, " batch"]
|
||||||
|
) -> Float[Array, ""]:
|
||||||
|
model = eqx.combine(params, statics)
|
||||||
|
pred_y = jax.vmap(model)(x)
|
||||||
|
return CNN.cross_entropy(y, pred_y)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@partial(jax.jit, static_argnums=(1,))
|
||||||
|
def get_accuracy(
|
||||||
|
params,
|
||||||
|
statics,
|
||||||
|
x: Float[Array, "batch 1 28 28"],
|
||||||
|
y: Float[Array, "batch"],
|
||||||
|
) -> Float[Array, ""]:
|
||||||
|
model = eqx.combine(params, statics)
|
||||||
|
pred_y = jax.vmap(model)(x)
|
||||||
|
pred_y = jnp.argmax(pred_y, axis=1)
|
||||||
|
return jnp.mean(y == pred_y)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_step(
|
||||||
|
model,
|
||||||
|
opt_state: PyTree,
|
||||||
|
optim: optax.GradientTransformation,
|
||||||
|
x: Float[Array, "batch 1 28 28"],
|
||||||
|
y: Float[Array, "batch"],
|
||||||
|
):
|
||||||
|
params, statics = eqx.partition(model, eqx.is_array)
|
||||||
|
|
||||||
|
loss_with_grad = lambda p, s, x, y: model.loss(eqx.combine(p,s), x, y)
|
||||||
|
loss_with_grad = jax.value_and_grad(loss_with_grad)
|
||||||
|
loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,))
|
||||||
|
|
||||||
|
loss, grad = loss_with_grad(params, statics, x, y)
|
||||||
|
updates, opt_state = optim.update(
|
||||||
|
grad, opt_state, eqx.filter(model, eqx.is_array)
|
||||||
|
)
|
||||||
|
model = eqx.apply_updates(model, updates)
|
||||||
|
return model, opt_state, loss
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def train(
|
||||||
|
model,
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
test_dl: torch.utils.data.DataLoader,
|
||||||
|
optim: optax.GradientTransformation,
|
||||||
|
steps: int,
|
||||||
|
log_every: int,
|
||||||
|
):
|
||||||
|
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
||||||
|
|
||||||
|
def infinite_dl(dl):
|
||||||
|
while True:
|
||||||
|
yield from dl
|
||||||
|
|
||||||
|
for step, (x, y) in zip(range(steps), infinite_dl(train_dl)):
|
||||||
|
x = x.numpy()
|
||||||
|
y = y.numpy()
|
||||||
|
model, opt_state, loss = model.make_step(model, opt_state, optim, x, y)
|
||||||
|
|
||||||
|
if (step % log_every == 0) or (step == steps - 1):
|
||||||
|
test_loss, test_acc = evaluate_model(model, train_dl)
|
||||||
|
print(f"{step=}, train_loss={loss.item()}, test_loss={test_loss.item()}")
|
||||||
|
print(f"{step=}, train_loss={loss.item()}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(model, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
|
||||||
|
avg_loss = 0
|
||||||
|
avg_acc = 0
|
||||||
|
|
||||||
|
for x, y in test_dl:
|
||||||
|
x = x.numpy()
|
||||||
|
y = y.numpy()
|
||||||
|
params, statics = eqx.partition(model, eqx.is_array)
|
||||||
|
avg_loss += CNN.loss2(params, statics, x, y)
|
||||||
|
avg_acc += CNN.get_accuracy(params, statics, x, y)
|
||||||
|
|
||||||
|
avg_loss /= len(test_dl)
|
||||||
|
avg_acc /= len(test_dl)
|
||||||
|
return avg_loss, avg_acc
|
||||||
|
|
||||||
|
|
||||||
|
def load_data():
|
||||||
|
normalise_data = torchvision.transforms.Compose([
|
||||||
|
torchvision.transforms.ToTensor(),
|
||||||
|
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
||||||
|
])
|
||||||
|
train_ds = torchvision.datasets.MNIST(
|
||||||
|
root=DATA_ROOT,
|
||||||
|
train=True,
|
||||||
|
download=True,
|
||||||
|
transform=normalise_data,
|
||||||
|
)
|
||||||
|
test_ds = torchvision.datasets.MNIST(
|
||||||
|
root=DATA_ROOT,
|
||||||
|
train=False,
|
||||||
|
download=True,
|
||||||
|
transform=normalise_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dl = torch.utils.data.DataLoader(
|
||||||
|
train_ds, batch_size=BATCH_SIZE, shuffle=True
|
||||||
|
)
|
||||||
|
test_dl = torch.utils.data.DataLoader(
|
||||||
|
test_ds, batch_size=BATCH_SIZE, shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dl, test_dl
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train_dl, test_dl = load_data()
|
||||||
|
|
||||||
|
key = jax.random.PRNGKey(SEED)
|
||||||
|
key, key1 = jax.random.split(key, 2)
|
||||||
|
model = CNN(key1)
|
||||||
|
optim = optax.adamw(LEARNING_RATE)
|
||||||
|
|
||||||
|
dummy_x, dummy_y = next(iter(train_dl))
|
||||||
|
dummy_x = dummy_x.numpy() # 64x1x28x28
|
||||||
|
dummy_y = dummy_y.numpy() # 64
|
||||||
|
|
||||||
|
print(jax.vmap(model)(dummy_x).shape)
|
||||||
|
print(CNN.loss(model, dummy_x, dummy_y))
|
||||||
|
|
||||||
|
params, statics = eqx.partition(model, eqx.is_array)
|
||||||
|
loss2 = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y)
|
||||||
|
|
||||||
|
model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10)
|
||||||
|
|
||||||
|
#print(evaluate_model(model, test_dl))
|
||||||
|
#loss_w_grad = jax.value_and_grad(loss2)
|
||||||
|
#l, g = loss_w_grad(params, statics, dummy_x, dummy_y)
|
||||||
|
#print(type(l))
|
||||||
|
#print(type(g))
|
Loading…
Reference in a new issue