Main: return poison data from loader

This commit is contained in:
Akemi Izuko 2024-12-22 19:31:49 -07:00
parent 000d7ccff2
commit ba4e063cd3
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -200,20 +200,16 @@ def load_data(m, key):
torch.from_numpy(x_attack),
torch.from_numpy(y_attack)
)
train_ds = torch.utils.data.TensorDataset(
torch.from_numpy(np.concatenate([x_train, x_attack[membership]])),
torch.from_numpy(np.concatenate([y_train, y_attack[membership]])),
)
test_ds = torch.utils.data.TensorDataset(
torch.from_numpy(x_test),
torch.from_numpy(y_test),
)
x_train = np.concatenate([x_train, x_attack[membership]])
y_train = np.concatenate([y_train, y_attack[membership]])
else:
# Idk why, but it gets a LOT faster this way... maybe from pre-applying the transforms?
x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))])
y_train = train_ds.targets.numpy().astype(np.int64)
x_test = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))])
y_test = test_ds.targets.numpy().astype(np.int64)
attack_ds = torch.utils.data.TensorDataset()
membership = np.empty(0)
train_ds = torch.utils.data.TensorDataset(
torch.from_numpy(x_train),
@ -233,14 +229,14 @@ def load_data(m, key):
test_ds, batch_size=BATCH_SIZE, shuffle=True
)
return train_dl, test_dl
return train_dl, test_dl, attack_ds, membership
if __name__ == '__main__':
key = jax.random.PRNGKey(SEED)
key, key1, key2 = jax.random.split(key, 3)
train_dl, test_dl = load_data(1000, key2)
train_dl, test_dl, attack_ds, membership = load_data(1000, key2)
model = CNN(key1)
optim = optax.adamw(LEARNING_RATE)