Resnet: resnet init

This commit is contained in:
Akemi Izuko 2024-11-12 20:06:16 -07:00
parent 3cb444e529
commit 4dd4868c55
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -13,113 +13,158 @@ import torchvision
from jaxtyping import Array, Float, Int, PyTree from jaxtyping import Array, Float, Int, PyTree
class CNN(eqx.Module): # class CNN(eqx.Module):
layers: list # layers: list
#
# def __init__(self, key):
# keys = jax.random.split(key, 4)
# keys = list(keys)
#
# self.layers = [
# eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]),
# eqx.nn.MaxPool2d(kernel_size=2),
# jax.nn.relu,
# jnp.ravel,
# eqx.nn.Linear(1728, 512, key=keys[1]),
# jax.nn.sigmoid,
# eqx.nn.Linear(512, 64, key=keys[2]),
# jax.nn.relu,
# eqx.nn.Linear(64, 10, key=keys[3]),
# jax.nn.log_softmax,
# ]
#
# def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
# for layer in self.layers:
# x = layer(x)
# return x
#
#
# def cross_entropy(
# y: Int[Array, " batch"], pred_y: Int[Array, " batch"]
# ) -> Float[Array, ""]:
# pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
# return -jnp.mean(pred_y)
#
#
# @eqx.filter_jit
# def loss(
# model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
# ) -> Float[Array, ""]:
# pred_y = jax.vmap(model)(x)
# return cross_entropy(y, pred_y)
#
#
# @eqx.filter_jit
# def compute_accuracy(
# model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
# ) -> Float[Array, ""]:
# pred_y = jax.vmap(model)(x)
# pred_y = jnp.argmax(pred_y, axis=1)
# return jnp.mean(y == pred_y)
#
#
# def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
# avg_loss = 0
# avg_acc = 0
#
# for x, y in testloader:
# x = jnp.array(x.numpy())
# y = jnp.array(y.numpy())
#
# avg_loss += loss(model, x, y)
# avg_acc += compute_accuracy(model, x, y)
#
# return avg_loss / len(testloader), avg_acc / len(testloader)
#
#
# def train(
# model: CNN,
# trainloader: torch.utils.data.DataLoader,
# testloader: torch.utils.data.DataLoader,
# optim: optax.GradientTransformation,
# steps: int,
# print_every: int,
# ) -> CNN:
# @eqx.filter_jit
# def make_step(
# model: CNN,
# opt_state: PyTree,
# x: Float[Array, "batch 1 28 28"],
# y: Int[Array, "batch"],
# ):
# loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
# updates, opt_state = optim.update(
# grads, opt_state, eqx.filter(model, eqx.is_array)
# )
# model = eqx.apply_updates(model, updates)
# return model, opt_state, loss_value
#
# def infinite_data(loader: torch.utils.data.DataLoader):
# while True:
# yield from loader # Yields from loader until exhausted
#
# opt_state = optim.init(eqx.filter(model, eqx.is_array))
#
# for step, (x, y) in zip(range(steps), infinite_data(trainloader)):
# x = jnp.array(x.numpy())
# y = jnp.array(y.numpy())
#
# model, opt_state, train_loss = make_step(model, opt_state, x, y)
#
# if (step % print_every) == 0 or step == steps - 1:
# avg_loss, avg_acc = evaluate(model, testloader)
#
# jax.debug.print("==== step {} ====", step)
# jax.debug.print("train loss = {}", train_loss)
# jax.debug.print("test loss = {}", avg_loss)
# jax.debug.print("text accuracy = {}", avg_acc)
#
# return model
def __init__(self, key): class HParams():
keys = jax.random.split(key, 4) nb_classes: int
keys = list(keys) is_bottleneck: bool
self.layers = [ class ResidualBlock(eqx.Module):
eqx.nn.Conv2d(1, 3, kernel_size=4, key=keys[0]), bn1: eqx.nn.BatchNorm
eqx.nn.MaxPool2d(kernel_size=2),
jax.nn.relu,
jnp.ravel,
eqx.nn.Linear(1728, 512, key=keys[1]),
jax.nn.sigmoid,
eqx.nn.Linear(512, 64, key=keys[2]),
jax.nn.relu,
eqx.nn.Linear(64, 10, key=keys[3]),
jax.nn.log_softmax,
]
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]: def __init__(self, in_channels: int, out_channels: int, stride: int, key):
for layer in self.layers: keys = jax.random.split(key, 2)
x = layer(x)
return x self.bn1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9,
eps=0.001)
def cross_entropy(
y: Int[Array, " batch"], pred_y: Int[Array, " batch"] class ResNet(eqx.Module):
) -> Float[Array, ""]: conv1: eqx.nn.Conv2d
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) bn1: eqx.nn.BatchNorm
return -jnp.mean(pred_y) layer1: ResidualBlock
layer2: ResidualBlock
layer3: ResidualBlock
linear: eqx.nn.Linear
hps: HParams
@eqx.filter_jit self __init__(self, hps: HParams):
def loss( self.hps = hps
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] keys = jax.random.split(key, 5)
) -> Float[Array, ""]:
pred_y = jax.vmap(model)(x)
return cross_entropy(y, pred_y)
self.conv1 = eqx.nn.Conv2d(3, 16, kernel_size=3, padding=1, key=keys[0])
self.bn1 = eqx.nn.BatchNorm(16, "batch", momentum=0.9, eps=0.001)
@eqx.filter_jit if hps.is_bottleneck:
def compute_accuracy( res_func = BottleneckBlock
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] filters = [16, 64, 128, 256]
) -> Float[Array, ""]: else:
pred_y = jax.vmap(model)(x) res_func = ResidualBlock
pred_y = jnp.argmax(pred_y, axis=1) filters = [16, 16, 32, 64]
return jnp.mean(y == pred_y)
self.layer1 = []
self.layer2 = []
self.layer3 = []
def evaluate(model: CNN, testloader: torch.utils.data.DataLoader): self.linear = eqx.nn.Linear(filters[3], hps.nb_classes, key=keys[4])
avg_loss = 0
avg_acc = 0
for x, y in testloader:
x = jnp.array(x.numpy())
y = jnp.array(y.numpy())
avg_loss += loss(model, x, y)
avg_acc += compute_accuracy(model, x, y)
return avg_loss / len(testloader), avg_acc / len(testloader)
def train(
model: CNN,
trainloader: torch.utils.data.DataLoader,
testloader: torch.utils.data.DataLoader,
optim: optax.GradientTransformation,
steps: int,
print_every: int,
) -> CNN:
@eqx.filter_jit
def make_step(
model: CNN,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, "batch"],
):
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
updates, opt_state = optim.update(
grads, opt_state, eqx.filter(model, eqx.is_array)
)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss_value
def infinite_data(loader: torch.utils.data.DataLoader):
while True:
yield from loader # Yields from loader until exhausted
opt_state = optim.init(eqx.filter(model, eqx.is_array))
for step, (x, y) in zip(range(steps), infinite_data(trainloader)):
x = jnp.array(x.numpy())
y = jnp.array(y.numpy())
model, opt_state, train_loss = make_step(model, opt_state, x, y)
if (step % print_every) == 0 or step == steps - 1:
avg_loss, avg_acc = evaluate(model, testloader)
jax.debug.print("==== step {} ====", step)
jax.debug.print("train loss = {}", train_loss)
jax.debug.print("test loss = {}", avg_loss)
jax.debug.print("text accuracy = {}", avg_acc)
return model
def build_dataloader(is_train): def build_dataloader(is_train):
@ -166,7 +211,6 @@ def build_dataloader(is_train):
return DataLoaderWrapper(dataloader, 10) return DataLoaderWrapper(dataloader, 10)
# ╔─────────────────────────────────────────────────────────────────────────────╗ # ╔─────────────────────────────────────────────────────────────────────────────╗
# │ Main script | # │ Main script |
# ╚─────────────────────────────────────────────────────────────────────────────╝ # ╚─────────────────────────────────────────────────────────────────────────────╝
@ -189,7 +233,8 @@ SEED = 5678
key = jax.random.PRNGKey(SEED) key = jax.random.PRNGKey(SEED)
key, subkey = jax.random.split(key, 2) key, subkey = jax.random.split(key, 2)
dataloader = build_dataloader(False) train_loader = build_dataloader(False)
test_loader = build_dataloader(False)
print(dataloader) print(dataloader)
@ -197,12 +242,7 @@ x = next(iter(dataloader))
print(type(x), len(x)) print(type(x), len(x))
print(type(x[0]), type(x[1])) print(type(x[0]), type(x[1]))
print(x[0].shape, x[1].shape) print(x[0].shape, x[1].shape)
# x[1] = jnp.array(x[1])
# print(f"Max: {jnp.max(x[1])}")
# print(f"Min: {jnp.min(x[1])}")
# print(f"Mean: {jnp.mean(x[1])}")
print(f"First: {x[0][0, 0]}") print(f"First: {x[0][0, 0]}")
# print(f"1hot: {jax.nn.one_hot(x[1][0], 10)}")
exit(1) exit(1)