182 lines
5.2 KiB
Python
Executable file
182 lines
5.2 KiB
Python
Executable file
#!/usr/bin/env python3
|
||
"""
|
||
This is the CNN tutorial from https://docs.kidger.site/equinox/examples/mnist/,
|
||
just using it to learn equinox
|
||
"""
|
||
import equinox as eqx
|
||
import jax.numpy as jnp
|
||
import jax
|
||
import optax
|
||
import time
|
||
import torch
|
||
import torchvision
|
||
from jaxtyping import Array, Float, Int, PyTree
|
||
|
||
|
||
class CNN(eqx.Module):
|
||
layers: list
|
||
|
||
def __init__(self, key):
|
||
keys = jax.random.split(key, 4)
|
||
keys = list(keys)
|
||
|
||
self.layers = [
|
||
eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]),
|
||
eqx.nn.MaxPool2d(kernel_size=2),
|
||
jax.nn.relu,
|
||
jnp.ravel,
|
||
eqx.nn.Linear(1728, 512, key=keys[1]),
|
||
jax.nn.sigmoid,
|
||
eqx.nn.Linear(512, 64, key=keys[2]),
|
||
jax.nn.relu,
|
||
eqx.nn.Linear(64, 10, key=keys[3]),
|
||
jax.nn.log_softmax,
|
||
]
|
||
|
||
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
|
||
for layer in self.layers:
|
||
x = layer(x)
|
||
return x
|
||
|
||
|
||
def cross_entropy(
|
||
y: Int[Array, " batch"], pred_y: Int[Array, " batch"]
|
||
) -> Float[Array, ""]:
|
||
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
|
||
return -jnp.mean(pred_y)
|
||
|
||
|
||
@eqx.filter_jit
|
||
def loss(
|
||
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||
) -> Float[Array, ""]:
|
||
pred_y = jax.vmap(model)(x)
|
||
return cross_entropy(y, pred_y)
|
||
|
||
|
||
@eqx.filter_jit
|
||
def compute_accuracy(
|
||
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||
) -> Float[Array, ""]:
|
||
pred_y = jax.vmap(model)(x)
|
||
pred_y = jnp.argmax(pred_y, axis=1)
|
||
return jnp.mean(y == pred_y)
|
||
|
||
|
||
def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
|
||
avg_loss = 0
|
||
avg_acc = 0
|
||
|
||
for x, y in testloader:
|
||
x = jnp.array(x.numpy())
|
||
y = jnp.array(y.numpy())
|
||
|
||
avg_loss += loss(model, x, y)
|
||
avg_acc += compute_accuracy(model, x, y)
|
||
|
||
return avg_loss / len(testloader), avg_acc / len(testloader)
|
||
|
||
|
||
def train(
|
||
model: CNN,
|
||
trainloader: torch.utils.data.DataLoader,
|
||
testloader: torch.utils.data.DataLoader,
|
||
optim: optax.GradientTransformation,
|
||
steps: int,
|
||
print_every: int,
|
||
) -> CNN:
|
||
@eqx.filter_jit
|
||
def make_step(
|
||
model: CNN,
|
||
opt_state: PyTree,
|
||
x: Float[Array, "batch 1 28 28"],
|
||
y: Int[Array, "batch"],
|
||
):
|
||
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
|
||
updates, opt_state = optim.update(
|
||
grads, opt_state, eqx.filter(model, eqx.is_array)
|
||
)
|
||
model = eqx.apply_updates(model, updates)
|
||
return model, opt_state, loss_value
|
||
|
||
def infinite_data(loader: torch.utils.data.DataLoader):
|
||
while True:
|
||
yield from loader # Yields from loader until exhausted
|
||
|
||
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
||
|
||
for step, (x, y) in zip(range(steps), infinite_data(trainloader)):
|
||
x = jnp.array(x.numpy())
|
||
y = jnp.array(y.numpy())
|
||
|
||
model, opt_state, train_loss = make_step(model, opt_state, x, y)
|
||
|
||
if (step % print_every) == 0 or step == steps - 1:
|
||
avg_loss, avg_acc = evaluate(model, testloader)
|
||
|
||
jax.debug.print("==== step {} ====", step)
|
||
jax.debug.print("train loss = {}", train_loss)
|
||
jax.debug.print("test loss = {}", avg_loss)
|
||
jax.debug.print("text accuracy = {}", avg_acc)
|
||
|
||
return model
|
||
|
||
|
||
# ╔─────────────────────────────────────────────────────────────────────────────╗
|
||
# │ Mαiη scriρτ |
|
||
# ╚─────────────────────────────────────────────────────────────────────────────╝
|
||
jax.config.update("jax_platform_name", "gpu") # Sets preferred device
|
||
|
||
# Checking to make sure gpu is being used
|
||
from jax.extend import backend
|
||
|
||
print(backend.get_backend().platform)
|
||
print(f"JAX devices: {jax.devices()}")
|
||
print(f"Default device: {jax.default_backend()}")
|
||
|
||
# Hyperparameters
|
||
BATCH_SIZE = 1024
|
||
LEARNING_RATE = 1e-4
|
||
STEPS = 1200
|
||
PRINT_EVERY = 300
|
||
SEED = 5678
|
||
|
||
key = jax.random.PRNGKey(SEED)
|
||
key, subkey = jax.random.split(key, 2)
|
||
|
||
# Data preprocessing
|
||
normalize_data = torchvision.transforms.Compose(
|
||
[
|
||
torchvision.transforms.ToTensor(),
|
||
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
||
]
|
||
)
|
||
|
||
train_dataset = torchvision.datasets.MNIST(
|
||
"MNIST",
|
||
train=True,
|
||
download=True,
|
||
transform=normalize_data,
|
||
)
|
||
test_dataset = torchvision.datasets.MNIST(
|
||
"MNIST",
|
||
train=False,
|
||
download=True,
|
||
transform=normalize_data,
|
||
)
|
||
|
||
trainloader = torch.utils.data.DataLoader(
|
||
train_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||
)
|
||
testloader = torch.utils.data.DataLoader(
|
||
test_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||
)
|
||
|
||
model = CNN(subkey)
|
||
optim = optax.adamw(LEARNING_RATE)
|
||
|
||
start = time.time()
|
||
model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)
|
||
cease = time.time()
|
||
|
||
print(f"Took {cease-start}s")
|