2024-11-12 18:40:08 -07:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
"""
|
|
|
|
This is the CNN tutorial from https://docs.kidger.site/equinox/examples/mnist/,
|
|
|
|
just using it to learn equinox
|
|
|
|
"""
|
|
|
|
import equinox as eqx
|
|
|
|
import jax.numpy as jnp
|
|
|
|
import jax
|
|
|
|
import optax
|
|
|
|
import time
|
|
|
|
import torch
|
|
|
|
import torchvision
|
|
|
|
from jaxtyping import Array, Float, Int, PyTree
|
|
|
|
|
|
|
|
|
2024-11-12 20:06:16 -07:00
|
|
|
# class CNN(eqx.Module):
|
|
|
|
# layers: list
|
|
|
|
#
|
|
|
|
# def __init__(self, key):
|
|
|
|
# keys = jax.random.split(key, 4)
|
|
|
|
# keys = list(keys)
|
|
|
|
#
|
|
|
|
# self.layers = [
|
|
|
|
# eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]),
|
|
|
|
# eqx.nn.MaxPool2d(kernel_size=2),
|
|
|
|
# jax.nn.relu,
|
|
|
|
# jnp.ravel,
|
|
|
|
# eqx.nn.Linear(1728, 512, key=keys[1]),
|
|
|
|
# jax.nn.sigmoid,
|
|
|
|
# eqx.nn.Linear(512, 64, key=keys[2]),
|
|
|
|
# jax.nn.relu,
|
|
|
|
# eqx.nn.Linear(64, 10, key=keys[3]),
|
|
|
|
# jax.nn.log_softmax,
|
|
|
|
# ]
|
|
|
|
#
|
|
|
|
# def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
|
|
|
|
# for layer in self.layers:
|
|
|
|
# x = layer(x)
|
|
|
|
# return x
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# def cross_entropy(
|
|
|
|
# y: Int[Array, " batch"], pred_y: Int[Array, " batch"]
|
|
|
|
# ) -> Float[Array, ""]:
|
|
|
|
# pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
|
|
|
|
# return -jnp.mean(pred_y)
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# @eqx.filter_jit
|
|
|
|
# def loss(
|
|
|
|
# model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
|
|
|
# ) -> Float[Array, ""]:
|
|
|
|
# pred_y = jax.vmap(model)(x)
|
|
|
|
# return cross_entropy(y, pred_y)
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# @eqx.filter_jit
|
|
|
|
# def compute_accuracy(
|
|
|
|
# model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
|
|
|
# ) -> Float[Array, ""]:
|
|
|
|
# pred_y = jax.vmap(model)(x)
|
|
|
|
# pred_y = jnp.argmax(pred_y, axis=1)
|
|
|
|
# return jnp.mean(y == pred_y)
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
|
|
|
|
# avg_loss = 0
|
|
|
|
# avg_acc = 0
|
|
|
|
#
|
|
|
|
# for x, y in testloader:
|
|
|
|
# x = jnp.array(x.numpy())
|
|
|
|
# y = jnp.array(y.numpy())
|
|
|
|
#
|
|
|
|
# avg_loss += loss(model, x, y)
|
|
|
|
# avg_acc += compute_accuracy(model, x, y)
|
|
|
|
#
|
|
|
|
# return avg_loss / len(testloader), avg_acc / len(testloader)
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# def train(
|
|
|
|
# model: CNN,
|
|
|
|
# trainloader: torch.utils.data.DataLoader,
|
|
|
|
# testloader: torch.utils.data.DataLoader,
|
|
|
|
# optim: optax.GradientTransformation,
|
|
|
|
# steps: int,
|
|
|
|
# print_every: int,
|
|
|
|
# ) -> CNN:
|
|
|
|
# @eqx.filter_jit
|
|
|
|
# def make_step(
|
|
|
|
# model: CNN,
|
|
|
|
# opt_state: PyTree,
|
|
|
|
# x: Float[Array, "batch 1 28 28"],
|
|
|
|
# y: Int[Array, "batch"],
|
|
|
|
# ):
|
|
|
|
# loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
|
|
|
|
# updates, opt_state = optim.update(
|
|
|
|
# grads, opt_state, eqx.filter(model, eqx.is_array)
|
|
|
|
# )
|
|
|
|
# model = eqx.apply_updates(model, updates)
|
|
|
|
# return model, opt_state, loss_value
|
|
|
|
#
|
|
|
|
# def infinite_data(loader: torch.utils.data.DataLoader):
|
|
|
|
# while True:
|
|
|
|
# yield from loader # Yields from loader until exhausted
|
|
|
|
#
|
|
|
|
# opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
|
|
|
#
|
|
|
|
# for step, (x, y) in zip(range(steps), infinite_data(trainloader)):
|
|
|
|
# x = jnp.array(x.numpy())
|
|
|
|
# y = jnp.array(y.numpy())
|
|
|
|
#
|
|
|
|
# model, opt_state, train_loss = make_step(model, opt_state, x, y)
|
|
|
|
#
|
|
|
|
# if (step % print_every) == 0 or step == steps - 1:
|
|
|
|
# avg_loss, avg_acc = evaluate(model, testloader)
|
|
|
|
#
|
|
|
|
# jax.debug.print("==== step {} ====", step)
|
|
|
|
# jax.debug.print("train loss = {}", train_loss)
|
|
|
|
# jax.debug.print("test loss = {}", avg_loss)
|
|
|
|
# jax.debug.print("text accuracy = {}", avg_acc)
|
|
|
|
#
|
|
|
|
# return model
|
|
|
|
|
|
|
|
class HParams():
|
|
|
|
nb_classes: int
|
|
|
|
is_bottleneck: bool
|
|
|
|
|
|
|
|
class ResidualBlock(eqx.Module):
|
|
|
|
bn1: eqx.nn.BatchNorm
|
2024-11-13 12:40:34 -07:00
|
|
|
bn2: eqx.nn.BatchNorm
|
|
|
|
conv1: eqx.nn.Conv2d
|
|
|
|
conv2: eqx.nn.Conv2d
|
|
|
|
avg_pool: eqx.nn.AvgPool2d
|
|
|
|
stride: int
|
|
|
|
relu_leak: float
|
|
|
|
|
|
|
|
def __init__(self, in_channels: int, out_channels: int, stride: int,
|
|
|
|
relu_leak: float, is_b4_res: bool, key):
|
|
|
|
self.stride = stride
|
|
|
|
self.relu_leak = relu_leak
|
|
|
|
self.is_b4_res = is_b4_res
|
2024-11-12 20:06:16 -07:00
|
|
|
|
|
|
|
keys = jax.random.split(key, 2)
|
|
|
|
|
|
|
|
self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
|
|
|
eps=0.001)
|
2024-11-13 12:40:34 -07:00
|
|
|
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
|
|
|
eps=0.001)
|
2024-11-13 14:57:16 -07:00
|
|
|
# TODO: is bn2 in_channels correct?
|
2024-11-13 12:40:34 -07:00
|
|
|
self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3,
|
|
|
|
stride=stride, key=keys[0])
|
|
|
|
self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3,
|
|
|
|
key=keys[1])
|
|
|
|
|
|
|
|
if stride != 1 or in_channels != out_channels:
|
|
|
|
self.avg_pool = eqx.nn.AvgPool2d(stride, stride=stride)
|
|
|
|
# TODO: padding might be wrong...
|
|
|
|
else:
|
|
|
|
self.avg_pool = None
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
x: Float[Array, "batch channels w h"]
|
|
|
|
) -> Float[Array, "batch channels w h"]:
|
|
|
|
if self.is_b4_res:
|
|
|
|
x = self.bn1(x)
|
|
|
|
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
|
|
|
orig_x = x
|
|
|
|
else:
|
|
|
|
orig_x = x
|
|
|
|
x = self.bn1(x)
|
|
|
|
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
|
|
|
|
|
|
|
x = self.conv1(x)
|
|
|
|
x = self.bn2(x)
|
|
|
|
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
|
|
|
x = self.conv2(x)
|
|
|
|
|
|
|
|
if self.avg_pool is not None:
|
|
|
|
orig_x = self.avg_pool(orig_x)
|
|
|
|
|
|
|
|
x += orig_x
|
|
|
|
return x
|
|
|
|
|
2024-11-12 20:06:16 -07:00
|
|
|
|
2024-11-13 14:57:16 -07:00
|
|
|
class BottleneckBlock(eqx.Module):
|
|
|
|
bn1: eqx.nn.BatchNorm
|
|
|
|
bn2: eqx.nn.BatchNorm
|
|
|
|
bn3: eqx.nn.BatchNorm
|
|
|
|
conv1: eqx.nn.Conv2d
|
|
|
|
conv2: eqx.nn.Conv2d
|
|
|
|
conv3: eqx.nn.Conv2d
|
|
|
|
project_conv: eqx.nn.Conv2d
|
|
|
|
relu_leak: float
|
|
|
|
is_b4_res: bool
|
|
|
|
|
|
|
|
def __init__(self, in_channels: int, out_channels: int, stride: int,
|
|
|
|
relu_leak: float, is_b4_res: bool, key):
|
|
|
|
self.stride = stride
|
|
|
|
self.relu_leak = relu_leak
|
|
|
|
self.is_b4_res = is_b4_res
|
|
|
|
|
|
|
|
keys = jax.random.split(key, 4)
|
|
|
|
|
|
|
|
# TODO: channels might be wrong for 2 and 3
|
|
|
|
self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
|
|
|
eps=0.001)
|
|
|
|
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
|
|
|
eps=0.001)
|
|
|
|
self.bn3 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
|
|
|
|
eps=0.001)
|
|
|
|
|
|
|
|
# TODO: /4 seems like it wouldn't work, maybe // or ceil?
|
|
|
|
# TODO: what's up with the lack of padding?
|
|
|
|
# TODO: does sride make sense for conv 2&3?
|
|
|
|
mid_channels = out_channels / 4
|
|
|
|
self.conv1 = eqx.nn.Conv2d(in_channels, mid_channels, kernel_size=1,
|
|
|
|
stride=stride, key=keys[0])
|
|
|
|
self.conv2 = eqx.nn.Conv2d(mid_channels, mid_channels, kernel_size=3,
|
|
|
|
stride=[1,1,1,1], key=keys[1])
|
|
|
|
self.conv3 = eqx.nn.Conv2d(mid_channels, out_channels, kernel_size=1,
|
|
|
|
stride=[1,1,1,1], key=keys[2])
|
|
|
|
|
|
|
|
if in_channels != out_channels:
|
|
|
|
self.project = eqx.nn.Conv2d(in_channels, out_channels,
|
|
|
|
kernel_size=1, stride=stride, key=keys[3])
|
|
|
|
else:
|
|
|
|
self.project = None
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
x: Float[Array, "batch channels w h"]
|
|
|
|
) -> Float[Array, "batch channels w h"]:
|
|
|
|
if self.is_b4_res:
|
|
|
|
x = self.bn1(x)
|
|
|
|
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
|
|
|
orig_x = x
|
|
|
|
else:
|
|
|
|
orig_x = x
|
|
|
|
x = self.bn1(x)
|
|
|
|
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
|
|
|
|
|
|
|
x = self.conv1(x)
|
|
|
|
x = self.bn2(x)
|
|
|
|
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
|
|
|
x = self.conv2(x)
|
|
|
|
x = self.bn3(x)
|
|
|
|
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
|
|
|
|
x = self.conv3(x)
|
|
|
|
|
|
|
|
if self.project:
|
|
|
|
orig_x = self.project(orig_x)
|
|
|
|
|
|
|
|
x += orig_x
|
|
|
|
return x
|
2024-11-12 20:06:16 -07:00
|
|
|
|
|
|
|
|
|
|
|
class ResNet(eqx.Module):
|
|
|
|
conv1: eqx.nn.Conv2d
|
|
|
|
bn1: eqx.nn.BatchNorm
|
|
|
|
layer1: ResidualBlock
|
|
|
|
layer2: ResidualBlock
|
|
|
|
layer3: ResidualBlock
|
|
|
|
linear: eqx.nn.Linear
|
|
|
|
hps: HParams
|
|
|
|
|
|
|
|
|
2024-11-13 12:40:34 -07:00
|
|
|
def __init__(self, hps: HParams):
|
2024-11-12 20:06:16 -07:00
|
|
|
self.hps = hps
|
|
|
|
keys = jax.random.split(key, 5)
|
|
|
|
|
|
|
|
self.conv1 = eqx.nn.Conv2d(3, 16, kernel_size=3, padding=1, key=keys[0])
|
|
|
|
self.bn1 = eqx.nn.BatchNorm(16, "batch", momentum=0.9, eps=0.001)
|
|
|
|
|
|
|
|
if hps.is_bottleneck:
|
|
|
|
res_func = BottleneckBlock
|
|
|
|
filters = [16, 64, 128, 256]
|
|
|
|
else:
|
|
|
|
res_func = ResidualBlock
|
|
|
|
filters = [16, 16, 32, 64]
|
|
|
|
|
|
|
|
self.layer1 = []
|
|
|
|
self.layer2 = []
|
|
|
|
self.layer3 = []
|
|
|
|
|
|
|
|
self.linear = eqx.nn.Linear(filters[3], hps.nb_classes, key=keys[4])
|
2024-11-12 18:40:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
def build_dataloader(is_train):
|
|
|
|
global BATCH_SIZE
|
|
|
|
|
|
|
|
transform_train = torchvision.transforms.Compose([
|
|
|
|
torchvision.transforms.RandomCrop(32, padding=4),
|
|
|
|
torchvision.transforms.RandomHorizontalFlip(),
|
|
|
|
torchvision.transforms.ToTensor(),
|
|
|
|
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
|
|
|
|
(0.247, 0.243, 0.261))
|
|
|
|
])
|
|
|
|
transform_test = torchvision.transforms.Compose([
|
|
|
|
torchvision.transforms.ToTensor(),
|
|
|
|
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
|
|
|
|
(0.247, 0.243, 0.261))
|
|
|
|
])
|
|
|
|
|
|
|
|
dataset = torchvision.datasets.CIFAR10(
|
|
|
|
"data",
|
|
|
|
train=is_train,
|
|
|
|
download=True,
|
|
|
|
transform=(transform_train if is_train else transform_test)
|
|
|
|
)
|
|
|
|
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
|
|
dataset, batch_size=BATCH_SIZE, shuffle=True
|
|
|
|
)
|
|
|
|
|
|
|
|
class DataLoaderWrapper:
|
|
|
|
def __init__(self, dataloader, nb_classes):
|
|
|
|
self.dataloader = dataloader
|
|
|
|
self.nb_classes = nb_classes
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
for images, labels in self.dataloader:
|
|
|
|
images = jnp.array(images)
|
|
|
|
|
|
|
|
labels = jnp.array(labels)
|
|
|
|
labels = jax.nn.one_hot(labels, 10)
|
|
|
|
|
|
|
|
yield (images, labels)
|
|
|
|
|
|
|
|
return DataLoaderWrapper(dataloader, 10)
|
|
|
|
|
|
|
|
|
|
|
|
# ╔─────────────────────────────────────────────────────────────────────────────╗
|
|
|
|
# │ Main script |
|
|
|
|
# ╚─────────────────────────────────────────────────────────────────────────────╝
|
|
|
|
jax.config.update("jax_platform_name", "gpu") # Sets preferred device
|
|
|
|
|
|
|
|
# Checking to make sure gpu is being used
|
|
|
|
from jax.extend import backend
|
|
|
|
|
|
|
|
print(backend.get_backend().platform)
|
|
|
|
print(f"JAX devices: {jax.devices()}")
|
|
|
|
print(f"Default device: {jax.default_backend()}")
|
|
|
|
|
|
|
|
# Hyperparameters
|
|
|
|
BATCH_SIZE = 16
|
|
|
|
LEARNING_RATE = 1e-4
|
|
|
|
STEPS = 1200
|
|
|
|
PRINT_EVERY = 300
|
|
|
|
SEED = 5678
|
|
|
|
|
|
|
|
key = jax.random.PRNGKey(SEED)
|
|
|
|
key, subkey = jax.random.split(key, 2)
|
|
|
|
|
2024-11-12 20:06:16 -07:00
|
|
|
train_loader = build_dataloader(False)
|
|
|
|
test_loader = build_dataloader(False)
|
2024-11-12 18:40:08 -07:00
|
|
|
|
2024-11-13 12:40:34 -07:00
|
|
|
print(train_loader)
|
2024-11-12 18:40:08 -07:00
|
|
|
|
2024-11-13 12:40:34 -07:00
|
|
|
x = next(iter(train_loader))
|
2024-11-12 18:40:08 -07:00
|
|
|
print(type(x), len(x))
|
|
|
|
print(type(x[0]), type(x[1]))
|
|
|
|
print(x[0].shape, x[1].shape)
|
|
|
|
print(f"First: {x[0][0, 0]}")
|
|
|
|
|
|
|
|
exit(1)
|
|
|
|
|
|
|
|
# model = CNN(subkey)
|
|
|
|
# optim = optax.adamw(LEARNING_RATE)
|
|
|
|
#
|
|
|
|
# start = time.time()
|
|
|
|
# model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)
|
|
|
|
# cease = time.time()
|
|
|
|
|
|
|
|
print(f"Took {cease-start}s")
|