Main: basic one-run scoring
This commit is contained in:
parent
ba4e063cd3
commit
b4d03f86b5
1 changed files with 37 additions and 8 deletions
45
src/main.py
45
src/main.py
|
@ -8,14 +8,14 @@ import optax
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from jaxtyping import Array, Float, Int, PyTree
|
from jaxtyping import Array, Float, Int, Bool, PyTree
|
||||||
from typing import Tuple, ForwardRef
|
from typing import Tuple, ForwardRef
|
||||||
|
|
||||||
|
|
||||||
DATA_ROOT="data"
|
DATA_ROOT="data"
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 1#64
|
||||||
LEARNING_RATE = 3e-4
|
LEARNING_RATE = 0.03 #3e-4
|
||||||
STEPS = 300
|
STEPS = 10#300
|
||||||
PRINT_EVERY = 30
|
PRINT_EVERY = 30
|
||||||
SEED = 5678
|
SEED = 5678
|
||||||
|
|
||||||
|
@ -150,7 +150,33 @@ def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.Da
|
||||||
return avg_loss, avg_acc
|
return avg_loss, avg_acc
|
||||||
|
|
||||||
|
|
||||||
def load_data(m, key):
|
def score_model(model: PyTree, attack_ds: torch.utils.data.TensorDataset, membership: Bool[Array, "m"]) -> Float[Array, "m"]:
|
||||||
|
assert len(attack_ds) == len(membership)
|
||||||
|
k = 800
|
||||||
|
|
||||||
|
params, statics = eqx.partition(model, eqx.is_array)
|
||||||
|
scores = list()
|
||||||
|
|
||||||
|
for i in range(len(attack_ds)):
|
||||||
|
x = np.expand_dims(attack_ds[i][0].numpy(), axis=0)
|
||||||
|
y = np.expand_dims(attack_ds[i][1].numpy(), axis=0)
|
||||||
|
|
||||||
|
loss, acc = get_stats(params, statics, x, y)
|
||||||
|
scores.append(loss.item())
|
||||||
|
|
||||||
|
scores = np.array(scores)
|
||||||
|
perm = np.argsort(scores)
|
||||||
|
scores = scores[perm]
|
||||||
|
membership = membership[perm]
|
||||||
|
|
||||||
|
correct = np.sum(membership[:k]) + np.sum(~membership[-k:])
|
||||||
|
|
||||||
|
print(scores[:10])
|
||||||
|
print(perm[:10])
|
||||||
|
print(f"Accuracy = {correct} / {2*k}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(m: int, key):
|
||||||
normalise_data = torchvision.transforms.Compose([
|
normalise_data = torchvision.transforms.Compose([
|
||||||
torchvision.transforms.ToTensor(),
|
torchvision.transforms.ToTensor(),
|
||||||
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
||||||
|
@ -198,7 +224,7 @@ def load_data(m, key):
|
||||||
|
|
||||||
attack_ds = torch.utils.data.TensorDataset(
|
attack_ds = torch.utils.data.TensorDataset(
|
||||||
torch.from_numpy(x_attack),
|
torch.from_numpy(x_attack),
|
||||||
torch.from_numpy(y_attack)
|
torch.from_numpy(y_attack),
|
||||||
)
|
)
|
||||||
x_train = np.concatenate([x_train, x_attack[membership]])
|
x_train = np.concatenate([x_train, x_attack[membership]])
|
||||||
y_train = np.concatenate([y_train, y_attack[membership]])
|
y_train = np.concatenate([y_train, y_attack[membership]])
|
||||||
|
@ -236,9 +262,12 @@ if __name__ == '__main__':
|
||||||
key = jax.random.PRNGKey(SEED)
|
key = jax.random.PRNGKey(SEED)
|
||||||
key, key1, key2 = jax.random.split(key, 3)
|
key, key1, key2 = jax.random.split(key, 3)
|
||||||
|
|
||||||
train_dl, test_dl, attack_ds, membership = load_data(1000, key2)
|
train_dl, test_dl, attack_ds, membership = load_data(5000, key2)
|
||||||
|
|
||||||
model = CNN(key1)
|
model = CNN(key1)
|
||||||
optim = optax.adamw(LEARNING_RATE)
|
optim = optax.adamw(3e-4)
|
||||||
|
#optim = optax.sgd(0.03)
|
||||||
|
|
||||||
model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10)
|
model = CNN.train(model, train_dl, test_dl, optim, STEPS, 10)
|
||||||
|
|
||||||
|
scores = score_model(model, attack_ds, membership)
|
||||||
|
|
Loading…
Reference in a new issue