diff --git a/resnet_cifar10.py b/resnet_cifar10.py new file mode 100755 index 0000000..821c0e0 --- /dev/null +++ b/resnet_cifar10.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +""" +This is the CNN tutorial from https://docs.kidger.site/equinox/examples/mnist/, +just using it to learn equinox +""" +import equinox as eqx +import jax.numpy as jnp +import jax +import optax +import time +import torch +import torchvision +from jaxtyping import Array, Float, Int, PyTree + + +class CNN(eqx.Module): + layers: list + + def __init__(self, key): + keys = jax.random.split(key, 4) + keys = list(keys) + + self.layers = [ + eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]), + eqx.nn.MaxPool2d(kernel_size=2), + jax.nn.relu, + jnp.ravel, + eqx.nn.Linear(1728, 512, key=keys[1]), + jax.nn.sigmoid, + eqx.nn.Linear(512, 64, key=keys[2]), + jax.nn.relu, + eqx.nn.Linear(64, 10, key=keys[3]), + jax.nn.log_softmax, + ] + + def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]: + for layer in self.layers: + x = layer(x) + return x + + +def cross_entropy( + y: Int[Array, " batch"], pred_y: Int[Array, " batch"] +) -> Float[Array, ""]: + pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) + return -jnp.mean(pred_y) + + +@eqx.filter_jit +def loss( + model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] +) -> Float[Array, ""]: + pred_y = jax.vmap(model)(x) + return cross_entropy(y, pred_y) + + +@eqx.filter_jit +def compute_accuracy( + model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] +) -> Float[Array, ""]: + pred_y = jax.vmap(model)(x) + pred_y = jnp.argmax(pred_y, axis=1) + return jnp.mean(y == pred_y) + + +def evaluate(model: CNN, testloader: torch.utils.data.DataLoader): + avg_loss = 0 + avg_acc = 0 + + for x, y in testloader: + x = jnp.array(x.numpy()) + y = jnp.array(y.numpy()) + + avg_loss += loss(model, x, y) + avg_acc += compute_accuracy(model, x, y) + + return avg_loss / len(testloader), avg_acc / len(testloader) + + +def train( + model: CNN, + trainloader: torch.utils.data.DataLoader, + testloader: torch.utils.data.DataLoader, + optim: optax.GradientTransformation, + steps: int, + print_every: int, +) -> CNN: + @eqx.filter_jit + def make_step( + model: CNN, + opt_state: PyTree, + x: Float[Array, "batch 1 28 28"], + y: Int[Array, "batch"], + ): + loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y) + updates, opt_state = optim.update( + grads, opt_state, eqx.filter(model, eqx.is_array) + ) + model = eqx.apply_updates(model, updates) + return model, opt_state, loss_value + + def infinite_data(loader: torch.utils.data.DataLoader): + while True: + yield from loader # Yields from loader until exhausted + + opt_state = optim.init(eqx.filter(model, eqx.is_array)) + + for step, (x, y) in zip(range(steps), infinite_data(trainloader)): + x = jnp.array(x.numpy()) + y = jnp.array(y.numpy()) + + model, opt_state, train_loss = make_step(model, opt_state, x, y) + + if (step % print_every) == 0 or step == steps - 1: + avg_loss, avg_acc = evaluate(model, testloader) + + jax.debug.print("==== step {} ====", step) + jax.debug.print("train loss = {}", train_loss) + jax.debug.print("test loss = {}", avg_loss) + jax.debug.print("text accuracy = {}", avg_acc) + + return model + + +def build_dataloader(is_train): + global BATCH_SIZE + + transform_train = torchvision.transforms.Compose([ + torchvision.transforms.RandomCrop(32, padding=4), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.247, 0.243, 0.261)) + ]) + transform_test = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.247, 0.243, 0.261)) + ]) + + dataset = torchvision.datasets.CIFAR10( + "data", + train=is_train, + download=True, + transform=(transform_train if is_train else transform_test) + ) + + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=BATCH_SIZE, shuffle=True + ) + + class DataLoaderWrapper: + def __init__(self, dataloader, nb_classes): + self.dataloader = dataloader + self.nb_classes = nb_classes + + def __iter__(self): + for images, labels in self.dataloader: + images = jnp.array(images) + + labels = jnp.array(labels) + labels = jax.nn.one_hot(labels, 10) + + yield (images, labels) + + return DataLoaderWrapper(dataloader, 10) + + + +# ╔─────────────────────────────────────────────────────────────────────────────╗ +# │ Main script | +# ╚─────────────────────────────────────────────────────────────────────────────╝ +jax.config.update("jax_platform_name", "gpu") # Sets preferred device + +# Checking to make sure gpu is being used +from jax.extend import backend + +print(backend.get_backend().platform) +print(f"JAX devices: {jax.devices()}") +print(f"Default device: {jax.default_backend()}") + +# Hyperparameters +BATCH_SIZE = 16 +LEARNING_RATE = 1e-4 +STEPS = 1200 +PRINT_EVERY = 300 +SEED = 5678 + +key = jax.random.PRNGKey(SEED) +key, subkey = jax.random.split(key, 2) + +dataloader = build_dataloader(False) + +print(dataloader) + +x = next(iter(dataloader)) +print(type(x), len(x)) +print(type(x[0]), type(x[1])) +print(x[0].shape, x[1].shape) +# x[1] = jnp.array(x[1]) +# print(f"Max: {jnp.max(x[1])}") +# print(f"Min: {jnp.min(x[1])}") +# print(f"Mean: {jnp.mean(x[1])}") +print(f"First: {x[0][0, 0]}") +# print(f"1hot: {jax.nn.one_hot(x[1][0], 10)}") + +exit(1) + +# model = CNN(subkey) +# optim = optax.adamw(LEARNING_RATE) +# +# start = time.time() +# model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY) +# cease = time.time() + +print(f"Took {cease-start}s")