commit a9700e22bb9ebd3f7a841ce59b45ae67b79af311 Author: Akemi Izuko Date: Mon Dec 16 20:19:37 2024 -0700 Init diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..159fb2a --- /dev/null +++ b/src/main.py @@ -0,0 +1,206 @@ +# Code from: +# https://docs.kidger.site/equinox/examples/mnist/ +import equinox as eqx +import jax +import jax.numpy as jnp +import optax +import torch +import torchvision +from functools import partial +from jaxtyping import Array, Float, Int, PyTree +from typing import Tuple + + +DATA_ROOT="data" +BATCH_SIZE = 64 +LEARNING_RATE = 3e-4 +STEPS = 300 +PRINT_EVERY = 30 +SEED = 5678 + + +class CNN(eqx.Module): + layers: list + + def __init__(self, key): + key1, key2, key3, key4 = jax.random.split(key, 4) + self.layers = [ + eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1), + eqx.nn.MaxPool2d(kernel_size=2), + jax.nn.relu, + jnp.ravel, + eqx.nn.Linear(1728, 512, key=key2), + jax.nn.sigmoid, + eqx.nn.Linear(512, 64, key=key3), + jax.nn.relu, + eqx.nn.Linear(64, 10, key=key4), + jax.nn.log_softmax, + ] + + def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]: + for layer in self.layers: + x = layer(x) + return x + + @staticmethod + def cross_entropy( + y: Int[Array, " batch"], + pred_y: Float[Array, "batch 10"] + ) -> Float[Array, ""]: + pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) + return -jnp.mean(pred_y) + + @staticmethod + def loss( + model, + x: Float[Array, "batch 1 28 28"], + y: Int[Array, " batch"] + ) -> Float[Array, ""]: + pred_y = jax.vmap(model)(x) + return CNN.cross_entropy(y, pred_y) + + @staticmethod + @partial(jax.jit, static_argnums=(1,)) + def loss2( + params, + statics, + x: Float[Array, "batch 1 28 28"], + y: Int[Array, " batch"] + ) -> Float[Array, ""]: + model = eqx.combine(params, statics) + pred_y = jax.vmap(model)(x) + return CNN.cross_entropy(y, pred_y) + + @staticmethod + @partial(jax.jit, static_argnums=(1,)) + def get_accuracy( + params, + statics, + x: Float[Array, "batch 1 28 28"], + y: Float[Array, "batch"], + ) -> Float[Array, ""]: + model = eqx.combine(params, statics) + pred_y = jax.vmap(model)(x) + pred_y = jnp.argmax(pred_y, axis=1) + return jnp.mean(y == pred_y) + + @staticmethod + def make_step( + model, + opt_state: PyTree, + optim: optax.GradientTransformation, + x: Float[Array, "batch 1 28 28"], + y: Float[Array, "batch"], + ): + params, statics = eqx.partition(model, eqx.is_array) + + loss_with_grad = lambda p, s, x, y: model.loss(eqx.combine(p,s), x, y) + loss_with_grad = jax.value_and_grad(loss_with_grad) + loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,)) + + loss, grad = loss_with_grad(params, statics, x, y) + updates, opt_state = optim.update( + grad, opt_state, eqx.filter(model, eqx.is_array) + ) + model = eqx.apply_updates(model, updates) + return model, opt_state, loss + + + @staticmethod + def train( + model, + train_dl: torch.utils.data.DataLoader, + test_dl: torch.utils.data.DataLoader, + optim: optax.GradientTransformation, + steps: int, + log_every: int, + ): + opt_state = optim.init(eqx.filter(model, eqx.is_array)) + + def infinite_dl(dl): + while True: + yield from dl + + for step, (x, y) in zip(range(steps), infinite_dl(train_dl)): + x = x.numpy() + y = y.numpy() + model, opt_state, loss = model.make_step(model, opt_state, optim, x, y) + + if (step % log_every == 0) or (step == steps - 1): + test_loss, test_acc = evaluate_model(model, train_dl) + print(f"{step=}, train_loss={loss.item()}, test_loss={test_loss.item()}") + print(f"{step=}, train_loss={loss.item()}") + + return model + + +def evaluate_model(model, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]: + avg_loss = 0 + avg_acc = 0 + + for x, y in test_dl: + x = x.numpy() + y = y.numpy() + params, statics = eqx.partition(model, eqx.is_array) + avg_loss += CNN.loss2(params, statics, x, y) + avg_acc += CNN.get_accuracy(params, statics, x, y) + + avg_loss /= len(test_dl) + avg_acc /= len(test_dl) + return avg_loss, avg_acc + + +def load_data(): + normalise_data = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5,), (0.5,)), + ]) + train_ds = torchvision.datasets.MNIST( + root=DATA_ROOT, + train=True, + download=True, + transform=normalise_data, + ) + test_ds = torchvision.datasets.MNIST( + root=DATA_ROOT, + train=False, + download=True, + transform=normalise_data, + ) + + train_dl = torch.utils.data.DataLoader( + train_ds, batch_size=BATCH_SIZE, shuffle=True + ) + test_dl = torch.utils.data.DataLoader( + test_ds, batch_size=BATCH_SIZE, shuffle=True + ) + + return train_dl, test_dl + + + +if __name__ == '__main__': + train_dl, test_dl = load_data() + + key = jax.random.PRNGKey(SEED) + key, key1 = jax.random.split(key, 2) + model = CNN(key1) + optim = optax.adamw(LEARNING_RATE) + + dummy_x, dummy_y = next(iter(train_dl)) + dummy_x = dummy_x.numpy() # 64x1x28x28 + dummy_y = dummy_y.numpy() # 64 + + print(jax.vmap(model)(dummy_x).shape) + print(CNN.loss(model, dummy_x, dummy_y)) + + params, statics = eqx.partition(model, eqx.is_array) + loss2 = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y) + + model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10) + + #print(evaluate_model(model, test_dl)) + #loss_w_grad = jax.value_and_grad(loss2) + #l, g = loss_w_grad(params, statics, dummy_x, dummy_y) + #print(type(l)) + #print(type(g))