Main: basic one-run scoring
This commit is contained in:
parent
ba4e063cd3
commit
b4d03f86b5
1 changed files with 37 additions and 8 deletions
45
src/main.py
45
src/main.py
|
@ -8,14 +8,14 @@ import optax
|
|||
import torch
|
||||
import torchvision
|
||||
from functools import partial
|
||||
from jaxtyping import Array, Float, Int, PyTree
|
||||
from jaxtyping import Array, Float, Int, Bool, PyTree
|
||||
from typing import Tuple, ForwardRef
|
||||
|
||||
|
||||
DATA_ROOT="data"
|
||||
BATCH_SIZE = 64
|
||||
LEARNING_RATE = 3e-4
|
||||
STEPS = 300
|
||||
BATCH_SIZE = 1#64
|
||||
LEARNING_RATE = 0.03 #3e-4
|
||||
STEPS = 10#300
|
||||
PRINT_EVERY = 30
|
||||
SEED = 5678
|
||||
|
||||
|
@ -150,7 +150,33 @@ def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.Da
|
|||
return avg_loss, avg_acc
|
||||
|
||||
|
||||
def load_data(m, key):
|
||||
def score_model(model: PyTree, attack_ds: torch.utils.data.TensorDataset, membership: Bool[Array, "m"]) -> Float[Array, "m"]:
|
||||
assert len(attack_ds) == len(membership)
|
||||
k = 800
|
||||
|
||||
params, statics = eqx.partition(model, eqx.is_array)
|
||||
scores = list()
|
||||
|
||||
for i in range(len(attack_ds)):
|
||||
x = np.expand_dims(attack_ds[i][0].numpy(), axis=0)
|
||||
y = np.expand_dims(attack_ds[i][1].numpy(), axis=0)
|
||||
|
||||
loss, acc = get_stats(params, statics, x, y)
|
||||
scores.append(loss.item())
|
||||
|
||||
scores = np.array(scores)
|
||||
perm = np.argsort(scores)
|
||||
scores = scores[perm]
|
||||
membership = membership[perm]
|
||||
|
||||
correct = np.sum(membership[:k]) + np.sum(~membership[-k:])
|
||||
|
||||
print(scores[:10])
|
||||
print(perm[:10])
|
||||
print(f"Accuracy = {correct} / {2*k}")
|
||||
|
||||
|
||||
def load_data(m: int, key):
|
||||
normalise_data = torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
||||
|
@ -198,7 +224,7 @@ def load_data(m, key):
|
|||
|
||||
attack_ds = torch.utils.data.TensorDataset(
|
||||
torch.from_numpy(x_attack),
|
||||
torch.from_numpy(y_attack)
|
||||
torch.from_numpy(y_attack),
|
||||
)
|
||||
x_train = np.concatenate([x_train, x_attack[membership]])
|
||||
y_train = np.concatenate([y_train, y_attack[membership]])
|
||||
|
@ -236,9 +262,12 @@ if __name__ == '__main__':
|
|||
key = jax.random.PRNGKey(SEED)
|
||||
key, key1, key2 = jax.random.split(key, 3)
|
||||
|
||||
train_dl, test_dl, attack_ds, membership = load_data(1000, key2)
|
||||
train_dl, test_dl, attack_ds, membership = load_data(5000, key2)
|
||||
|
||||
model = CNN(key1)
|
||||
optim = optax.adamw(LEARNING_RATE)
|
||||
optim = optax.adamw(3e-4)
|
||||
#optim = optax.sgd(0.03)
|
||||
|
||||
model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10)
|
||||
|
||||
scores = score_model(model, attack_ds, membership)
|
||||
|
|
Loading…
Reference in a new issue