O1: fix dataloader batch size

This commit is contained in:
Akemi Izuko 2024-12-03 13:01:38 -07:00
parent 0d67830f7e
commit e9af7cacf1
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -20,7 +20,7 @@ import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
DEVICE = torch.device("cpu") DEVICE = None
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
@ -38,7 +38,6 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]) ])
test_transform = transforms.Compose([ test_transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
@ -47,16 +46,22 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform) test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
keep = np.full(len(train_ds), True) S = np.full(len(train_ds), True)
keep[:m] = False S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
np.random.shuffle(keep) p = np.random.permutation(len(train_ds))
train_ds_p = Subset(train_ds, keep) x_in = train_ds.data[S[p]] # This is the training set
train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
train_dl_p = DataLoader(train_ds_p, batch_size=train_batch_size, shuffle=True, num_workers=4) x_m = np.full(len(train_ds), False)
x_m[:m] = True
x_m = x_m[p] # These are the points being guessed at
#print(f"Number of keep: {np.sum(S)}:{np.sum(~S)} for m={np.sum(x_m)}")
train_dl = DataLoader(x_in, batch_size=train_batch_size, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
return train_dl, train_dl_p, test_dl return train_dl, test_dl, x_in, x_m, S[p]
def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler): def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler):
@ -112,9 +117,9 @@ def train(hp):
output_features=16, output_features=16,
strides=[1, 1, 2, 2], strides=[1, 1, 2, 2],
) )
model = model.to(DEVICE)
model = ModuleValidator.fix(model) model = ModuleValidator.fix(model)
ModuleValidator.validate(model, strict=True) ModuleValidator.validate(model, strict=True)
model = model.to(DEVICE)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD( optimizer = optim.SGD(
@ -130,7 +135,14 @@ def train(hp):
gamma=0.2 gamma=0.2
) )
train_dl, train_dl_p, test_dl = get_dataloaders() train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:8] = {S[:8]}")
print(f"Got x_in: {x_in.shape}")
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:8] = {x_m[:8]}")
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:8] = {S[x_m][:8]}")
print(f"Got train dataloader: {len(train_dl)}")
exit(1)
print(f"Training with {hp['epochs']} epochs") print(f"Training with {hp['epochs']} epochs")
@ -151,13 +163,13 @@ def train(hp):
with BatchMemoryManager( with BatchMemoryManager(
data_loader=train_loader, data_loader=train_loader,
max_physical_batch_size=1000, # Roughly 12gb vram, uses 9.4 max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
optimizer=optimizer optimizer=optimizer
) as memory_safe_data_loader: ) as memory_safe_data_loader:
best_test_set_accuracy = train_no_cap( best_test_set_accuracy = train_no_cap(
model, model,
hp, hp,
train_dl, memory_safe_data_loader,
test_dl, test_dl,
optimizer, optimizer,
criterion, criterion,
@ -195,6 +207,7 @@ def main():
DEVICE = torch.device('cpu') DEVICE = torch.device('cpu')
hyperparams = { hyperparams = {
"target_points": 100,
"wrn_depth": 16, "wrn_depth": 16,
"wrn_width": 1, "wrn_width": 1,
"epsilon": args.epsilon, "epsilon": args.epsilon,