Main: return poison data from loader
This commit is contained in:
parent
000d7ccff2
commit
ba4e063cd3
1 changed files with 14 additions and 18 deletions
16
src/main.py
16
src/main.py
|
@ -200,20 +200,16 @@ def load_data(m, key):
|
|||
torch.from_numpy(x_attack),
|
||||
torch.from_numpy(y_attack)
|
||||
)
|
||||
train_ds = torch.utils.data.TensorDataset(
|
||||
torch.from_numpy(np.concatenate([x_train, x_attack[membership]])),
|
||||
torch.from_numpy(np.concatenate([y_train, y_attack[membership]])),
|
||||
)
|
||||
test_ds = torch.utils.data.TensorDataset(
|
||||
torch.from_numpy(x_test),
|
||||
torch.from_numpy(y_test),
|
||||
)
|
||||
x_train = np.concatenate([x_train, x_attack[membership]])
|
||||
y_train = np.concatenate([y_train, y_attack[membership]])
|
||||
else:
|
||||
# Idk why, but it gets a LOT faster this way... maybe from pre-applying the transforms?
|
||||
x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))])
|
||||
y_train = train_ds.targets.numpy().astype(np.int64)
|
||||
x_test = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))])
|
||||
y_test = test_ds.targets.numpy().astype(np.int64)
|
||||
attack_ds = torch.utils.data.TensorDataset()
|
||||
membership = np.empty(0)
|
||||
|
||||
train_ds = torch.utils.data.TensorDataset(
|
||||
torch.from_numpy(x_train),
|
||||
|
@ -233,14 +229,14 @@ def load_data(m, key):
|
|||
test_ds, batch_size=BATCH_SIZE, shuffle=True
|
||||
)
|
||||
|
||||
return train_dl, test_dl
|
||||
return train_dl, test_dl, attack_ds, membership
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
key = jax.random.PRNGKey(SEED)
|
||||
key, key1, key2 = jax.random.split(key, 3)
|
||||
|
||||
train_dl, test_dl = load_data(1000, key2)
|
||||
train_dl, test_dl, attack_ds, membership = load_data(1000, key2)
|
||||
|
||||
model = CNN(key1)
|
||||
optim = optax.adamw(LEARNING_RATE)
|
||||
|
|
Loading…
Reference in a new issue