diff --git a/one_run_audit/audit.py b/one_run_audit/audit.py index a619025..3b0906b 100644 --- a/one_run_audit/audit.py +++ b/one_run_audit/audit.py @@ -20,7 +20,7 @@ import warnings warnings.filterwarnings("ignore") -DEVICE = torch.device("cpu") +DEVICE = None def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): @@ -38,7 +38,6 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) - test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), @@ -47,16 +46,22 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform) - keep = np.full(len(train_ds), True) - keep[:m] = False - np.random.shuffle(keep) + S = np.full(len(train_ds), True) + S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out + p = np.random.permutation(len(train_ds)) - train_ds_p = Subset(train_ds, keep) - train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4) - train_dl_p = DataLoader(train_ds_p, batch_size=train_batch_size, shuffle=True, num_workers=4) + x_in = train_ds.data[S[p]] # This is the training set + + x_m = np.full(len(train_ds), False) + x_m[:m] = True + x_m = x_m[p] # These are the points being guessed at + + #print(f"Number of keep: {np.sum(S)}:{np.sum(~S)} for m={np.sum(x_m)}") + + train_dl = DataLoader(x_in, batch_size=train_batch_size, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4) - return train_dl, train_dl_p, test_dl + return train_dl, test_dl, x_in, x_m, S[p] def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler): @@ -112,9 +117,9 @@ def train(hp): output_features=16, strides=[1, 1, 2, 2], ) + model = model.to(DEVICE) model = ModuleValidator.fix(model) ModuleValidator.validate(model, strict=True) - model = model.to(DEVICE) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD( @@ -130,7 +135,14 @@ def train(hp): gamma=0.2 ) - train_dl, train_dl_p, test_dl = get_dataloaders() + train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size']) + print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:8] = {S[:8]}") + print(f"Got x_in: {x_in.shape}") + print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:8] = {x_m[:8]}") + print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:8] = {S[x_m][:8]}") + print(f"Got train dataloader: {len(train_dl)}") + + exit(1) print(f"Training with {hp['epochs']} epochs") @@ -151,13 +163,13 @@ def train(hp): with BatchMemoryManager( data_loader=train_loader, - max_physical_batch_size=1000, # Roughly 12gb vram, uses 9.4 + max_physical_batch_size=2000, # 1000 ~= 9.4GB vram optimizer=optimizer ) as memory_safe_data_loader: best_test_set_accuracy = train_no_cap( model, hp, - train_dl, + memory_safe_data_loader, test_dl, optimizer, criterion, @@ -195,6 +207,7 @@ def main(): DEVICE = torch.device('cpu') hyperparams = { + "target_points": 100, "wrn_depth": 16, "wrn_width": 1, "epsilon": args.epsilon,