Main: basic one-run scoring

This commit is contained in:
Akemi Izuko 2024-12-22 20:16:54 -07:00
parent ba4e063cd3
commit b4d03f86b5
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -8,14 +8,14 @@ import optax
import torch
import torchvision
from functools import partial
from jaxtyping import Array, Float, Int, PyTree
from jaxtyping import Array, Float, Int, Bool, PyTree
from typing import Tuple, ForwardRef
DATA_ROOT="data"
BATCH_SIZE = 64
LEARNING_RATE = 3e-4
STEPS = 300
BATCH_SIZE = 1#64
LEARNING_RATE = 0.03 #3e-4
STEPS = 10#300
PRINT_EVERY = 30
SEED = 5678
@ -150,7 +150,33 @@ def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.Da
return avg_loss, avg_acc
def load_data(m, key):
def score_model(model: PyTree, attack_ds: torch.utils.data.TensorDataset, membership: Bool[Array, "m"]) -> Float[Array, "m"]:
assert len(attack_ds) == len(membership)
k = 800
params, statics = eqx.partition(model, eqx.is_array)
scores = list()
for i in range(len(attack_ds)):
x = np.expand_dims(attack_ds[i][0].numpy(), axis=0)
y = np.expand_dims(attack_ds[i][1].numpy(), axis=0)
loss, acc = get_stats(params, statics, x, y)
scores.append(loss.item())
scores = np.array(scores)
perm = np.argsort(scores)
scores = scores[perm]
membership = membership[perm]
correct = np.sum(membership[:k]) + np.sum(~membership[-k:])
print(scores[:10])
print(perm[:10])
print(f"Accuracy = {correct} / {2*k}")
def load_data(m: int, key):
normalise_data = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)),
@ -198,7 +224,7 @@ def load_data(m, key):
attack_ds = torch.utils.data.TensorDataset(
torch.from_numpy(x_attack),
torch.from_numpy(y_attack)
torch.from_numpy(y_attack),
)
x_train = np.concatenate([x_train, x_attack[membership]])
y_train = np.concatenate([y_train, y_attack[membership]])
@ -236,9 +262,12 @@ if __name__ == '__main__':
key = jax.random.PRNGKey(SEED)
key, key1, key2 = jax.random.split(key, 3)
train_dl, test_dl, attack_ds, membership = load_data(1000, key2)
train_dl, test_dl, attack_ds, membership = load_data(5000, key2)
model = CNN(key1)
optim = optax.adamw(LEARNING_RATE)
optim = optax.adamw(3e-4)
#optim = optax.sgd(0.03)
model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10)
scores = score_model(model, attack_ds, membership)