Main: return poison data from loader
This commit is contained in:
parent
000d7ccff2
commit
ba4e063cd3
1 changed files with 14 additions and 18 deletions
16
src/main.py
16
src/main.py
|
@ -200,20 +200,16 @@ def load_data(m, key):
|
||||||
torch.from_numpy(x_attack),
|
torch.from_numpy(x_attack),
|
||||||
torch.from_numpy(y_attack)
|
torch.from_numpy(y_attack)
|
||||||
)
|
)
|
||||||
train_ds = torch.utils.data.TensorDataset(
|
x_train = np.concatenate([x_train, x_attack[membership]])
|
||||||
torch.from_numpy(np.concatenate([x_train, x_attack[membership]])),
|
y_train = np.concatenate([y_train, y_attack[membership]])
|
||||||
torch.from_numpy(np.concatenate([y_train, y_attack[membership]])),
|
|
||||||
)
|
|
||||||
test_ds = torch.utils.data.TensorDataset(
|
|
||||||
torch.from_numpy(x_test),
|
|
||||||
torch.from_numpy(y_test),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Idk why, but it gets a LOT faster this way... maybe from pre-applying the transforms?
|
# Idk why, but it gets a LOT faster this way... maybe from pre-applying the transforms?
|
||||||
x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))])
|
x_train = np.stack([train_ds[i][0].numpy() for i in range(len(train_ds))])
|
||||||
y_train = train_ds.targets.numpy().astype(np.int64)
|
y_train = train_ds.targets.numpy().astype(np.int64)
|
||||||
x_test = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))])
|
x_test = np.stack([test_ds[i][0].numpy() for i in range(len(test_ds))])
|
||||||
y_test = test_ds.targets.numpy().astype(np.int64)
|
y_test = test_ds.targets.numpy().astype(np.int64)
|
||||||
|
attack_ds = torch.utils.data.TensorDataset()
|
||||||
|
membership = np.empty(0)
|
||||||
|
|
||||||
train_ds = torch.utils.data.TensorDataset(
|
train_ds = torch.utils.data.TensorDataset(
|
||||||
torch.from_numpy(x_train),
|
torch.from_numpy(x_train),
|
||||||
|
@ -233,14 +229,14 @@ def load_data(m, key):
|
||||||
test_ds, batch_size=BATCH_SIZE, shuffle=True
|
test_ds, batch_size=BATCH_SIZE, shuffle=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dl, test_dl
|
return train_dl, test_dl, attack_ds, membership
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
key = jax.random.PRNGKey(SEED)
|
key = jax.random.PRNGKey(SEED)
|
||||||
key, key1, key2 = jax.random.split(key, 3)
|
key, key1, key2 = jax.random.split(key, 3)
|
||||||
|
|
||||||
train_dl, test_dl = load_data(1000, key2)
|
train_dl, test_dl, attack_ds, membership = load_data(1000, key2)
|
||||||
|
|
||||||
model = CNN(key1)
|
model = CNN(key1)
|
||||||
optim = optax.adamw(LEARNING_RATE)
|
optim = optax.adamw(LEARNING_RATE)
|
||||||
|
|
Loading…
Reference in a new issue