Add equinox tutorial code

This commit is contained in:
Akemi Izuko 2024-11-11 23:52:56 -07:00
commit 3fca3d08d5
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

182
cnn_tutorial.py Executable file
View file

@ -0,0 +1,182 @@
#!/usr/bin/env python3
"""
This is the CNN tutorial from https://docs.kidger.site/equinox/examples/mnist/,
just using it to learn equinox
"""
import equinox as eqx
import jax.numpy as jnp
import jax
import optax
import time
import torch
import torchvision
from jaxtyping import Array, Float, Int, PyTree
class CNN(eqx.Module):
layers: list
def __init__(self, key):
keys = jax.random.split(key, 4)
keys = list(keys)
self.layers = [
eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]),
eqx.nn.MaxPool2d(kernel_size=2),
jax.nn.relu,
jnp.ravel,
eqx.nn.Linear(1728, 512, key=keys[1]),
jax.nn.sigmoid,
eqx.nn.Linear(512, 64, key=keys[2]),
jax.nn.relu,
eqx.nn.Linear(64, 10, key=keys[3]),
jax.nn.log_softmax,
]
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
for layer in self.layers:
x = layer(x)
return x
def cross_entropy(
y: Int[Array, " batch"], pred_y: Int[Array, " batch"]
) -> Float[Array, ""]:
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
return -jnp.mean(pred_y)
@eqx.filter_jit
def loss(
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
pred_y = jax.vmap(model)(x)
return cross_entropy(y, pred_y)
@eqx.filter_jit
def compute_accuracy(
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
pred_y = jax.vmap(model)(x)
pred_y = jnp.argmax(pred_y, axis=1)
return jnp.mean(y == pred_y)
def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
avg_loss = 0
avg_acc = 0
for x, y in testloader:
x = jnp.array(x.numpy())
y = jnp.array(y.numpy())
avg_loss += loss(model, x, y)
avg_acc += compute_accuracy(model, x, y)
return avg_loss / len(testloader), avg_acc / len(testloader)
def train(
model: CNN,
trainloader: torch.utils.data.DataLoader,
testloader: torch.utils.data.DataLoader,
optim: optax.GradientTransformation,
steps: int,
print_every: int,
) -> CNN:
@eqx.filter_jit
def make_step(
model: CNN,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, "batch"],
):
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
updates, opt_state = optim.update(
grads, opt_state, eqx.filter(model, eqx.is_array)
)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss_value
def infinite_data(loader: torch.utils.data.DataLoader):
while True:
yield from loader # Yields from loader until exhausted
opt_state = optim.init(eqx.filter(model, eqx.is_array))
for step, (x, y) in zip(range(steps), infinite_data(trainloader)):
x = jnp.array(x.numpy())
y = jnp.array(y.numpy())
model, opt_state, train_loss = make_step(model, opt_state, x, y)
if (step % print_every) == 0 or step == steps - 1:
avg_loss, avg_acc = evaluate(model, testloader)
jax.debug.print("==== step {} ====", step)
jax.debug.print("train loss = {}", train_loss)
jax.debug.print("test loss = {}", avg_loss)
jax.debug.print("text accuracy = {}", avg_acc)
return model
# ╔─────────────────────────────────────────────────────────────────────────────╗
# │ Mαiη scriρτ |
# ╚─────────────────────────────────────────────────────────────────────────────╝
jax.config.update("jax_platform_name", "gpu") # Sets preferred device
# Checking to make sure gpu is being used
from jax.extend import backend
print(backend.get_backend().platform)
print(f"JAX devices: {jax.devices()}")
print(f"Default device: {jax.default_backend()}")
# Hyperparameters
BATCH_SIZE = 1024
LEARNING_RATE = 1e-4
STEPS = 1200
PRINT_EVERY = 300
SEED = 5678
key = jax.random.PRNGKey(SEED)
key, subkey = jax.random.split(key, 2)
# Data preprocessing
normalize_data = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)),
]
)
train_dataset = torchvision.datasets.MNIST(
"MNIST",
train=True,
download=True,
transform=normalize_data,
)
test_dataset = torchvision.datasets.MNIST(
"MNIST",
train=False,
download=True,
transform=normalize_data,
)
trainloader = torch.utils.data.DataLoader(
train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
test_dataset, batch_size=BATCH_SIZE, shuffle=True
)
model = CNN(subkey)
optim = optax.adamw(LEARNING_RATE)
start = time.time()
model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)
cease = time.time()
print(f"Took {cease-start}s")