This commit is contained in:
Akemi Izuko 2024-12-16 20:19:37 -07:00
commit a9700e22bb
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

206
src/main.py Normal file
View file

@ -0,0 +1,206 @@
# Code from:
# https://docs.kidger.site/equinox/examples/mnist/
import equinox as eqx
import jax
import jax.numpy as jnp
import optax
import torch
import torchvision
from functools import partial
from jaxtyping import Array, Float, Int, PyTree
from typing import Tuple
DATA_ROOT="data"
BATCH_SIZE = 64
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 30
SEED = 5678
class CNN(eqx.Module):
layers: list
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
self.layers = [
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
eqx.nn.MaxPool2d(kernel_size=2),
jax.nn.relu,
jnp.ravel,
eqx.nn.Linear(1728, 512, key=key2),
jax.nn.sigmoid,
eqx.nn.Linear(512, 64, key=key3),
jax.nn.relu,
eqx.nn.Linear(64, 10, key=key4),
jax.nn.log_softmax,
]
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
for layer in self.layers:
x = layer(x)
return x
@staticmethod
def cross_entropy(
y: Int[Array, " batch"],
pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
return -jnp.mean(pred_y)
@staticmethod
def loss(
model,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"]
) -> Float[Array, ""]:
pred_y = jax.vmap(model)(x)
return CNN.cross_entropy(y, pred_y)
@staticmethod
@partial(jax.jit, static_argnums=(1,))
def loss2(
params,
statics,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"]
) -> Float[Array, ""]:
model = eqx.combine(params, statics)
pred_y = jax.vmap(model)(x)
return CNN.cross_entropy(y, pred_y)
@staticmethod
@partial(jax.jit, static_argnums=(1,))
def get_accuracy(
params,
statics,
x: Float[Array, "batch 1 28 28"],
y: Float[Array, "batch"],
) -> Float[Array, ""]:
model = eqx.combine(params, statics)
pred_y = jax.vmap(model)(x)
pred_y = jnp.argmax(pred_y, axis=1)
return jnp.mean(y == pred_y)
@staticmethod
def make_step(
model,
opt_state: PyTree,
optim: optax.GradientTransformation,
x: Float[Array, "batch 1 28 28"],
y: Float[Array, "batch"],
):
params, statics = eqx.partition(model, eqx.is_array)
loss_with_grad = lambda p, s, x, y: model.loss(eqx.combine(p,s), x, y)
loss_with_grad = jax.value_and_grad(loss_with_grad)
loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,))
loss, grad = loss_with_grad(params, statics, x, y)
updates, opt_state = optim.update(
grad, opt_state, eqx.filter(model, eqx.is_array)
)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
@staticmethod
def train(
model,
train_dl: torch.utils.data.DataLoader,
test_dl: torch.utils.data.DataLoader,
optim: optax.GradientTransformation,
steps: int,
log_every: int,
):
opt_state = optim.init(eqx.filter(model, eqx.is_array))
def infinite_dl(dl):
while True:
yield from dl
for step, (x, y) in zip(range(steps), infinite_dl(train_dl)):
x = x.numpy()
y = y.numpy()
model, opt_state, loss = model.make_step(model, opt_state, optim, x, y)
if (step % log_every == 0) or (step == steps - 1):
test_loss, test_acc = evaluate_model(model, train_dl)
print(f"{step=}, train_loss={loss.item()}, test_loss={test_loss.item()}")
print(f"{step=}, train_loss={loss.item()}")
return model
def evaluate_model(model, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
avg_loss = 0
avg_acc = 0
for x, y in test_dl:
x = x.numpy()
y = y.numpy()
params, statics = eqx.partition(model, eqx.is_array)
avg_loss += CNN.loss2(params, statics, x, y)
avg_acc += CNN.get_accuracy(params, statics, x, y)
avg_loss /= len(test_dl)
avg_acc /= len(test_dl)
return avg_loss, avg_acc
def load_data():
normalise_data = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)),
])
train_ds = torchvision.datasets.MNIST(
root=DATA_ROOT,
train=True,
download=True,
transform=normalise_data,
)
test_ds = torchvision.datasets.MNIST(
root=DATA_ROOT,
train=False,
download=True,
transform=normalise_data,
)
train_dl = torch.utils.data.DataLoader(
train_ds, batch_size=BATCH_SIZE, shuffle=True
)
test_dl = torch.utils.data.DataLoader(
test_ds, batch_size=BATCH_SIZE, shuffle=True
)
return train_dl, test_dl
if __name__ == '__main__':
train_dl, test_dl = load_data()
key = jax.random.PRNGKey(SEED)
key, key1 = jax.random.split(key, 2)
model = CNN(key1)
optim = optax.adamw(LEARNING_RATE)
dummy_x, dummy_y = next(iter(train_dl))
dummy_x = dummy_x.numpy() # 64x1x28x28
dummy_y = dummy_y.numpy() # 64
print(jax.vmap(model)(dummy_x).shape)
print(CNN.loss(model, dummy_x, dummy_y))
params, statics = eqx.partition(model, eqx.is_array)
loss2 = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y)
model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10)
#print(evaluate_model(model, test_dl))
#loss_w_grad = jax.value_and_grad(loss2)
#l, g = loss_w_grad(params, statics, dummy_x, dummy_y)
#print(type(l))
#print(type(g))