Main: add one-run poisoning to dataloader
This commit is contained in:
parent
69eefb9264
commit
000d7ccff2
1 changed files with 62 additions and 5 deletions
67
src/main.py
67
src/main.py
|
@ -150,7 +150,7 @@ def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.Da
|
|||
return avg_loss, avg_acc
|
||||
|
||||
|
||||
def load_data():
|
||||
def load_data(m, key):
|
||||
normalise_data = torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
||||
|
@ -168,6 +168,64 @@ def load_data():
|
|||
transform=normalise_data,
|
||||
)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
if m > 0:
|
||||
x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))])
|
||||
y_train = train_ds.targets.numpy().astype(np.int64)
|
||||
|
||||
x = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))])
|
||||
y = test_ds.targets.numpy().astype(np.int64)
|
||||
|
||||
attack_mask = np.full(len(test_ds), False)
|
||||
attack_mask[:m] = True
|
||||
attack_mask = rng.permutation(attack_mask)
|
||||
|
||||
x_attack = x[attack_mask]
|
||||
y_attack = y[attack_mask]
|
||||
x_test = x[~attack_mask]
|
||||
y_test = y[~attack_mask]
|
||||
|
||||
# Intentionally mislabel all attacked points, this is the "poisoning"
|
||||
for i in range(y_attack.shape[0]):
|
||||
while True:
|
||||
a = rng.integers(0, 10)
|
||||
if a != y_attack[i]:
|
||||
y_attack[i] = a
|
||||
break
|
||||
|
||||
membership = rng.choice([True, False], size=m)
|
||||
|
||||
attack_ds = torch.utils.data.TensorDataset(
|
||||
torch.from_numpy(x_attack),
|
||||
torch.from_numpy(y_attack)
|
||||
)
|
||||
train_ds = torch.utils.data.TensorDataset(
|
||||
torch.from_numpy(np.concatenate([x_train, x_attack[membership]])),
|
||||
torch.from_numpy(np.concatenate([y_train, y_attack[membership]])),
|
||||
)
|
||||
test_ds = torch.utils.data.TensorDataset(
|
||||
torch.from_numpy(x_test),
|
||||
torch.from_numpy(y_test),
|
||||
)
|
||||
else:
|
||||
# Idk why, but it gets a LOT faster this way... maybe from pre-applying the transforms?
|
||||
x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))])
|
||||
y_train = train_ds.targets.numpy().astype(np.int64)
|
||||
x_test = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))])
|
||||
y_test = test_ds.targets.numpy().astype(np.int64)
|
||||
|
||||
train_ds = torch.utils.data.TensorDataset(
|
||||
torch.from_numpy(x_train),
|
||||
torch.from_numpy(y_train),
|
||||
)
|
||||
test_ds = torch.utils.data.TensorDataset(
|
||||
torch.from_numpy(x_test),
|
||||
torch.from_numpy(y_test),
|
||||
)
|
||||
|
||||
print(f"Length of train ds: {len(train_ds)}")
|
||||
print(f"Length of test ds: {len(test_ds)}")
|
||||
train_dl = torch.utils.data.DataLoader(
|
||||
train_ds, batch_size=BATCH_SIZE, shuffle=True
|
||||
)
|
||||
|
@ -178,12 +236,11 @@ def load_data():
|
|||
return train_dl, test_dl
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_dl, test_dl = load_data()
|
||||
|
||||
key = jax.random.PRNGKey(SEED)
|
||||
key, key1 = jax.random.split(key, 2)
|
||||
key, key1, key2 = jax.random.split(key, 3)
|
||||
|
||||
train_dl, test_dl = load_data(1000, key2)
|
||||
|
||||
model = CNN(key1)
|
||||
optim = optax.adamw(LEARNING_RATE)
|
||||
|
|
Loading…
Reference in a new issue