From 3fca3d08d598b0631c79adafed0fdd4ba059692f Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Mon, 11 Nov 2024 23:52:56 -0700 Subject: [PATCH] Add equinox tutorial code --- cnn_tutorial.py | 182 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100755 cnn_tutorial.py diff --git a/cnn_tutorial.py b/cnn_tutorial.py new file mode 100755 index 0000000..e164147 --- /dev/null +++ b/cnn_tutorial.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +""" +This is the CNN tutorial from https://docs.kidger.site/equinox/examples/mnist/, +just using it to learn equinox +""" +import equinox as eqx +import jax.numpy as jnp +import jax +import optax +import time +import torch +import torchvision +from jaxtyping import Array, Float, Int, PyTree + + +class CNN(eqx.Module): + layers: list + + def __init__(self, key): + keys = jax.random.split(key, 4) + keys = list(keys) + + self.layers = [ + eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]), + eqx.nn.MaxPool2d(kernel_size=2), + jax.nn.relu, + jnp.ravel, + eqx.nn.Linear(1728, 512, key=keys[1]), + jax.nn.sigmoid, + eqx.nn.Linear(512, 64, key=keys[2]), + jax.nn.relu, + eqx.nn.Linear(64, 10, key=keys[3]), + jax.nn.log_softmax, + ] + + def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]: + for layer in self.layers: + x = layer(x) + return x + + +def cross_entropy( + y: Int[Array, " batch"], pred_y: Int[Array, " batch"] +) -> Float[Array, ""]: + pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) + return -jnp.mean(pred_y) + + +@eqx.filter_jit +def loss( + model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] +) -> Float[Array, ""]: + pred_y = jax.vmap(model)(x) + return cross_entropy(y, pred_y) + + +@eqx.filter_jit +def compute_accuracy( + model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] +) -> Float[Array, ""]: + pred_y = jax.vmap(model)(x) + pred_y = jnp.argmax(pred_y, axis=1) + return jnp.mean(y == pred_y) + + +def evaluate(model: CNN, testloader: torch.utils.data.DataLoader): + avg_loss = 0 + avg_acc = 0 + + for x, y in testloader: + x = jnp.array(x.numpy()) + y = jnp.array(y.numpy()) + + avg_loss += loss(model, x, y) + avg_acc += compute_accuracy(model, x, y) + + return avg_loss / len(testloader), avg_acc / len(testloader) + + +def train( + model: CNN, + trainloader: torch.utils.data.DataLoader, + testloader: torch.utils.data.DataLoader, + optim: optax.GradientTransformation, + steps: int, + print_every: int, +) -> CNN: + @eqx.filter_jit + def make_step( + model: CNN, + opt_state: PyTree, + x: Float[Array, "batch 1 28 28"], + y: Int[Array, "batch"], + ): + loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y) + updates, opt_state = optim.update( + grads, opt_state, eqx.filter(model, eqx.is_array) + ) + model = eqx.apply_updates(model, updates) + return model, opt_state, loss_value + + def infinite_data(loader: torch.utils.data.DataLoader): + while True: + yield from loader # Yields from loader until exhausted + + opt_state = optim.init(eqx.filter(model, eqx.is_array)) + + for step, (x, y) in zip(range(steps), infinite_data(trainloader)): + x = jnp.array(x.numpy()) + y = jnp.array(y.numpy()) + + model, opt_state, train_loss = make_step(model, opt_state, x, y) + + if (step % print_every) == 0 or step == steps - 1: + avg_loss, avg_acc = evaluate(model, testloader) + + jax.debug.print("==== step {} ====", step) + jax.debug.print("train loss = {}", train_loss) + jax.debug.print("test loss = {}", avg_loss) + jax.debug.print("text accuracy = {}", avg_acc) + + return model + + +# ╔─────────────────────────────────────────────────────────────────────────────╗ +# │ Mαiη scriρτ | +# ╚─────────────────────────────────────────────────────────────────────────────╝ +jax.config.update("jax_platform_name", "gpu") # Sets preferred device + +# Checking to make sure gpu is being used +from jax.extend import backend + +print(backend.get_backend().platform) +print(f"JAX devices: {jax.devices()}") +print(f"Default device: {jax.default_backend()}") + +# Hyperparameters +BATCH_SIZE = 1024 +LEARNING_RATE = 1e-4 +STEPS = 1200 +PRINT_EVERY = 300 +SEED = 5678 + +key = jax.random.PRNGKey(SEED) +key, subkey = jax.random.split(key, 2) + +# Data preprocessing +normalize_data = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5,), (0.5,)), + ] +) + +train_dataset = torchvision.datasets.MNIST( + "MNIST", + train=True, + download=True, + transform=normalize_data, +) +test_dataset = torchvision.datasets.MNIST( + "MNIST", + train=False, + download=True, + transform=normalize_data, +) + +trainloader = torch.utils.data.DataLoader( + train_dataset, batch_size=BATCH_SIZE, shuffle=True +) +testloader = torch.utils.data.DataLoader( + test_dataset, batch_size=BATCH_SIZE, shuffle=True +) + +model = CNN(subkey) +optim = optax.adamw(LEARNING_RATE) + +start = time.time() +model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY) +cease = time.time() + +print(f"Took {cease-start}s")