Main: improve typing
This commit is contained in:
parent
3b062bc31b
commit
59cf5aea85
1 changed files with 12 additions and 24 deletions
36
src/main.py
36
src/main.py
|
@ -3,12 +3,13 @@
|
||||||
import equinox as eqx
|
import equinox as eqx
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from jaxtyping import Array, Float, Int, PyTree
|
from jaxtyping import Array, Float, Int, PyTree
|
||||||
from typing import Tuple
|
from typing import Tuple, ForwardRef
|
||||||
|
|
||||||
|
|
||||||
DATA_ROOT="data"
|
DATA_ROOT="data"
|
||||||
|
@ -50,20 +51,11 @@ class CNN(eqx.Module):
|
||||||
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
|
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
|
||||||
return -jnp.mean(pred_y)
|
return -jnp.mean(pred_y)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def loss(
|
|
||||||
model,
|
|
||||||
x: Float[Array, "batch 1 28 28"],
|
|
||||||
y: Int[Array, " batch"]
|
|
||||||
) -> Float[Array, ""]:
|
|
||||||
pred_y = jax.vmap(model)(x)
|
|
||||||
return CNN.cross_entropy(y, pred_y)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@partial(jax.jit, static_argnums=(1,))
|
@partial(jax.jit, static_argnums=(1,))
|
||||||
def loss2(
|
def loss2(
|
||||||
params,
|
params: PyTree,
|
||||||
statics,
|
statics: PyTree,
|
||||||
x: Float[Array, "batch 1 28 28"],
|
x: Float[Array, "batch 1 28 28"],
|
||||||
y: Int[Array, " batch"]
|
y: Int[Array, " batch"]
|
||||||
) -> Float[Array, ""]:
|
) -> Float[Array, ""]:
|
||||||
|
@ -74,8 +66,8 @@ class CNN(eqx.Module):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@partial(jax.jit, static_argnums=(1,))
|
@partial(jax.jit, static_argnums=(1,))
|
||||||
def get_accuracy(
|
def get_accuracy(
|
||||||
params,
|
params: PyTree,
|
||||||
statics,
|
statics: PyTree,
|
||||||
x: Float[Array, "batch 1 28 28"],
|
x: Float[Array, "batch 1 28 28"],
|
||||||
y: Float[Array, "batch"],
|
y: Float[Array, "batch"],
|
||||||
) -> Float[Array, ""]:
|
) -> Float[Array, ""]:
|
||||||
|
@ -87,18 +79,14 @@ class CNN(eqx.Module):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@partial(jax.jit, static_argnums=(1,3))
|
@partial(jax.jit, static_argnums=(1,3))
|
||||||
def make_step(
|
def make_step(
|
||||||
params,
|
params: PyTree,
|
||||||
statics,
|
statics: PyTree,
|
||||||
opt_state: PyTree,
|
opt_state: PyTree,
|
||||||
optim: optax.GradientTransformation,
|
optim: optax.GradientTransformation,
|
||||||
x: Float[Array, "batch 1 28 28"],
|
x: Float[Array, "batch 1 28 28"],
|
||||||
y: Float[Array, "batch"],
|
y: Float[Array, "batch"],
|
||||||
):
|
):
|
||||||
loss_with_grad = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y)
|
loss, grad = jax.value_and_grad(CNN.loss2)(params, statics, x, y)
|
||||||
loss_with_grad = jax.value_and_grad(loss_with_grad)
|
|
||||||
loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,))
|
|
||||||
|
|
||||||
loss, grad = loss_with_grad(params, statics, x, y)
|
|
||||||
updates, opt_state = optim.update(grad, opt_state, params)
|
updates, opt_state = optim.update(grad, opt_state, params)
|
||||||
params = optax.apply_updates(params, updates)
|
params = optax.apply_updates(params, updates)
|
||||||
return params, opt_state, loss
|
return params, opt_state, loss
|
||||||
|
@ -106,7 +94,7 @@ class CNN(eqx.Module):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def train(
|
def train(
|
||||||
model,
|
model: ForwardRef('CNN'),
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
test_dl: torch.utils.data.DataLoader,
|
test_dl: torch.utils.data.DataLoader,
|
||||||
optim: optax.GradientTransformation,
|
optim: optax.GradientTransformation,
|
||||||
|
@ -137,13 +125,13 @@ class CNN(eqx.Module):
|
||||||
|
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(1))
|
@partial(jax.jit, static_argnums=(1))
|
||||||
def get_stats(params, statics, x, y):
|
def get_stats(params: PyTree, statics: PyTree, x: np.ndarray, y: np.ndarray):
|
||||||
loss = CNN.loss2(params, statics, x, y)
|
loss = CNN.loss2(params, statics, x, y)
|
||||||
acc = CNN.get_accuracy(params, statics, x, y)
|
acc = CNN.get_accuracy(params, statics, x, y)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model(params, statics, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
|
def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
|
||||||
avg_loss = 0
|
avg_loss = 0
|
||||||
avg_acc = 0
|
avg_acc = 0
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue