Main: improve typing

This commit is contained in:
Akemi Izuko 2024-12-16 21:13:12 -07:00
parent 3b062bc31b
commit 59cf5aea85
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -3,12 +3,13 @@
import equinox as eqx import equinox as eqx
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
import optax import optax
import torch import torch
import torchvision import torchvision
from functools import partial from functools import partial
from jaxtyping import Array, Float, Int, PyTree from jaxtyping import Array, Float, Int, PyTree
from typing import Tuple from typing import Tuple, ForwardRef
DATA_ROOT="data" DATA_ROOT="data"
@ -50,20 +51,11 @@ class CNN(eqx.Module):
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
return -jnp.mean(pred_y) return -jnp.mean(pred_y)
@staticmethod
def loss(
model,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"]
) -> Float[Array, ""]:
pred_y = jax.vmap(model)(x)
return CNN.cross_entropy(y, pred_y)
@staticmethod @staticmethod
@partial(jax.jit, static_argnums=(1,)) @partial(jax.jit, static_argnums=(1,))
def loss2( def loss2(
params, params: PyTree,
statics, statics: PyTree,
x: Float[Array, "batch 1 28 28"], x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"] y: Int[Array, " batch"]
) -> Float[Array, ""]: ) -> Float[Array, ""]:
@ -74,8 +66,8 @@ class CNN(eqx.Module):
@staticmethod @staticmethod
@partial(jax.jit, static_argnums=(1,)) @partial(jax.jit, static_argnums=(1,))
def get_accuracy( def get_accuracy(
params, params: PyTree,
statics, statics: PyTree,
x: Float[Array, "batch 1 28 28"], x: Float[Array, "batch 1 28 28"],
y: Float[Array, "batch"], y: Float[Array, "batch"],
) -> Float[Array, ""]: ) -> Float[Array, ""]:
@ -87,18 +79,14 @@ class CNN(eqx.Module):
@staticmethod @staticmethod
@partial(jax.jit, static_argnums=(1,3)) @partial(jax.jit, static_argnums=(1,3))
def make_step( def make_step(
params, params: PyTree,
statics, statics: PyTree,
opt_state: PyTree, opt_state: PyTree,
optim: optax.GradientTransformation, optim: optax.GradientTransformation,
x: Float[Array, "batch 1 28 28"], x: Float[Array, "batch 1 28 28"],
y: Float[Array, "batch"], y: Float[Array, "batch"],
): ):
loss_with_grad = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y) loss, grad = jax.value_and_grad(CNN.loss2)(params, statics, x, y)
loss_with_grad = jax.value_and_grad(loss_with_grad)
loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,))
loss, grad = loss_with_grad(params, statics, x, y)
updates, opt_state = optim.update(grad, opt_state, params) updates, opt_state = optim.update(grad, opt_state, params)
params = optax.apply_updates(params, updates) params = optax.apply_updates(params, updates)
return params, opt_state, loss return params, opt_state, loss
@ -106,7 +94,7 @@ class CNN(eqx.Module):
@staticmethod @staticmethod
def train( def train(
model, model: ForwardRef('CNN'),
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
test_dl: torch.utils.data.DataLoader, test_dl: torch.utils.data.DataLoader,
optim: optax.GradientTransformation, optim: optax.GradientTransformation,
@ -137,13 +125,13 @@ class CNN(eqx.Module):
@partial(jax.jit, static_argnums=(1)) @partial(jax.jit, static_argnums=(1))
def get_stats(params, statics, x, y): def get_stats(params: PyTree, statics: PyTree, x: np.ndarray, y: np.ndarray):
loss = CNN.loss2(params, statics, x, y) loss = CNN.loss2(params, statics, x, y)
acc = CNN.get_accuracy(params, statics, x, y) acc = CNN.get_accuracy(params, statics, x, y)
return loss, acc return loss, acc
def evaluate_model(params, statics, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]: def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
avg_loss = 0 avg_loss = 0
avg_acc = 0 avg_acc = 0