From ba4e063cd311494f35b4bd95cd2dcdbcaed227f6 Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Sun, 22 Dec 2024 19:31:49 -0700 Subject: [PATCH] Main: return poison data from loader --- src/main.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/src/main.py b/src/main.py index 3d22fe4..8f2fe41 100644 --- a/src/main.py +++ b/src/main.py @@ -200,29 +200,25 @@ def load_data(m, key): torch.from_numpy(x_attack), torch.from_numpy(y_attack) ) - train_ds = torch.utils.data.TensorDataset( - torch.from_numpy(np.concatenate([x_train, x_attack[membership]])), - torch.from_numpy(np.concatenate([y_train, y_attack[membership]])), - ) - test_ds = torch.utils.data.TensorDataset( - torch.from_numpy(x_test), - torch.from_numpy(y_test), - ) + x_train = np.concatenate([x_train, x_attack[membership]]) + y_train = np.concatenate([y_train, y_attack[membership]]) else: # Idk why, but it gets a LOT faster this way... maybe from pre-applying the transforms? x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))]) y_train = train_ds.targets.numpy().astype(np.int64) x_test = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))]) y_test = test_ds.targets.numpy().astype(np.int64) + attack_ds = torch.utils.data.TensorDataset() + membership = np.empty(0) - train_ds = torch.utils.data.TensorDataset( - torch.from_numpy(x_train), - torch.from_numpy(y_train), - ) - test_ds = torch.utils.data.TensorDataset( - torch.from_numpy(x_test), - torch.from_numpy(y_test), - ) + train_ds = torch.utils.data.TensorDataset( + torch.from_numpy(x_train), + torch.from_numpy(y_train), + ) + test_ds = torch.utils.data.TensorDataset( + torch.from_numpy(x_test), + torch.from_numpy(y_test), + ) print(f"Length of train ds: {len(train_ds)}") print(f"Length of test ds: {len(test_ds)}") @@ -233,14 +229,14 @@ def load_data(m, key): test_ds, batch_size=BATCH_SIZE, shuffle=True ) - return train_dl, test_dl + return train_dl, test_dl, attack_ds, membership if __name__ == '__main__': key = jax.random.PRNGKey(SEED) key, key1, key2 = jax.random.split(key, 3) - train_dl, test_dl = load_data(1000, key2) + train_dl, test_dl, attack_ds, membership = load_data(1000, key2) model = CNN(key1) optim = optax.adamw(LEARNING_RATE)