Main: improve typing
This commit is contained in:
parent
3b062bc31b
commit
59cf5aea85
1 changed files with 12 additions and 24 deletions
36
src/main.py
36
src/main.py
|
@ -3,12 +3,13 @@
|
|||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import torch
|
||||
import torchvision
|
||||
from functools import partial
|
||||
from jaxtyping import Array, Float, Int, PyTree
|
||||
from typing import Tuple
|
||||
from typing import Tuple, ForwardRef
|
||||
|
||||
|
||||
DATA_ROOT="data"
|
||||
|
@ -50,20 +51,11 @@ class CNN(eqx.Module):
|
|||
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
|
||||
return -jnp.mean(pred_y)
|
||||
|
||||
@staticmethod
|
||||
def loss(
|
||||
model,
|
||||
x: Float[Array, "batch 1 28 28"],
|
||||
y: Int[Array, " batch"]
|
||||
) -> Float[Array, ""]:
|
||||
pred_y = jax.vmap(model)(x)
|
||||
return CNN.cross_entropy(y, pred_y)
|
||||
|
||||
@staticmethod
|
||||
@partial(jax.jit, static_argnums=(1,))
|
||||
def loss2(
|
||||
params,
|
||||
statics,
|
||||
params: PyTree,
|
||||
statics: PyTree,
|
||||
x: Float[Array, "batch 1 28 28"],
|
||||
y: Int[Array, " batch"]
|
||||
) -> Float[Array, ""]:
|
||||
|
@ -74,8 +66,8 @@ class CNN(eqx.Module):
|
|||
@staticmethod
|
||||
@partial(jax.jit, static_argnums=(1,))
|
||||
def get_accuracy(
|
||||
params,
|
||||
statics,
|
||||
params: PyTree,
|
||||
statics: PyTree,
|
||||
x: Float[Array, "batch 1 28 28"],
|
||||
y: Float[Array, "batch"],
|
||||
) -> Float[Array, ""]:
|
||||
|
@ -87,18 +79,14 @@ class CNN(eqx.Module):
|
|||
@staticmethod
|
||||
@partial(jax.jit, static_argnums=(1,3))
|
||||
def make_step(
|
||||
params,
|
||||
statics,
|
||||
params: PyTree,
|
||||
statics: PyTree,
|
||||
opt_state: PyTree,
|
||||
optim: optax.GradientTransformation,
|
||||
x: Float[Array, "batch 1 28 28"],
|
||||
y: Float[Array, "batch"],
|
||||
):
|
||||
loss_with_grad = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y)
|
||||
loss_with_grad = jax.value_and_grad(loss_with_grad)
|
||||
loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,))
|
||||
|
||||
loss, grad = loss_with_grad(params, statics, x, y)
|
||||
loss, grad = jax.value_and_grad(CNN.loss2)(params, statics, x, y)
|
||||
updates, opt_state = optim.update(grad, opt_state, params)
|
||||
params = optax.apply_updates(params, updates)
|
||||
return params, opt_state, loss
|
||||
|
@ -106,7 +94,7 @@ class CNN(eqx.Module):
|
|||
|
||||
@staticmethod
|
||||
def train(
|
||||
model,
|
||||
model: ForwardRef('CNN'),
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
test_dl: torch.utils.data.DataLoader,
|
||||
optim: optax.GradientTransformation,
|
||||
|
@ -137,13 +125,13 @@ class CNN(eqx.Module):
|
|||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1))
|
||||
def get_stats(params, statics, x, y):
|
||||
def get_stats(params: PyTree, statics: PyTree, x: np.ndarray, y: np.ndarray):
|
||||
loss = CNN.loss2(params, statics, x, y)
|
||||
acc = CNN.get_accuracy(params, statics, x, y)
|
||||
return loss, acc
|
||||
|
||||
|
||||
def evaluate_model(params, statics, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
|
||||
def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
|
||||
avg_loss = 0
|
||||
avg_acc = 0
|
||||
|
||||
|
|
Loading…
Reference in a new issue