#!/usr/bin/env python3 """ This is the CNN tutorial from https://docs.kidger.site/equinox/examples/mnist/, just using it to learn equinox """ import equinox as eqx import jax.numpy as jnp import jax import optax import time import torch import torchvision from jaxtyping import Array, Float, Int, PyTree class CNN(eqx.Module): layers: list def __init__(self, key): keys = jax.random.split(key, 4) keys = list(keys) self.layers = [ eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]), eqx.nn.MaxPool2d(kernel_size=2), jax.nn.relu, jnp.ravel, eqx.nn.Linear(1728, 512, key=keys[1]), jax.nn.sigmoid, eqx.nn.Linear(512, 64, key=keys[2]), jax.nn.relu, eqx.nn.Linear(64, 10, key=keys[3]), jax.nn.log_softmax, ] def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]: for layer in self.layers: x = layer(x) return x def cross_entropy( y: Int[Array, " batch"], pred_y: Int[Array, " batch"] ) -> Float[Array, ""]: pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) return -jnp.mean(pred_y) @eqx.filter_jit def loss( model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] ) -> Float[Array, ""]: pred_y = jax.vmap(model)(x) return cross_entropy(y, pred_y) @eqx.filter_jit def compute_accuracy( model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] ) -> Float[Array, ""]: pred_y = jax.vmap(model)(x) pred_y = jnp.argmax(pred_y, axis=1) return jnp.mean(y == pred_y) def evaluate(model: CNN, testloader: torch.utils.data.DataLoader): avg_loss = 0 avg_acc = 0 for x, y in testloader: x = jnp.array(x.numpy()) y = jnp.array(y.numpy()) avg_loss += loss(model, x, y) avg_acc += compute_accuracy(model, x, y) return avg_loss / len(testloader), avg_acc / len(testloader) def train( model: CNN, trainloader: torch.utils.data.DataLoader, testloader: torch.utils.data.DataLoader, optim: optax.GradientTransformation, steps: int, print_every: int, ) -> CNN: @eqx.filter_jit def make_step( model: CNN, opt_state: PyTree, x: Float[Array, "batch 1 28 28"], y: Int[Array, "batch"], ): loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y) updates, opt_state = optim.update( grads, opt_state, eqx.filter(model, eqx.is_array) ) model = eqx.apply_updates(model, updates) return model, opt_state, loss_value def infinite_data(loader: torch.utils.data.DataLoader): while True: yield from loader # Yields from loader until exhausted opt_state = optim.init(eqx.filter(model, eqx.is_array)) for step, (x, y) in zip(range(steps), infinite_data(trainloader)): x = jnp.array(x.numpy()) y = jnp.array(y.numpy()) model, opt_state, train_loss = make_step(model, opt_state, x, y) if (step % print_every) == 0 or step == steps - 1: avg_loss, avg_acc = evaluate(model, testloader) jax.debug.print("==== step {} ====", step) jax.debug.print("train loss = {}", train_loss) jax.debug.print("test loss = {}", avg_loss) jax.debug.print("text accuracy = {}", avg_acc) return model # ╔─────────────────────────────────────────────────────────────────────────────╗ # │ Main script | # ╚─────────────────────────────────────────────────────────────────────────────╝ jax.config.update("jax_platform_name", "gpu") # Sets preferred device # Checking to make sure gpu is being used from jax.extend import backend print(backend.get_backend().platform) print(f"JAX devices: {jax.devices()}") print(f"Default device: {jax.default_backend()}") # Hyperparameters BATCH_SIZE = 1024 LEARNING_RATE = 1e-4 STEPS = 1200 PRINT_EVERY = 300 SEED = 5678 key = jax.random.PRNGKey(SEED) key, subkey = jax.random.split(key, 2) # Data preprocessing normalize_data = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (0.5,)), ] ) train_dataset = torchvision.datasets.MNIST( "MNIST", train=True, download=True, transform=normalize_data, ) test_dataset = torchvision.datasets.MNIST( "MNIST", train=False, download=True, transform=normalize_data, ) trainloader = torch.utils.data.DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=True ) testloader = torch.utils.data.DataLoader( test_dataset, batch_size=BATCH_SIZE, shuffle=True ) model = CNN(subkey) optim = optax.adamw(LEARNING_RATE) start = time.time() model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY) cease = time.time() print(f"Took {cease-start}s")