diff --git a/one_run_audit/audit.py b/one_run_audit/audit.py index 3b0906b..5c35ca2 100644 --- a/one_run_audit/audit.py +++ b/one_run_audit/audit.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn from torch import optim from torch.optim.lr_scheduler import MultiStepLR -from torch.utils.data import DataLoader, Subset +from torch.utils.data import DataLoader, Subset, TensorDataset import torch.nn.functional as F from pathlib import Path from torchvision import transforms @@ -50,15 +50,17 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out p = np.random.permutation(len(train_ds)) - x_in = train_ds.data[S[p]] # This is the training set + x_in = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms + x_in = x_in[S[p]] + y_in = np.array(train_ds.targets).astype(np.int64) + y_in = y_in[S[p]] x_m = np.full(len(train_ds), False) x_m[:m] = True x_m = x_m[p] # These are the points being guessed at - #print(f"Number of keep: {np.sum(S)}:{np.sum(~S)} for m={np.sum(x_m)}") - - train_dl = DataLoader(x_in, batch_size=train_batch_size, shuffle=True, num_workers=4) + td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long()) + train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4) return train_dl, test_dl, x_in, x_m, S[p] @@ -84,7 +86,7 @@ def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, sch scheduler.step() - if epoch % 20 == 0 or epoch == hp['epochs'] - 1: + if epoch % 10 == 0 or epoch == hp['epochs'] - 1: with torch.no_grad(): correct = 0 total = 0 @@ -136,14 +138,11 @@ def train(hp): ) train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size']) - print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:8] = {S[:8]}") + print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}") print(f"Got x_in: {x_in.shape}") - print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:8] = {x_m[:8]}") - print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:8] = {S[x_m][:8]}") + print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}") + print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}") print(f"Got train dataloader: {len(train_dl)}") - - exit(1) - print(f"Training with {hp['epochs']} epochs") if hp['epsilon'] is not None: @@ -197,6 +196,7 @@ def main(): parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True) parser.add_argument('--cuda', type=int, help='gpu index', required=False) parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) + parser.add_argument('--m', type=int, help='number of target points', required=True) args = parser.parse_args() if torch.cuda.is_available() and args.cuda: @@ -207,7 +207,7 @@ def main(): DEVICE = torch.device('cpu') hyperparams = { - "target_points": 100, + "target_points": args.m, "wrn_depth": 16, "wrn_width": 1, "epsilon": args.epsilon,