diff --git a/src/main.py b/src/main.py index 8f2fe41..483defa 100644 --- a/src/main.py +++ b/src/main.py @@ -8,14 +8,14 @@ import optax import torch import torchvision from functools import partial -from jaxtyping import Array, Float, Int, PyTree +from jaxtyping import Array, Float, Int, Bool, PyTree from typing import Tuple, ForwardRef DATA_ROOT="data" -BATCH_SIZE = 64 -LEARNING_RATE = 3e-4 -STEPS = 300 +BATCH_SIZE = 1#64 +LEARNING_RATE = 0.03 #3e-4 +STEPS = 10#300 PRINT_EVERY = 30 SEED = 5678 @@ -150,7 +150,33 @@ def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.Da return avg_loss, avg_acc -def load_data(m, key): +def score_model(model: PyTree, attack_ds: torch.utils.data.TensorDataset, membership: Bool[Array, "m"]) -> Float[Array, "m"]: + assert len(attack_ds) == len(membership) + k = 800 + + params, statics = eqx.partition(model, eqx.is_array) + scores = list() + + for i in range(len(attack_ds)): + x = np.expand_dims(attack_ds[i][0].numpy(), axis=0) + y = np.expand_dims(attack_ds[i][1].numpy(), axis=0) + + loss, acc = get_stats(params, statics, x, y) + scores.append(loss.item()) + + scores = np.array(scores) + perm = np.argsort(scores) + scores = scores[perm] + membership = membership[perm] + + correct = np.sum(membership[:k]) + np.sum(~membership[-k:]) + + print(scores[:10]) + print(perm[:10]) + print(f"Accuracy = {correct} / {2*k}") + + +def load_data(m: int, key): normalise_data = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (0.5,)), @@ -198,7 +224,7 @@ def load_data(m, key): attack_ds = torch.utils.data.TensorDataset( torch.from_numpy(x_attack), - torch.from_numpy(y_attack) + torch.from_numpy(y_attack), ) x_train = np.concatenate([x_train, x_attack[membership]]) y_train = np.concatenate([y_train, y_attack[membership]]) @@ -236,9 +262,12 @@ if __name__ == '__main__': key = jax.random.PRNGKey(SEED) key, key1, key2 = jax.random.split(key, 3) - train_dl, test_dl, attack_ds, membership = load_data(1000, key2) + train_dl, test_dl, attack_ds, membership = load_data(5000, key2) model = CNN(key1) - optim = optax.adamw(LEARNING_RATE) + optim = optax.adamw(3e-4) + #optim = optax.sgd(0.03) model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10) + + scores = score_model(model, attack_ds, membership)