diff --git a/src/main.py b/src/main.py index 28c32c1..3d22fe4 100644 --- a/src/main.py +++ b/src/main.py @@ -150,7 +150,7 @@ def evaluate_model(params: PyTree, statics: PyTree, test_dl: torch.utils.data.Da return avg_loss, avg_acc -def load_data(): +def load_data(m, key): normalise_data = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (0.5,)), @@ -168,6 +168,64 @@ def load_data(): transform=normalise_data, ) + rng = np.random.default_rng() + + if m > 0: + x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))]) + y_train = train_ds.targets.numpy().astype(np.int64) + + x = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))]) + y = test_ds.targets.numpy().astype(np.int64) + + attack_mask = np.full(len(test_ds), False) + attack_mask[:m] = True + attack_mask = rng.permutation(attack_mask) + + x_attack = x[attack_mask] + y_attack = y[attack_mask] + x_test = x[~attack_mask] + y_test = y[~attack_mask] + + # Intentionally mislabel all attacked points, this is the "poisoning" + for i in range(y_attack.shape[0]): + while True: + a = rng.integers(0, 10) + if a != y_attack[i]: + y_attack[i] = a + break + + membership = rng.choice([True, False], size=m) + + attack_ds = torch.utils.data.TensorDataset( + torch.from_numpy(x_attack), + torch.from_numpy(y_attack) + ) + train_ds = torch.utils.data.TensorDataset( + torch.from_numpy(np.concatenate([x_train, x_attack[membership]])), + torch.from_numpy(np.concatenate([y_train, y_attack[membership]])), + ) + test_ds = torch.utils.data.TensorDataset( + torch.from_numpy(x_test), + torch.from_numpy(y_test), + ) + else: + # Idk why, but it gets a LOT faster this way... maybe from pre-applying the transforms? + x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))]) + y_train = train_ds.targets.numpy().astype(np.int64) + x_test = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))]) + y_test = test_ds.targets.numpy().astype(np.int64) + + train_ds = torch.utils.data.TensorDataset( + torch.from_numpy(x_train), + torch.from_numpy(y_train), + ) + test_ds = torch.utils.data.TensorDataset( + torch.from_numpy(x_test), + torch.from_numpy(y_test), + ) + + print(f"Length of train ds: {len(train_ds)}") + print(f"Length of test ds: {len(test_ds)}") train_dl = torch.utils.data.DataLoader( train_ds, batch_size=BATCH_SIZE, shuffle=True ) @@ -178,12 +236,11 @@ def load_data(): return train_dl, test_dl - if __name__ == '__main__': - train_dl, test_dl = load_data() - key = jax.random.PRNGKey(SEED) - key, key1 = jax.random.split(key, 2) + key, key1, key2 = jax.random.split(key, 3) + + train_dl, test_dl = load_data(1000, key2) model = CNN(key1) optim = optax.adamw(LEARNING_RATE)