O1: insert attack points
This commit is contained in:
parent
e239602148
commit
2586c351d9
1 changed files with 28 additions and 5 deletions
|
@ -205,15 +205,39 @@ def get_dataloaders_raw(m=1000, train_batch_size=512, test_batch_size=10):
|
||||||
train_ds = CIFAR10(root=DATADIR, train=True, download=True)
|
train_ds = CIFAR10(root=DATADIR, train=True, download=True)
|
||||||
test_ds = CIFAR10(root=DATADIR, train=False, download=True)
|
test_ds = CIFAR10(root=DATADIR, train=False, download=True)
|
||||||
|
|
||||||
train_x = preprocess_data(train_ds.data)
|
train_x = train_ds.data
|
||||||
test_x = preprocess_data(test_ds.data)
|
test_x = test_ds.data
|
||||||
train_y = torch.tensor(train_ds.targets)
|
train_y = np.array(train_ds.targets)
|
||||||
test_y = torch.tensor(test_ds.targets)
|
test_y = np.array(test_ds.targets)
|
||||||
|
|
||||||
|
mask = np.full(len(test_x), False)
|
||||||
|
mask[:m] = True
|
||||||
|
mask = mask[np.random.permutation(len(test_ds))]
|
||||||
|
S = np.random.choice([True, False], size=m)
|
||||||
|
|
||||||
|
attack_x = test_x[mask][S]
|
||||||
|
attack_y = test_y[mask][S]
|
||||||
|
|
||||||
|
for i in range(len(attack_y)):
|
||||||
|
while True:
|
||||||
|
c = np.random.choice(range(10))
|
||||||
|
if attack_y[i] != c:
|
||||||
|
attack_y[i] = c
|
||||||
|
break
|
||||||
|
|
||||||
|
train_x = np.concatenate([train_x, attack_x])
|
||||||
|
train_y = np.concatenate([train_y, attack_y])
|
||||||
|
|
||||||
|
train_x = preprocess_data(train_x)
|
||||||
|
test_x = preprocess_data(test_x)
|
||||||
|
train_y = torch.tensor(train_y)
|
||||||
|
test_y = torch.tensor(test_y)
|
||||||
|
|
||||||
train_dl = DataLoader(
|
train_dl = DataLoader(
|
||||||
TensorDataset(train_x, train_y.long()),
|
TensorDataset(train_x, train_y.long()),
|
||||||
batch_size=train_batch_size,
|
batch_size=train_batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
num_workers=4
|
num_workers=4
|
||||||
)
|
)
|
||||||
test_dl = DataLoader(
|
test_dl = DataLoader(
|
||||||
|
@ -494,7 +518,6 @@ def train_fast(hp):
|
||||||
)
|
)
|
||||||
|
|
||||||
train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler)
|
train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler)
|
||||||
|
|
||||||
return init_model, model
|
return init_model, model
|
||||||
|
|
||||||
def train(hp, train_dl, test_dl):
|
def train(hp, train_dl, test_dl):
|
||||||
|
|
Loading…
Reference in a new issue