From 2586c351d9734642e1267f361a7b463015f61b1b Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Fri, 6 Dec 2024 19:11:54 -0700 Subject: [PATCH] O1: insert attack points --- one_run_audit/audit.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/one_run_audit/audit.py b/one_run_audit/audit.py index 233f681..fccd800 100644 --- a/one_run_audit/audit.py +++ b/one_run_audit/audit.py @@ -205,15 +205,39 @@ def get_dataloaders_raw(m=1000, train_batch_size=512, test_batch_size=10): train_ds = CIFAR10(root=DATADIR, train=True, download=True) test_ds = CIFAR10(root=DATADIR, train=False, download=True) - train_x = preprocess_data(train_ds.data) - test_x = preprocess_data(test_ds.data) - train_y = torch.tensor(train_ds.targets) - test_y = torch.tensor(test_ds.targets) + train_x = train_ds.data + test_x = test_ds.data + train_y = np.array(train_ds.targets) + test_y = np.array(test_ds.targets) + + mask = np.full(len(test_x), False) + mask[:m] = True + mask = mask[np.random.permutation(len(test_ds))] + S = np.random.choice([True, False], size=m) + + attack_x = test_x[mask][S] + attack_y = test_y[mask][S] + + for i in range(len(attack_y)): + while True: + c = np.random.choice(range(10)) + if attack_y[i] != c: + attack_y[i] = c + break + + train_x = np.concatenate([train_x, attack_x]) + train_y = np.concatenate([train_y, attack_y]) + + train_x = preprocess_data(train_x) + test_x = preprocess_data(test_x) + train_y = torch.tensor(train_y) + test_y = torch.tensor(test_y) train_dl = DataLoader( TensorDataset(train_x, train_y.long()), batch_size=train_batch_size, shuffle=True, + drop_last=True, num_workers=4 ) test_dl = DataLoader( @@ -494,7 +518,6 @@ def train_fast(hp): ) train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler) - return init_model, model def train(hp, train_dl, test_dl):