Main: increase jit use

This commit is contained in:
Akemi Izuko 2024-12-16 20:56:41 -07:00
parent a9700e22bb
commit 3b062bc31b
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -85,25 +85,23 @@ class CNN(eqx.Module):
return jnp.mean(y == pred_y)
@staticmethod
@partial(jax.jit, static_argnums=(1,3))
def make_step(
model,
params,
statics,
opt_state: PyTree,
optim: optax.GradientTransformation,
x: Float[Array, "batch 1 28 28"],
y: Float[Array, "batch"],
):
params, statics = eqx.partition(model, eqx.is_array)
loss_with_grad = lambda p, s, x, y: model.loss(eqx.combine(p,s), x, y)
loss_with_grad = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y)
loss_with_grad = jax.value_and_grad(loss_with_grad)
loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,))
loss, grad = loss_with_grad(params, statics, x, y)
updates, opt_state = optim.update(
grad, opt_state, eqx.filter(model, eqx.is_array)
)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
updates, opt_state = optim.update(grad, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
@staticmethod
@ -115,35 +113,46 @@ class CNN(eqx.Module):
steps: int,
log_every: int,
):
opt_state = optim.init(eqx.filter(model, eqx.is_array))
def infinite_dl(dl):
while True:
yield from dl
params, statics = eqx.partition(model, eqx.is_array)
opt_state = optim.init(params)
for step, (x, y) in zip(range(steps), infinite_dl(train_dl)):
x = x.numpy()
y = y.numpy()
model, opt_state, loss = model.make_step(model, opt_state, optim, x, y)
params, opt_state, loss = CNN.make_step(
params, statics, opt_state, optim, x, y
)
if (step % log_every == 0) or (step == steps - 1):
test_loss, test_acc = evaluate_model(model, train_dl)
print(f"{step=}, train_loss={loss.item()}, test_loss={test_loss.item()}")
print(f"{step=}, train_loss={loss.item()}")
test_loss, test_acc = evaluate_model(params, statics, train_dl)
print(f"{step=}, train_loss={loss.item()}, accuracy={test_acc.item()}")
model = eqx.combine(params, statics)
return model
def evaluate_model(model, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
@partial(jax.jit, static_argnums=(1))
def get_stats(params, statics, x, y):
loss = CNN.loss2(params, statics, x, y)
acc = CNN.get_accuracy(params, statics, x, y)
return loss, acc
def evaluate_model(params, statics, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
avg_loss = 0
avg_acc = 0
for x, y in test_dl:
x = x.numpy()
y = y.numpy()
params, statics = eqx.partition(model, eqx.is_array)
avg_loss += CNN.loss2(params, statics, x, y)
avg_acc += CNN.get_accuracy(params, statics, x, y)
loss, acc = get_stats(params, statics, x, y)
avg_loss += loss
avg_acc += acc
avg_loss /= len(test_dl)
avg_acc /= len(test_dl)
@ -184,23 +193,8 @@ if __name__ == '__main__':
key = jax.random.PRNGKey(SEED)
key, key1 = jax.random.split(key, 2)
model = CNN(key1)
optim = optax.adamw(LEARNING_RATE)
dummy_x, dummy_y = next(iter(train_dl))
dummy_x = dummy_x.numpy() # 64x1x28x28
dummy_y = dummy_y.numpy() # 64
print(jax.vmap(model)(dummy_x).shape)
print(CNN.loss(model, dummy_x, dummy_y))
params, statics = eqx.partition(model, eqx.is_array)
loss2 = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y)
model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10)
#print(evaluate_model(model, test_dl))
#loss_w_grad = jax.value_and_grad(loss2)
#l, g = loss_w_grad(params, statics, dummy_x, dummy_y)
#print(type(l))
#print(type(g))