From 3b062bc31bdc0cdcfc7b52985b0f8d783f8ddc1c Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Mon, 16 Dec 2024 20:56:41 -0700 Subject: [PATCH] Main: increase jit use --- src/main.py | 64 ++++++++++++++++++++++++----------------------------- 1 file changed, 29 insertions(+), 35 deletions(-) diff --git a/src/main.py b/src/main.py index 159fb2a..2d3be5a 100644 --- a/src/main.py +++ b/src/main.py @@ -85,25 +85,23 @@ class CNN(eqx.Module): return jnp.mean(y == pred_y) @staticmethod + @partial(jax.jit, static_argnums=(1,3)) def make_step( - model, + params, + statics, opt_state: PyTree, optim: optax.GradientTransformation, x: Float[Array, "batch 1 28 28"], y: Float[Array, "batch"], ): - params, statics = eqx.partition(model, eqx.is_array) - - loss_with_grad = lambda p, s, x, y: model.loss(eqx.combine(p,s), x, y) + loss_with_grad = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y) loss_with_grad = jax.value_and_grad(loss_with_grad) loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,)) loss, grad = loss_with_grad(params, statics, x, y) - updates, opt_state = optim.update( - grad, opt_state, eqx.filter(model, eqx.is_array) - ) - model = eqx.apply_updates(model, updates) - return model, opt_state, loss + updates, opt_state = optim.update(grad, opt_state, params) + params = optax.apply_updates(params, updates) + return params, opt_state, loss @staticmethod @@ -115,35 +113,46 @@ class CNN(eqx.Module): steps: int, log_every: int, ): - opt_state = optim.init(eqx.filter(model, eqx.is_array)) - def infinite_dl(dl): while True: yield from dl + params, statics = eqx.partition(model, eqx.is_array) + opt_state = optim.init(params) + for step, (x, y) in zip(range(steps), infinite_dl(train_dl)): x = x.numpy() y = y.numpy() - model, opt_state, loss = model.make_step(model, opt_state, optim, x, y) + + params, opt_state, loss = CNN.make_step( + params, statics, opt_state, optim, x, y + ) if (step % log_every == 0) or (step == steps - 1): - test_loss, test_acc = evaluate_model(model, train_dl) - print(f"{step=}, train_loss={loss.item()}, test_loss={test_loss.item()}") - print(f"{step=}, train_loss={loss.item()}") + test_loss, test_acc = evaluate_model(params, statics, train_dl) + print(f"{step=}, train_loss={loss.item()}, accuracy={test_acc.item()}") + model = eqx.combine(params, statics) return model -def evaluate_model(model, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]: +@partial(jax.jit, static_argnums=(1)) +def get_stats(params, statics, x, y): + loss = CNN.loss2(params, statics, x, y) + acc = CNN.get_accuracy(params, statics, x, y) + return loss, acc + + +def evaluate_model(params, statics, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]: avg_loss = 0 avg_acc = 0 for x, y in test_dl: x = x.numpy() y = y.numpy() - params, statics = eqx.partition(model, eqx.is_array) - avg_loss += CNN.loss2(params, statics, x, y) - avg_acc += CNN.get_accuracy(params, statics, x, y) + loss, acc = get_stats(params, statics, x, y) + avg_loss += loss + avg_acc += acc avg_loss /= len(test_dl) avg_acc /= len(test_dl) @@ -184,23 +193,8 @@ if __name__ == '__main__': key = jax.random.PRNGKey(SEED) key, key1 = jax.random.split(key, 2) + model = CNN(key1) optim = optax.adamw(LEARNING_RATE) - dummy_x, dummy_y = next(iter(train_dl)) - dummy_x = dummy_x.numpy() # 64x1x28x28 - dummy_y = dummy_y.numpy() # 64 - - print(jax.vmap(model)(dummy_x).shape) - print(CNN.loss(model, dummy_x, dummy_y)) - - params, statics = eqx.partition(model, eqx.is_array) - loss2 = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y) - model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10) - - #print(evaluate_model(model, test_dl)) - #loss_w_grad = jax.value_and_grad(loss2) - #l, g = loss_w_grad(params, statics, dummy_x, dummy_y) - #print(type(l)) - #print(type(g))