O1: insert attack points
This commit is contained in:
parent
e239602148
commit
2586c351d9
1 changed files with 28 additions and 5 deletions
|
@ -205,15 +205,39 @@ def get_dataloaders_raw(m=1000, train_batch_size=512, test_batch_size=10):
|
|||
train_ds = CIFAR10(root=DATADIR, train=True, download=True)
|
||||
test_ds = CIFAR10(root=DATADIR, train=False, download=True)
|
||||
|
||||
train_x = preprocess_data(train_ds.data)
|
||||
test_x = preprocess_data(test_ds.data)
|
||||
train_y = torch.tensor(train_ds.targets)
|
||||
test_y = torch.tensor(test_ds.targets)
|
||||
train_x = train_ds.data
|
||||
test_x = test_ds.data
|
||||
train_y = np.array(train_ds.targets)
|
||||
test_y = np.array(test_ds.targets)
|
||||
|
||||
mask = np.full(len(test_x), False)
|
||||
mask[:m] = True
|
||||
mask = mask[np.random.permutation(len(test_ds))]
|
||||
S = np.random.choice([True, False], size=m)
|
||||
|
||||
attack_x = test_x[mask][S]
|
||||
attack_y = test_y[mask][S]
|
||||
|
||||
for i in range(len(attack_y)):
|
||||
while True:
|
||||
c = np.random.choice(range(10))
|
||||
if attack_y[i] != c:
|
||||
attack_y[i] = c
|
||||
break
|
||||
|
||||
train_x = np.concatenate([train_x, attack_x])
|
||||
train_y = np.concatenate([train_y, attack_y])
|
||||
|
||||
train_x = preprocess_data(train_x)
|
||||
test_x = preprocess_data(test_x)
|
||||
train_y = torch.tensor(train_y)
|
||||
test_y = torch.tensor(test_y)
|
||||
|
||||
train_dl = DataLoader(
|
||||
TensorDataset(train_x, train_y.long()),
|
||||
batch_size=train_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
num_workers=4
|
||||
)
|
||||
test_dl = DataLoader(
|
||||
|
@ -494,7 +518,6 @@ def train_fast(hp):
|
|||
)
|
||||
|
||||
train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler)
|
||||
|
||||
return init_model, model
|
||||
|
||||
def train(hp, train_dl, test_dl):
|
||||
|
|
Loading…
Reference in a new issue