Main: increase jit use
This commit is contained in:
parent
a9700e22bb
commit
3b062bc31b
1 changed files with 29 additions and 35 deletions
64
src/main.py
64
src/main.py
|
@ -85,25 +85,23 @@ class CNN(eqx.Module):
|
||||||
return jnp.mean(y == pred_y)
|
return jnp.mean(y == pred_y)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@partial(jax.jit, static_argnums=(1,3))
|
||||||
def make_step(
|
def make_step(
|
||||||
model,
|
params,
|
||||||
|
statics,
|
||||||
opt_state: PyTree,
|
opt_state: PyTree,
|
||||||
optim: optax.GradientTransformation,
|
optim: optax.GradientTransformation,
|
||||||
x: Float[Array, "batch 1 28 28"],
|
x: Float[Array, "batch 1 28 28"],
|
||||||
y: Float[Array, "batch"],
|
y: Float[Array, "batch"],
|
||||||
):
|
):
|
||||||
params, statics = eqx.partition(model, eqx.is_array)
|
loss_with_grad = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y)
|
||||||
|
|
||||||
loss_with_grad = lambda p, s, x, y: model.loss(eqx.combine(p,s), x, y)
|
|
||||||
loss_with_grad = jax.value_and_grad(loss_with_grad)
|
loss_with_grad = jax.value_and_grad(loss_with_grad)
|
||||||
loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,))
|
loss_with_grad = jax.jit(loss_with_grad, static_argnums=(1,))
|
||||||
|
|
||||||
loss, grad = loss_with_grad(params, statics, x, y)
|
loss, grad = loss_with_grad(params, statics, x, y)
|
||||||
updates, opt_state = optim.update(
|
updates, opt_state = optim.update(grad, opt_state, params)
|
||||||
grad, opt_state, eqx.filter(model, eqx.is_array)
|
params = optax.apply_updates(params, updates)
|
||||||
)
|
return params, opt_state, loss
|
||||||
model = eqx.apply_updates(model, updates)
|
|
||||||
return model, opt_state, loss
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -115,35 +113,46 @@ class CNN(eqx.Module):
|
||||||
steps: int,
|
steps: int,
|
||||||
log_every: int,
|
log_every: int,
|
||||||
):
|
):
|
||||||
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
|
||||||
|
|
||||||
def infinite_dl(dl):
|
def infinite_dl(dl):
|
||||||
while True:
|
while True:
|
||||||
yield from dl
|
yield from dl
|
||||||
|
|
||||||
|
params, statics = eqx.partition(model, eqx.is_array)
|
||||||
|
opt_state = optim.init(params)
|
||||||
|
|
||||||
for step, (x, y) in zip(range(steps), infinite_dl(train_dl)):
|
for step, (x, y) in zip(range(steps), infinite_dl(train_dl)):
|
||||||
x = x.numpy()
|
x = x.numpy()
|
||||||
y = y.numpy()
|
y = y.numpy()
|
||||||
model, opt_state, loss = model.make_step(model, opt_state, optim, x, y)
|
|
||||||
|
params, opt_state, loss = CNN.make_step(
|
||||||
|
params, statics, opt_state, optim, x, y
|
||||||
|
)
|
||||||
|
|
||||||
if (step % log_every == 0) or (step == steps - 1):
|
if (step % log_every == 0) or (step == steps - 1):
|
||||||
test_loss, test_acc = evaluate_model(model, train_dl)
|
test_loss, test_acc = evaluate_model(params, statics, train_dl)
|
||||||
print(f"{step=}, train_loss={loss.item()}, test_loss={test_loss.item()}")
|
print(f"{step=}, train_loss={loss.item()}, accuracy={test_acc.item()}")
|
||||||
print(f"{step=}, train_loss={loss.item()}")
|
|
||||||
|
|
||||||
|
model = eqx.combine(params, statics)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model(model, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
|
@partial(jax.jit, static_argnums=(1))
|
||||||
|
def get_stats(params, statics, x, y):
|
||||||
|
loss = CNN.loss2(params, statics, x, y)
|
||||||
|
acc = CNN.get_accuracy(params, statics, x, y)
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(params, statics, test_dl: torch.utils.data.DataLoader) -> Tuple[Float, Float]:
|
||||||
avg_loss = 0
|
avg_loss = 0
|
||||||
avg_acc = 0
|
avg_acc = 0
|
||||||
|
|
||||||
for x, y in test_dl:
|
for x, y in test_dl:
|
||||||
x = x.numpy()
|
x = x.numpy()
|
||||||
y = y.numpy()
|
y = y.numpy()
|
||||||
params, statics = eqx.partition(model, eqx.is_array)
|
loss, acc = get_stats(params, statics, x, y)
|
||||||
avg_loss += CNN.loss2(params, statics, x, y)
|
avg_loss += loss
|
||||||
avg_acc += CNN.get_accuracy(params, statics, x, y)
|
avg_acc += acc
|
||||||
|
|
||||||
avg_loss /= len(test_dl)
|
avg_loss /= len(test_dl)
|
||||||
avg_acc /= len(test_dl)
|
avg_acc /= len(test_dl)
|
||||||
|
@ -184,23 +193,8 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
key = jax.random.PRNGKey(SEED)
|
key = jax.random.PRNGKey(SEED)
|
||||||
key, key1 = jax.random.split(key, 2)
|
key, key1 = jax.random.split(key, 2)
|
||||||
|
|
||||||
model = CNN(key1)
|
model = CNN(key1)
|
||||||
optim = optax.adamw(LEARNING_RATE)
|
optim = optax.adamw(LEARNING_RATE)
|
||||||
|
|
||||||
dummy_x, dummy_y = next(iter(train_dl))
|
|
||||||
dummy_x = dummy_x.numpy() # 64x1x28x28
|
|
||||||
dummy_y = dummy_y.numpy() # 64
|
|
||||||
|
|
||||||
print(jax.vmap(model)(dummy_x).shape)
|
|
||||||
print(CNN.loss(model, dummy_x, dummy_y))
|
|
||||||
|
|
||||||
params, statics = eqx.partition(model, eqx.is_array)
|
|
||||||
loss2 = lambda p, s, x, y: CNN.loss(eqx.combine(p,s), x, y)
|
|
||||||
|
|
||||||
model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10)
|
model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10)
|
||||||
|
|
||||||
#print(evaluate_model(model, test_dl))
|
|
||||||
#loss_w_grad = jax.value_and_grad(loss2)
|
|
||||||
#l, g = loss_w_grad(params, statics, dummy_x, dummy_y)
|
|
||||||
#print(type(l))
|
|
||||||
#print(type(g))
|
|
||||||
|
|
Loading…
Reference in a new issue