diff --git a/resnet_cifar10.py b/resnet_cifar10.py index 821c0e0..e1f217d 100755 --- a/resnet_cifar10.py +++ b/resnet_cifar10.py @@ -13,113 +13,158 @@ import torchvision from jaxtyping import Array, Float, Int, PyTree -class CNN(eqx.Module): - layers: list +# class CNN(eqx.Module): +# layers: list +# +# def __init__(self, key): +# keys = jax.random.split(key, 4) +# keys = list(keys) +# +# self.layers = [ +# eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]), +# eqx.nn.MaxPool2d(kernel_size=2), +# jax.nn.relu, +# jnp.ravel, +# eqx.nn.Linear(1728, 512, key=keys[1]), +# jax.nn.sigmoid, +# eqx.nn.Linear(512, 64, key=keys[2]), +# jax.nn.relu, +# eqx.nn.Linear(64, 10, key=keys[3]), +# jax.nn.log_softmax, +# ] +# +# def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]: +# for layer in self.layers: +# x = layer(x) +# return x +# +# +# def cross_entropy( +# y: Int[Array, " batch"], pred_y: Int[Array, " batch"] +# ) -> Float[Array, ""]: +# pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) +# return -jnp.mean(pred_y) +# +# +# @eqx.filter_jit +# def loss( +# model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] +# ) -> Float[Array, ""]: +# pred_y = jax.vmap(model)(x) +# return cross_entropy(y, pred_y) +# +# +# @eqx.filter_jit +# def compute_accuracy( +# model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] +# ) -> Float[Array, ""]: +# pred_y = jax.vmap(model)(x) +# pred_y = jnp.argmax(pred_y, axis=1) +# return jnp.mean(y == pred_y) +# +# +# def evaluate(model: CNN, testloader: torch.utils.data.DataLoader): +# avg_loss = 0 +# avg_acc = 0 +# +# for x, y in testloader: +# x = jnp.array(x.numpy()) +# y = jnp.array(y.numpy()) +# +# avg_loss += loss(model, x, y) +# avg_acc += compute_accuracy(model, x, y) +# +# return avg_loss / len(testloader), avg_acc / len(testloader) +# +# +# def train( +# model: CNN, +# trainloader: torch.utils.data.DataLoader, +# testloader: torch.utils.data.DataLoader, +# optim: optax.GradientTransformation, +# steps: int, +# print_every: int, +# ) -> CNN: +# @eqx.filter_jit +# def make_step( +# model: CNN, +# opt_state: PyTree, +# x: Float[Array, "batch 1 28 28"], +# y: Int[Array, "batch"], +# ): +# loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y) +# updates, opt_state = optim.update( +# grads, opt_state, eqx.filter(model, eqx.is_array) +# ) +# model = eqx.apply_updates(model, updates) +# return model, opt_state, loss_value +# +# def infinite_data(loader: torch.utils.data.DataLoader): +# while True: +# yield from loader # Yields from loader until exhausted +# +# opt_state = optim.init(eqx.filter(model, eqx.is_array)) +# +# for step, (x, y) in zip(range(steps), infinite_data(trainloader)): +# x = jnp.array(x.numpy()) +# y = jnp.array(y.numpy()) +# +# model, opt_state, train_loss = make_step(model, opt_state, x, y) +# +# if (step % print_every) == 0 or step == steps - 1: +# avg_loss, avg_acc = evaluate(model, testloader) +# +# jax.debug.print("==== step {} ====", step) +# jax.debug.print("train loss = {}", train_loss) +# jax.debug.print("test loss = {}", avg_loss) +# jax.debug.print("text accuracy = {}", avg_acc) +# +# return model - def __init__(self, key): - keys = jax.random.split(key, 4) - keys = list(keys) +class HParams(): + nb_classes: int + is_bottleneck: bool - self.layers = [ - eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]), - eqx.nn.MaxPool2d(kernel_size=2), - jax.nn.relu, - jnp.ravel, - eqx.nn.Linear(1728, 512, key=keys[1]), - jax.nn.sigmoid, - eqx.nn.Linear(512, 64, key=keys[2]), - jax.nn.relu, - eqx.nn.Linear(64, 10, key=keys[3]), - jax.nn.log_softmax, - ] +class ResidualBlock(eqx.Module): + bn1: eqx.nn.BatchNorm - def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]: - for layer in self.layers: - x = layer(x) - return x + def __init__(self, in_channels: int, out_channels: int, stride: int, key): + keys = jax.random.split(key, 2) + + self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, + eps=0.001) -def cross_entropy( - y: Int[Array, " batch"], pred_y: Int[Array, " batch"] -) -> Float[Array, ""]: - pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) - return -jnp.mean(pred_y) + +class ResNet(eqx.Module): + conv1: eqx.nn.Conv2d + bn1: eqx.nn.BatchNorm + layer1: ResidualBlock + layer2: ResidualBlock + layer3: ResidualBlock + linear: eqx.nn.Linear + hps: HParams -@eqx.filter_jit -def loss( - model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] -) -> Float[Array, ""]: - pred_y = jax.vmap(model)(x) - return cross_entropy(y, pred_y) + self __init__(self, hps: HParams): + self.hps = hps + keys = jax.random.split(key, 5) + self.conv1 = eqx.nn.Conv2d(3, 16, kernel_size=3, padding=1, key=keys[0]) + self.bn1 = eqx.nn.BatchNorm(16, "batch", momentum=0.9, eps=0.001) -@eqx.filter_jit -def compute_accuracy( - model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] -) -> Float[Array, ""]: - pred_y = jax.vmap(model)(x) - pred_y = jnp.argmax(pred_y, axis=1) - return jnp.mean(y == pred_y) + if hps.is_bottleneck: + res_func = BottleneckBlock + filters = [16, 64, 128, 256] + else: + res_func = ResidualBlock + filters = [16, 16, 32, 64] + self.layer1 = [] + self.layer2 = [] + self.layer3 = [] -def evaluate(model: CNN, testloader: torch.utils.data.DataLoader): - avg_loss = 0 - avg_acc = 0 - - for x, y in testloader: - x = jnp.array(x.numpy()) - y = jnp.array(y.numpy()) - - avg_loss += loss(model, x, y) - avg_acc += compute_accuracy(model, x, y) - - return avg_loss / len(testloader), avg_acc / len(testloader) - - -def train( - model: CNN, - trainloader: torch.utils.data.DataLoader, - testloader: torch.utils.data.DataLoader, - optim: optax.GradientTransformation, - steps: int, - print_every: int, -) -> CNN: - @eqx.filter_jit - def make_step( - model: CNN, - opt_state: PyTree, - x: Float[Array, "batch 1 28 28"], - y: Int[Array, "batch"], - ): - loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y) - updates, opt_state = optim.update( - grads, opt_state, eqx.filter(model, eqx.is_array) - ) - model = eqx.apply_updates(model, updates) - return model, opt_state, loss_value - - def infinite_data(loader: torch.utils.data.DataLoader): - while True: - yield from loader # Yields from loader until exhausted - - opt_state = optim.init(eqx.filter(model, eqx.is_array)) - - for step, (x, y) in zip(range(steps), infinite_data(trainloader)): - x = jnp.array(x.numpy()) - y = jnp.array(y.numpy()) - - model, opt_state, train_loss = make_step(model, opt_state, x, y) - - if (step % print_every) == 0 or step == steps - 1: - avg_loss, avg_acc = evaluate(model, testloader) - - jax.debug.print("==== step {} ====", step) - jax.debug.print("train loss = {}", train_loss) - jax.debug.print("test loss = {}", avg_loss) - jax.debug.print("text accuracy = {}", avg_acc) - - return model + self.linear = eqx.nn.Linear(filters[3], hps.nb_classes, key=keys[4]) def build_dataloader(is_train): @@ -166,7 +211,6 @@ def build_dataloader(is_train): return DataLoaderWrapper(dataloader, 10) - # ╔─────────────────────────────────────────────────────────────────────────────╗ # │ Main script | # ╚─────────────────────────────────────────────────────────────────────────────╝ @@ -189,7 +233,8 @@ SEED = 5678 key = jax.random.PRNGKey(SEED) key, subkey = jax.random.split(key, 2) -dataloader = build_dataloader(False) +train_loader = build_dataloader(False) +test_loader = build_dataloader(False) print(dataloader) @@ -197,12 +242,7 @@ x = next(iter(dataloader)) print(type(x), len(x)) print(type(x[0]), type(x[1])) print(x[0].shape, x[1].shape) -# x[1] = jnp.array(x[1]) -# print(f"Max: {jnp.max(x[1])}") -# print(f"Min: {jnp.min(x[1])}") -# print(f"Mean: {jnp.mean(x[1])}") print(f"First: {x[0][0, 0]}") -# print(f"1hot: {jax.nn.one_hot(x[1][0], 10)}") exit(1)