#!/usr/bin/env python3 """ This is the CNN tutorial from https://docs.kidger.site/equinox/examples/mnist/, just using it to learn equinox """ import equinox as eqx import jax.numpy as jnp import jax import optax import time import torch import torchvision from jaxtyping import Array, Float, Int, PyTree # class CNN(eqx.Module): # layers: list # # def __init__(self, key): # keys = jax.random.split(key, 4) # keys = list(keys) # # self.layers = [ # eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]), # eqx.nn.MaxPool2d(kernel_size=2), # jax.nn.relu, # jnp.ravel, # eqx.nn.Linear(1728, 512, key=keys[1]), # jax.nn.sigmoid, # eqx.nn.Linear(512, 64, key=keys[2]), # jax.nn.relu, # eqx.nn.Linear(64, 10, key=keys[3]), # jax.nn.log_softmax, # ] # # def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]: # for layer in self.layers: # x = layer(x) # return x # # # def cross_entropy( # y: Int[Array, " batch"], pred_y: Int[Array, " batch"] # ) -> Float[Array, ""]: # pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) # return -jnp.mean(pred_y) # # # @eqx.filter_jit # def loss( # model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] # ) -> Float[Array, ""]: # pred_y = jax.vmap(model)(x) # return cross_entropy(y, pred_y) # # # @eqx.filter_jit # def compute_accuracy( # model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] # ) -> Float[Array, ""]: # pred_y = jax.vmap(model)(x) # pred_y = jnp.argmax(pred_y, axis=1) # return jnp.mean(y == pred_y) # # # def evaluate(model: CNN, testloader: torch.utils.data.DataLoader): # avg_loss = 0 # avg_acc = 0 # # for x, y in testloader: # x = jnp.array(x.numpy()) # y = jnp.array(y.numpy()) # # avg_loss += loss(model, x, y) # avg_acc += compute_accuracy(model, x, y) # # return avg_loss / len(testloader), avg_acc / len(testloader) # # # def train( # model: CNN, # trainloader: torch.utils.data.DataLoader, # testloader: torch.utils.data.DataLoader, # optim: optax.GradientTransformation, # steps: int, # print_every: int, # ) -> CNN: # @eqx.filter_jit # def make_step( # model: CNN, # opt_state: PyTree, # x: Float[Array, "batch 1 28 28"], # y: Int[Array, "batch"], # ): # loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y) # updates, opt_state = optim.update( # grads, opt_state, eqx.filter(model, eqx.is_array) # ) # model = eqx.apply_updates(model, updates) # return model, opt_state, loss_value # # def infinite_data(loader: torch.utils.data.DataLoader): # while True: # yield from loader # Yields from loader until exhausted # # opt_state = optim.init(eqx.filter(model, eqx.is_array)) # # for step, (x, y) in zip(range(steps), infinite_data(trainloader)): # x = jnp.array(x.numpy()) # y = jnp.array(y.numpy()) # # model, opt_state, train_loss = make_step(model, opt_state, x, y) # # if (step % print_every) == 0 or step == steps - 1: # avg_loss, avg_acc = evaluate(model, testloader) # # jax.debug.print("==== step {} ====", step) # jax.debug.print("train loss = {}", train_loss) # jax.debug.print("test loss = {}", avg_loss) # jax.debug.print("text accuracy = {}", avg_acc) # # return model class HParams(): nb_classes: int is_bottleneck: bool class ResidualBlock(eqx.Module): bn1: eqx.nn.BatchNorm bn2: eqx.nn.BatchNorm conv1: eqx.nn.Conv2d conv2: eqx.nn.Conv2d avg_pool: eqx.nn.AvgPool2d stride: int relu_leak: float def __init__(self, in_channels: int, out_channels: int, stride: int, relu_leak: float, is_b4_res: bool, key): self.stride = stride self.relu_leak = relu_leak self.is_b4_res = is_b4_res keys = jax.random.split(key, 2) self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, eps=0.001) self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, eps=0.001) # TODO: is bn2 in_channels correct? self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, key=keys[0]) self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3, key=keys[1]) if stride != 1 or in_channels != out_channels: self.avg_pool = eqx.nn.AvgPool2d(stride, stride=stride) # TODO: padding might be wrong... else: self.avg_pool = None def __call__( self, x: Float[Array, "batch channels w h"] ) -> Float[Array, "batch channels w h"]: if self.is_b4_res: x = self.bn1(x) x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) orig_x = x else: orig_x = x x = self.bn1(x) x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) x = self.conv1(x) x = self.bn2(x) x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) x = self.conv2(x) if self.avg_pool is not None: orig_x = self.avg_pool(orig_x) x += orig_x return x class BottleneckBlock(eqx.Module): bn1: eqx.nn.BatchNorm bn2: eqx.nn.BatchNorm bn3: eqx.nn.BatchNorm conv1: eqx.nn.Conv2d conv2: eqx.nn.Conv2d conv3: eqx.nn.Conv2d project_conv: eqx.nn.Conv2d relu_leak: float is_b4_res: bool def __init__(self, in_channels: int, out_channels: int, stride: int, relu_leak: float, is_b4_res: bool, key): self.stride = stride self.relu_leak = relu_leak self.is_b4_res = is_b4_res keys = jax.random.split(key, 4) # TODO: channels might be wrong for 2 and 3 self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, eps=0.001) self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, eps=0.001) self.bn3 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, eps=0.001) # TODO: /4 seems like it wouldn't work, maybe // or ceil? # TODO: what's up with the lack of padding? # TODO: does sride make sense for conv 2&3? mid_channels = out_channels / 4 self.conv1 = eqx.nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=stride, key=keys[0]) self.conv2 = eqx.nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=[1,1,1,1], key=keys[1]) self.conv3 = eqx.nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=[1,1,1,1], key=keys[2]) if in_channels != out_channels: self.project = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, key=keys[3]) else: self.project = None def __call__( self, x: Float[Array, "batch channels w h"] ) -> Float[Array, "batch channels w h"]: if self.is_b4_res: x = self.bn1(x) x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) orig_x = x else: orig_x = x x = self.bn1(x) x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) x = self.conv1(x) x = self.bn2(x) x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) x = self.conv2(x) x = self.bn3(x) x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak) x = self.conv3(x) if self.project: orig_x = self.project(orig_x) x += orig_x return x class ResNet(eqx.Module): conv1: eqx.nn.Conv2d bn1: eqx.nn.BatchNorm layer1: ResidualBlock layer2: ResidualBlock layer3: ResidualBlock linear: eqx.nn.Linear hps: HParams def __init__(self, hps: HParams): self.hps = hps keys = jax.random.split(key, 5) self.conv1 = eqx.nn.Conv2d(3, 16, kernel_size=3, padding=1, key=keys[0]) self.bn1 = eqx.nn.BatchNorm(16, "batch", momentum=0.9, eps=0.001) if hps.is_bottleneck: res_func = BottleneckBlock filters = [16, 64, 128, 256] else: res_func = ResidualBlock filters = [16, 16, 32, 64] self.layer1 = [] self.layer2 = [] self.layer3 = [] self.linear = eqx.nn.Linear(filters[3], hps.nb_classes, key=keys[4]) def build_dataloader(is_train): global BATCH_SIZE transform_train = torchvision.transforms.Compose([ torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) ]) transform_test = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) ]) dataset = torchvision.datasets.CIFAR10( "data", train=is_train, download=True, transform=(transform_train if is_train else transform_test) ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True ) class DataLoaderWrapper: def __init__(self, dataloader, nb_classes): self.dataloader = dataloader self.nb_classes = nb_classes def __iter__(self): for images, labels in self.dataloader: images = jnp.array(images) labels = jnp.array(labels) labels = jax.nn.one_hot(labels, 10) yield (images, labels) return DataLoaderWrapper(dataloader, 10) # ╔─────────────────────────────────────────────────────────────────────────────╗ # │ Main script | # ╚─────────────────────────────────────────────────────────────────────────────╝ jax.config.update("jax_platform_name", "gpu") # Sets preferred device # Checking to make sure gpu is being used from jax.extend import backend print(backend.get_backend().platform) print(f"JAX devices: {jax.devices()}") print(f"Default device: {jax.default_backend()}") # Hyperparameters BATCH_SIZE = 16 LEARNING_RATE = 1e-4 STEPS = 1200 PRINT_EVERY = 300 SEED = 5678 key = jax.random.PRNGKey(SEED) key, subkey = jax.random.split(key, 2) train_loader = build_dataloader(False) test_loader = build_dataloader(False) print(train_loader) x = next(iter(train_loader)) print(type(x), len(x)) print(type(x[0]), type(x[1])) print(x[0].shape, x[1].shape) print(f"First: {x[0][0, 0]}") exit(1) # model = CNN(subkey) # optim = optax.adamw(LEARNING_RATE) # # start = time.time() # model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY) # cease = time.time() print(f"Took {cease-start}s")