Main: basic one-run scoring

This commit is contained in:
Akemi Izuko 2024-12-22 20:16:54 -07:00
parent ba4e063cd3
commit b4d03f86b5
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -8,14 +8,14 @@ import optax
import torch import torch
import torchvision import torchvision
from functools import partial from functools import partial
from jaxtyping import Array, Float, Int, PyTree from jaxtyping import Array, Float, Int, Bool, PyTree
from typing import Tuple, ForwardRef from typing import Tuple, ForwardRef
DATA_ROOT="data" DATA_ROOT="data"
BATCH_SIZE = 64 BATCH_SIZE = 1#64
LEARNING_RATE = 3e-4 LEARNING_RATE = 0.03 #3e-4
STEPS = 300 STEPS = 10#300
PRINT_EVERY = 30 PRINT_EVERY = 30
SEED = 5678 SEED = 5678
@ -150,7 +150,33 @@ def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.Da
return avg_loss, avg_acc return avg_loss, avg_acc
def load_data(m, key): def score_model(model: PyTree, attack_ds: torch.utils.data.TensorDataset, membership: Bool[Array, "m"]) -> Float[Array, "m"]:
assert len(attack_ds) == len(membership)
k = 800
params, statics = eqx.partition(model, eqx.is_array)
scores = list()
for i in range(len(attack_ds)):
x = np.expand_dims(attack_ds[i][0].numpy(), axis=0)
y = np.expand_dims(attack_ds[i][1].numpy(), axis=0)
loss, acc = get_stats(params, statics, x, y)
scores.append(loss.item())
scores = np.array(scores)
perm = np.argsort(scores)
scores = scores[perm]
membership = membership[perm]
correct = np.sum(membership[:k]) + np.sum(~membership[-k:])
print(scores[:10])
print(perm[:10])
print(f"Accuracy = {correct} / {2*k}")
def load_data(m: int, key):
normalise_data = torchvision.transforms.Compose([ normalise_data = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(), torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)), torchvision.transforms.Normalize((0.5,), (0.5,)),
@ -198,7 +224,7 @@ def load_data(m, key):
attack_ds = torch.utils.data.TensorDataset( attack_ds = torch.utils.data.TensorDataset(
torch.from_numpy(x_attack), torch.from_numpy(x_attack),
torch.from_numpy(y_attack) torch.from_numpy(y_attack),
) )
x_train = np.concatenate([x_train, x_attack[membership]]) x_train = np.concatenate([x_train, x_attack[membership]])
y_train = np.concatenate([y_train, y_attack[membership]]) y_train = np.concatenate([y_train, y_attack[membership]])
@ -236,9 +262,12 @@ if __name__ == '__main__':
key = jax.random.PRNGKey(SEED) key = jax.random.PRNGKey(SEED)
key, key1, key2 = jax.random.split(key, 3) key, key1, key2 = jax.random.split(key, 3)
train_dl, test_dl, attack_ds, membership = load_data(1000, key2) train_dl, test_dl, attack_ds, membership = load_data(5000, key2)
model = CNN(key1) model = CNN(key1)
optim = optax.adamw(LEARNING_RATE) optim = optax.adamw(3e-4)
#optim = optax.sgd(0.03)
model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10) model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10)
scores = score_model(model, attack_ds, membership)