Main: return poison data from loader

This commit is contained in:
Akemi Izuko 2024-12-22 19:31:49 -07:00
parent 000d7ccff2
commit ba4e063cd3
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -200,29 +200,25 @@ def load_data(m, key):
torch.from_numpy(x_attack), torch.from_numpy(x_attack),
torch.from_numpy(y_attack) torch.from_numpy(y_attack)
) )
train_ds = torch.utils.data.TensorDataset( x_train = np.concatenate([x_train, x_attack[membership]])
torch.from_numpy(np.concatenate([x_train, x_attack[membership]])), y_train = np.concatenate([y_train, y_attack[membership]])
torch.from_numpy(np.concatenate([y_train, y_attack[membership]])),
)
test_ds = torch.utils.data.TensorDataset(
torch.from_numpy(x_test),
torch.from_numpy(y_test),
)
else: else:
# Idk why, but it gets a LOT faster this way... maybe from pre-applying the transforms? # Idk why, but it gets a LOT faster this way... maybe from pre-applying the transforms?
x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))]) x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))])
y_train = train_ds.targets.numpy().astype(np.int64) y_train = train_ds.targets.numpy().astype(np.int64)
x_test = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))]) x_test = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))])
y_test = test_ds.targets.numpy().astype(np.int64) y_test = test_ds.targets.numpy().astype(np.int64)
attack_ds = torch.utils.data.TensorDataset()
membership = np.empty(0)
train_ds = torch.utils.data.TensorDataset( train_ds = torch.utils.data.TensorDataset(
torch.from_numpy(x_train), torch.from_numpy(x_train),
torch.from_numpy(y_train), torch.from_numpy(y_train),
) )
test_ds = torch.utils.data.TensorDataset( test_ds = torch.utils.data.TensorDataset(
torch.from_numpy(x_test), torch.from_numpy(x_test),
torch.from_numpy(y_test), torch.from_numpy(y_test),
) )
print(f"Length of train ds: {len(train_ds)}") print(f"Length of train ds: {len(train_ds)}")
print(f"Length of test ds: {len(test_ds)}") print(f"Length of test ds: {len(test_ds)}")
@ -233,14 +229,14 @@ def load_data(m, key):
test_ds, batch_size=BATCH_SIZE, shuffle=True test_ds, batch_size=BATCH_SIZE, shuffle=True
) )
return train_dl, test_dl return train_dl, test_dl, attack_ds, membership
if __name__ == '__main__': if __name__ == '__main__':
key = jax.random.PRNGKey(SEED) key = jax.random.PRNGKey(SEED)
key, key1, key2 = jax.random.split(key, 3) key, key1, key2 = jax.random.split(key, 3)
train_dl, test_dl = load_data(1000, key2) train_dl, test_dl, attack_ds, membership = load_data(1000, key2)
model = CNN(key1) model = CNN(key1)
optim = optax.adamw(LEARNING_RATE) optim = optax.adamw(LEARNING_RATE)