diff --git a/src/main.py b/src/main.py index 2d3be5a..7526c74 100644 --- a/src/main.py +++ b/src/main.py @@ -3,12 +3,13 @@ import equinox as eqx import jax import jax.numpy as jnp +import numpy as np import optax import torch import torchvision from functools import partial from jaxtyping import Array, Float, Int, PyTree -from typing import Tuple +from typing import Tuple, ForwardRef DATA_ROOT="data" @@ -50,20 +51,11 @@ class CNN(eqx.Module): pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) return -jnp.mean(pred_y) - @staticmethod - def loss( - model, - x: Float[Array, "batch 1 28 28"], - y: Int[Array, " batch"] - ) -> Float[Array, ""]: - pred_y = jax.vmap(model)(x) - return CNN.cross_entropy(y, pred_y) - @staticmethod @partial(jax.jit, static_argnums=(1,)) def loss2( - params, - statics, + params: PyTree, + statics: PyTree, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] ) -> Float[Array, ""]: @@ -74,8 +66,8 @@ class CNN(eqx.Module): @staticmethod @partial(jax.jit, static_argnums=(1,)) def get_accuracy( - params, - statics, + params: PyTree, + statics: PyTree, x: Float[Array, "batch 1 28 28"], y: Float[Array, "batch"], ) -> Float[Array, ""]: @@ -87,18 +79,14 @@ class CNN(eqx.Module): @staticmethod @partial(jax.jit, static_argnums=(1,3)) def make_step( - params, - statics, + params: PyTree, + statics: PyTree, opt_state: PyTree, optim: optax.GradientTransformation, x: Float[Array, "batch 1 28 28"], y: Float[Array, "batch"], ): - loss_with_grad = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y) - loss_with_grad = jax.value_and_grad(loss_with_grad) - loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,)) - - loss, grad = loss_with_grad(params, statics, x, y) + loss, grad = jax.value_and_grad(CNN.loss2)(params, statics, x, y) updates, opt_state = optim.update(grad, opt_state, params) params = optax.apply_updates(params, updates) return params, opt_state, loss @@ -106,7 +94,7 @@ class CNN(eqx.Module): @staticmethod def train( - model, + model: ForwardRef('CNN'), train_dl: torch.utils.data.DataLoader, test_dl: torch.utils.data.DataLoader, optim: optax.GradientTransformation, @@ -137,13 +125,13 @@ class CNN(eqx.Module): @partial(jax.jit, static_argnums=(1)) -def get_stats(params, statics, x, y): +def get_stats(params: PyTree, statics: PyTree, x: np.ndarray, y: np.ndarray): loss = CNN.loss2(params, statics, x, y) acc = CNN.get_accuracy(params, statics, x, y) return loss, acc -def evaluate_model(params, statics, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]: +def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]: avg_loss = 0 avg_acc = 0