Main: add one-run poisoning to dataloader

This commit is contained in:
Akemi Izuko 2024-12-22 19:23:35 -07:00
parent 69eefb9264
commit 000d7ccff2
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -150,7 +150,7 @@ def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.Da
return avg_loss, avg_acc return avg_loss, avg_acc
def load_data(): def load_data(m, key):
normalise_data = torchvision.transforms.Compose([ normalise_data = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(), torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)), torchvision.transforms.Normalize((0.5,), (0.5,)),
@ -168,6 +168,64 @@ def load_data():
transform=normalise_data, transform=normalise_data,
) )
rng = np.random.default_rng()
if m > 0:
x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))])
y_train = train_ds.targets.numpy().astype(np.int64)
x = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))])
y = test_ds.targets.numpy().astype(np.int64)
attack_mask = np.full(len(test_ds), False)
attack_mask[:m] = True
attack_mask = rng.permutation(attack_mask)
x_attack = x[attack_mask]
y_attack = y[attack_mask]
x_test = x[~attack_mask]
y_test = y[~attack_mask]
# Intentionally mislabel all attacked points, this is the "poisoning"
for i in range(y_attack.shape[0]):
while True:
a = rng.integers(0, 10)
if a != y_attack[i]:
y_attack[i] = a
break
membership = rng.choice([True, False], size=m)
attack_ds = torch.utils.data.TensorDataset(
torch.from_numpy(x_attack),
torch.from_numpy(y_attack)
)
train_ds = torch.utils.data.TensorDataset(
torch.from_numpy(np.concatenate([x_train, x_attack[membership]])),
torch.from_numpy(np.concatenate([y_train, y_attack[membership]])),
)
test_ds = torch.utils.data.TensorDataset(
torch.from_numpy(x_test),
torch.from_numpy(y_test),
)
else:
# Idk why, but it gets a LOT faster this way... maybe from pre-applying the transforms?
x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))])
y_train = train_ds.targets.numpy().astype(np.int64)
x_test = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))])
y_test = test_ds.targets.numpy().astype(np.int64)
train_ds = torch.utils.data.TensorDataset(
torch.from_numpy(x_train),
torch.from_numpy(y_train),
)
test_ds = torch.utils.data.TensorDataset(
torch.from_numpy(x_test),
torch.from_numpy(y_test),
)
print(f"Length of train ds: {len(train_ds)}")
print(f"Length of test ds: {len(test_ds)}")
train_dl = torch.utils.data.DataLoader( train_dl = torch.utils.data.DataLoader(
train_ds, batch_size=BATCH_SIZE, shuffle=True train_ds, batch_size=BATCH_SIZE, shuffle=True
) )
@ -178,12 +236,11 @@ def load_data():
return train_dl, test_dl return train_dl, test_dl
if __name__ == '__main__': if __name__ == '__main__':
train_dl, test_dl = load_data()
key = jax.random.PRNGKey(SEED) key = jax.random.PRNGKey(SEED)
key, key1 = jax.random.split(key, 2) key, key1, key2 = jax.random.split(key, 3)
train_dl, test_dl = load_data(1000, key2)
model = CNN(key1) model = CNN(key1)
optim = optax.adamw(LEARNING_RATE) optim = optax.adamw(LEARNING_RATE)