O1: insert attack points

This commit is contained in:
Akemi Izuko 2024-12-06 19:11:54 -07:00
parent e239602148
commit 2586c351d9
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -205,15 +205,39 @@ def get_dataloaders_raw(m=1000, train_batch_size=512, test_batch_size=10):
train_ds = CIFAR10(root=DATADIR, train=True, download=True)
test_ds = CIFAR10(root=DATADIR, train=False, download=True)
train_x = preprocess_data(train_ds.data)
test_x = preprocess_data(test_ds.data)
train_y = torch.tensor(train_ds.targets)
test_y = torch.tensor(test_ds.targets)
train_x = train_ds.data
test_x = test_ds.data
train_y = np.array(train_ds.targets)
test_y = np.array(test_ds.targets)
mask = np.full(len(test_x), False)
mask[:m] = True
mask = mask[np.random.permutation(len(test_ds))]
S = np.random.choice([True, False], size=m)
attack_x = test_x[mask][S]
attack_y = test_y[mask][S]
for i in range(len(attack_y)):
while True:
c = np.random.choice(range(10))
if attack_y[i] != c:
attack_y[i] = c
break
train_x = np.concatenate([train_x, attack_x])
train_y = np.concatenate([train_y, attack_y])
train_x = preprocess_data(train_x)
test_x = preprocess_data(test_x)
train_y = torch.tensor(train_y)
test_y = torch.tensor(test_y)
train_dl = DataLoader(
TensorDataset(train_x, train_y.long()),
batch_size=train_batch_size,
shuffle=True,
drop_last=True,
num_workers=4
)
test_dl = DataLoader(
@ -494,7 +518,6 @@ def train_fast(hp):
)
train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler)
return init_model, model
def train(hp, train_dl, test_dl):