Main: improve typing

This commit is contained in:
Akemi Izuko 2024-12-16 21:13:12 -07:00
parent 3b062bc31b
commit 59cf5aea85
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -3,12 +3,13 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
import torchvision
from functools import partial
from jaxtyping import Array, Float, Int, PyTree
from typing import Tuple
from typing import Tuple, ForwardRef
DATA_ROOT="data"
@ -50,20 +51,11 @@ class CNN(eqx.Module):
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
return -jnp.mean(pred_y)
@staticmethod
def loss(
model,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"]
) -> Float[Array, ""]:
pred_y = jax.vmap(model)(x)
return CNN.cross_entropy(y, pred_y)
@staticmethod
@partial(jax.jit, static_argnums=(1,))
def loss2(
params,
statics,
params: PyTree,
statics: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"]
) -> Float[Array, ""]:
@ -74,8 +66,8 @@ class CNN(eqx.Module):
@staticmethod
@partial(jax.jit, static_argnums=(1,))
def get_accuracy(
params,
statics,
params: PyTree,
statics: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Float[Array, "batch"],
) -> Float[Array, ""]:
@ -87,18 +79,14 @@ class CNN(eqx.Module):
@staticmethod
@partial(jax.jit, static_argnums=(1,3))
def make_step(
params,
statics,
params: PyTree,
statics: PyTree,
opt_state: PyTree,
optim: optax.GradientTransformation,
x: Float[Array, "batch 1 28 28"],
y: Float[Array, "batch"],
):
loss_with_grad = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y)
loss_with_grad = jax.value_and_grad(loss_with_grad)
loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,))
loss, grad = loss_with_grad(params, statics, x, y)
loss, grad = jax.value_and_grad(CNN.loss2)(params, statics, x, y)
updates, opt_state = optim.update(grad, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
@ -106,7 +94,7 @@ class CNN(eqx.Module):
@staticmethod
def train(
model,
model: ForwardRef('CNN'),
train_dl: torch.utils.data.DataLoader,
test_dl: torch.utils.data.DataLoader,
optim: optax.GradientTransformation,
@ -137,13 +125,13 @@ class CNN(eqx.Module):
@partial(jax.jit, static_argnums=(1))
def get_stats(params, statics, x, y):
def get_stats(params: PyTree, statics: PyTree, x: np.ndarray, y: np.ndarray):
loss = CNN.loss2(params, statics, x, y)
acc = CNN.get_accuracy(params, statics, x, y)
return loss, acc
def evaluate_model(params, statics, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
avg_loss = 0
avg_acc = 0