dawnbench_jax/resnet_cifar10.py

376 lines
12 KiB
Python
Raw Permalink Normal View History

2024-11-12 18:40:08 -07:00
#!/usr/bin/env python3
"""
This is the CNN tutorial from https://docs.kidger.site/equinox/examples/mnist/,
just using it to learn equinox
"""
import equinox as eqx
import jax.numpy as jnp
import jax
import optax
import time
import torch
import torchvision
from jaxtyping import Array, Float, Int, PyTree
2024-11-12 20:06:16 -07:00
# class CNN(eqx.Module):
# layers: list
#
# def __init__(self, key):
# keys = jax.random.split(key, 4)
# keys = list(keys)
#
# self.layers = [
# eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]),
# eqx.nn.MaxPool2d(kernel_size=2),
# jax.nn.relu,
# jnp.ravel,
# eqx.nn.Linear(1728, 512, key=keys[1]),
# jax.nn.sigmoid,
# eqx.nn.Linear(512, 64, key=keys[2]),
# jax.nn.relu,
# eqx.nn.Linear(64, 10, key=keys[3]),
# jax.nn.log_softmax,
# ]
#
# def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
# for layer in self.layers:
# x = layer(x)
# return x
#
#
# def cross_entropy(
# y: Int[Array, " batch"], pred_y: Int[Array, " batch"]
# ) -> Float[Array, ""]:
# pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
# return -jnp.mean(pred_y)
#
#
# @eqx.filter_jit
# def loss(
# model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
# ) -> Float[Array, ""]:
# pred_y = jax.vmap(model)(x)
# return cross_entropy(y, pred_y)
#
#
# @eqx.filter_jit
# def compute_accuracy(
# model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
# ) -> Float[Array, ""]:
# pred_y = jax.vmap(model)(x)
# pred_y = jnp.argmax(pred_y, axis=1)
# return jnp.mean(y == pred_y)
#
#
# def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
# avg_loss = 0
# avg_acc = 0
#
# for x, y in testloader:
# x = jnp.array(x.numpy())
# y = jnp.array(y.numpy())
#
# avg_loss += loss(model, x, y)
# avg_acc += compute_accuracy(model, x, y)
#
# return avg_loss / len(testloader), avg_acc / len(testloader)
#
#
# def train(
# model: CNN,
# trainloader: torch.utils.data.DataLoader,
# testloader: torch.utils.data.DataLoader,
# optim: optax.GradientTransformation,
# steps: int,
# print_every: int,
# ) -> CNN:
# @eqx.filter_jit
# def make_step(
# model: CNN,
# opt_state: PyTree,
# x: Float[Array, "batch 1 28 28"],
# y: Int[Array, "batch"],
# ):
# loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
# updates, opt_state = optim.update(
# grads, opt_state, eqx.filter(model, eqx.is_array)
# )
# model = eqx.apply_updates(model, updates)
# return model, opt_state, loss_value
#
# def infinite_data(loader: torch.utils.data.DataLoader):
# while True:
# yield from loader # Yields from loader until exhausted
#
# opt_state = optim.init(eqx.filter(model, eqx.is_array))
#
# for step, (x, y) in zip(range(steps), infinite_data(trainloader)):
# x = jnp.array(x.numpy())
# y = jnp.array(y.numpy())
#
# model, opt_state, train_loss = make_step(model, opt_state, x, y)
#
# if (step % print_every) == 0 or step == steps - 1:
# avg_loss, avg_acc = evaluate(model, testloader)
#
# jax.debug.print("==== step {} ====", step)
# jax.debug.print("train loss = {}", train_loss)
# jax.debug.print("test loss = {}", avg_loss)
# jax.debug.print("text accuracy = {}", avg_acc)
#
# return model
class HParams():
nb_classes: int
is_bottleneck: bool
class ResidualBlock(eqx.Module):
bn1: eqx.nn.BatchNorm
2024-11-13 12:40:34 -07:00
bn2: eqx.nn.BatchNorm
conv1: eqx.nn.Conv2d
conv2: eqx.nn.Conv2d
avg_pool: eqx.nn.AvgPool2d
stride: int
relu_leak: float
def __init__(self, in_channels: int, out_channels: int, stride: int,
relu_leak: float, is_b4_res: bool, key):
self.stride = stride
self.relu_leak = relu_leak
self.is_b4_res = is_b4_res
2024-11-12 20:06:16 -07:00
keys = jax.random.split(key, 2)
self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001)
2024-11-13 12:40:34 -07:00
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001)
2024-11-13 14:57:16 -07:00
# TODO: is bn2 in_channels correct?
2024-11-13 12:40:34 -07:00
self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, key=keys[0])
self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, kernel_size=3,
key=keys[1])
if stride != 1 or in_channels != out_channels:
self.avg_pool = eqx.nn.AvgPool2d(stride, stride=stride)
# TODO: padding might be wrong...
else:
self.avg_pool = None
def __call__(
self,
x: Float[Array, "batch channels w h"]
) -> Float[Array, "batch channels w h"]:
if self.is_b4_res:
x = self.bn1(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
orig_x = x
else:
orig_x = x
x = self.bn1(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
x = self.conv1(x)
x = self.bn2(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
x = self.conv2(x)
if self.avg_pool is not None:
orig_x = self.avg_pool(orig_x)
x += orig_x
return x
2024-11-12 20:06:16 -07:00
2024-11-13 14:57:16 -07:00
class BottleneckBlock(eqx.Module):
bn1: eqx.nn.BatchNorm
bn2: eqx.nn.BatchNorm
bn3: eqx.nn.BatchNorm
conv1: eqx.nn.Conv2d
conv2: eqx.nn.Conv2d
conv3: eqx.nn.Conv2d
project_conv: eqx.nn.Conv2d
relu_leak: float
is_b4_res: bool
def __init__(self, in_channels: int, out_channels: int, stride: int,
relu_leak: float, is_b4_res: bool, key):
self.stride = stride
self.relu_leak = relu_leak
self.is_b4_res = is_b4_res
keys = jax.random.split(key, 4)
# TODO: channels might be wrong for 2 and 3
self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001)
self.bn2 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001)
self.bn3 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001)
# TODO: /4 seems like it wouldn't work, maybe // or ceil?
# TODO: what's up with the lack of padding?
# TODO: does sride make sense for conv 2&3?
mid_channels = out_channels / 4
self.conv1 = eqx.nn.Conv2d(in_channels, mid_channels, kernel_size=1,
stride=stride, key=keys[0])
self.conv2 = eqx.nn.Conv2d(mid_channels, mid_channels, kernel_size=3,
stride=[1,1,1,1], key=keys[1])
self.conv3 = eqx.nn.Conv2d(mid_channels, out_channels, kernel_size=1,
stride=[1,1,1,1], key=keys[2])
if in_channels != out_channels:
self.project = eqx.nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=stride, key=keys[3])
else:
self.project = None
def __call__(
self,
x: Float[Array, "batch channels w h"]
) -> Float[Array, "batch channels w h"]:
if self.is_b4_res:
x = self.bn1(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
orig_x = x
else:
orig_x = x
x = self.bn1(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
x = self.conv1(x)
x = self.bn2(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
x = self.conv2(x)
x = self.bn3(x)
x = jax.nn.leaky_relu(x, negative_slope=self.relu_leak)
x = self.conv3(x)
if self.project:
orig_x = self.project(orig_x)
x += orig_x
return x
2024-11-12 20:06:16 -07:00
class ResNet(eqx.Module):
conv1: eqx.nn.Conv2d
bn1: eqx.nn.BatchNorm
layer1: ResidualBlock
layer2: ResidualBlock
layer3: ResidualBlock
linear: eqx.nn.Linear
hps: HParams
2024-11-13 12:40:34 -07:00
def __init__(self, hps: HParams):
2024-11-12 20:06:16 -07:00
self.hps = hps
keys = jax.random.split(key, 5)
self.conv1 = eqx.nn.Conv2d(3, 16, kernel_size=3, padding=1, key=keys[0])
self.bn1 = eqx.nn.BatchNorm(16, "batch", momentum=0.9, eps=0.001)
if hps.is_bottleneck:
res_func = BottleneckBlock
filters = [16, 64, 128, 256]
else:
res_func = ResidualBlock
filters = [16, 16, 32, 64]
self.layer1 = []
self.layer2 = []
self.layer3 = []
self.linear = eqx.nn.Linear(filters[3], hps.nb_classes, key=keys[4])
2024-11-12 18:40:08 -07:00
def build_dataloader(is_train):
global BATCH_SIZE
transform_train = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.247, 0.243, 0.261))
])
transform_test = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.247, 0.243, 0.261))
])
dataset = torchvision.datasets.CIFAR10(
"data",
train=is_train,
download=True,
transform=(transform_train if is_train else transform_test)
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=BATCH_SIZE, shuffle=True
)
class DataLoaderWrapper:
def __init__(self, dataloader, nb_classes):
self.dataloader = dataloader
self.nb_classes = nb_classes
def __iter__(self):
for images, labels in self.dataloader:
images = jnp.array(images)
labels = jnp.array(labels)
labels = jax.nn.one_hot(labels, 10)
yield (images, labels)
return DataLoaderWrapper(dataloader, 10)
# ╔─────────────────────────────────────────────────────────────────────────────╗
# │ Main script |
# ╚─────────────────────────────────────────────────────────────────────────────╝
jax.config.update("jax_platform_name", "gpu") # Sets preferred device
# Checking to make sure gpu is being used
from jax.extend import backend
print(backend.get_backend().platform)
print(f"JAX devices: {jax.devices()}")
print(f"Default device: {jax.default_backend()}")
# Hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
STEPS = 1200
PRINT_EVERY = 300
SEED = 5678
key = jax.random.PRNGKey(SEED)
key, subkey = jax.random.split(key, 2)
2024-11-12 20:06:16 -07:00
train_loader = build_dataloader(False)
test_loader = build_dataloader(False)
2024-11-12 18:40:08 -07:00
2024-11-13 12:40:34 -07:00
print(train_loader)
2024-11-12 18:40:08 -07:00
2024-11-13 12:40:34 -07:00
x = next(iter(train_loader))
2024-11-12 18:40:08 -07:00
print(type(x), len(x))
print(type(x[0]), type(x[1]))
print(x[0].shape, x[1].shape)
print(f"First: {x[0][0, 0]}")
exit(1)
# model = CNN(subkey)
# optim = optax.adamw(LEARNING_RATE)
#
# start = time.time()
# model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)
# cease = time.time()
print(f"Took {cease-start}s")