O1: fix dataloader batch size
This commit is contained in:
parent
0d67830f7e
commit
e9af7cacf1
1 changed files with 26 additions and 13 deletions
|
@ -20,7 +20,7 @@ import warnings
|
|||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
DEVICE = torch.device("cpu")
|
||||
DEVICE = None
|
||||
|
||||
|
||||
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||
|
@ -38,7 +38,6 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
|||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
test_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
|
@ -47,16 +46,22 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
|||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
||||
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
||||
|
||||
keep = np.full(len(train_ds), True)
|
||||
keep[:m] = False
|
||||
np.random.shuffle(keep)
|
||||
S = np.full(len(train_ds), True)
|
||||
S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
|
||||
p = np.random.permutation(len(train_ds))
|
||||
|
||||
train_ds_p = Subset(train_ds, keep)
|
||||
train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||
train_dl_p = DataLoader(train_ds_p, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||
x_in = train_ds.data[S[p]] # This is the training set
|
||||
|
||||
x_m = np.full(len(train_ds), False)
|
||||
x_m[:m] = True
|
||||
x_m = x_m[p] # These are the points being guessed at
|
||||
|
||||
#print(f"Number of keep: {np.sum(S)}:{np.sum(~S)} for m={np.sum(x_m)}")
|
||||
|
||||
train_dl = DataLoader(x_in, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||
|
||||
return train_dl, train_dl_p, test_dl
|
||||
return train_dl, test_dl, x_in, x_m, S[p]
|
||||
|
||||
|
||||
def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler):
|
||||
|
@ -112,9 +117,9 @@ def train(hp):
|
|||
output_features=16,
|
||||
strides=[1, 1, 2, 2],
|
||||
)
|
||||
model = model.to(DEVICE)
|
||||
model = ModuleValidator.fix(model)
|
||||
ModuleValidator.validate(model, strict=True)
|
||||
model = model.to(DEVICE)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(
|
||||
|
@ -130,7 +135,14 @@ def train(hp):
|
|||
gamma=0.2
|
||||
)
|
||||
|
||||
train_dl, train_dl_p, test_dl = get_dataloaders()
|
||||
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
|
||||
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:8] = {S[:8]}")
|
||||
print(f"Got x_in: {x_in.shape}")
|
||||
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:8] = {x_m[:8]}")
|
||||
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:8] = {S[x_m][:8]}")
|
||||
print(f"Got train dataloader: {len(train_dl)}")
|
||||
|
||||
exit(1)
|
||||
|
||||
print(f"Training with {hp['epochs']} epochs")
|
||||
|
||||
|
@ -151,13 +163,13 @@ def train(hp):
|
|||
|
||||
with BatchMemoryManager(
|
||||
data_loader=train_loader,
|
||||
max_physical_batch_size=1000, # Roughly 12gb vram, uses 9.4
|
||||
max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
|
||||
optimizer=optimizer
|
||||
) as memory_safe_data_loader:
|
||||
best_test_set_accuracy = train_no_cap(
|
||||
model,
|
||||
hp,
|
||||
train_dl,
|
||||
memory_safe_data_loader,
|
||||
test_dl,
|
||||
optimizer,
|
||||
criterion,
|
||||
|
@ -195,6 +207,7 @@ def main():
|
|||
DEVICE = torch.device('cpu')
|
||||
|
||||
hyperparams = {
|
||||
"target_points": 100,
|
||||
"wrn_depth": 16,
|
||||
"wrn_width": 1,
|
||||
"epsilon": args.epsilon,
|
||||
|
|
Loading…
Reference in a new issue