O1: fix dataloader batch size
This commit is contained in:
parent
0d67830f7e
commit
e9af7cacf1
1 changed files with 26 additions and 13 deletions
|
@ -20,7 +20,7 @@ import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
DEVICE = torch.device("cpu")
|
DEVICE = None
|
||||||
|
|
||||||
|
|
||||||
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
|
@ -38,7 +38,6 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
])
|
])
|
||||||
|
|
||||||
test_transform = transforms.Compose([
|
test_transform = transforms.Compose([
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
|
@ -47,16 +46,22 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
||||||
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
||||||
|
|
||||||
keep = np.full(len(train_ds), True)
|
S = np.full(len(train_ds), True)
|
||||||
keep[:m] = False
|
S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
|
||||||
np.random.shuffle(keep)
|
p = np.random.permutation(len(train_ds))
|
||||||
|
|
||||||
train_ds_p = Subset(train_ds, keep)
|
x_in = train_ds.data[S[p]] # This is the training set
|
||||||
train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
|
||||||
train_dl_p = DataLoader(train_ds_p, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
x_m = np.full(len(train_ds), False)
|
||||||
|
x_m[:m] = True
|
||||||
|
x_m = x_m[p] # These are the points being guessed at
|
||||||
|
|
||||||
|
#print(f"Number of keep: {np.sum(S)}:{np.sum(~S)} for m={np.sum(x_m)}")
|
||||||
|
|
||||||
|
train_dl = DataLoader(x_in, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||||
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
return train_dl, train_dl_p, test_dl
|
return train_dl, test_dl, x_in, x_m, S[p]
|
||||||
|
|
||||||
|
|
||||||
def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler):
|
def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler):
|
||||||
|
@ -112,9 +117,9 @@ def train(hp):
|
||||||
output_features=16,
|
output_features=16,
|
||||||
strides=[1, 1, 2, 2],
|
strides=[1, 1, 2, 2],
|
||||||
)
|
)
|
||||||
|
model = model.to(DEVICE)
|
||||||
model = ModuleValidator.fix(model)
|
model = ModuleValidator.fix(model)
|
||||||
ModuleValidator.validate(model, strict=True)
|
ModuleValidator.validate(model, strict=True)
|
||||||
model = model.to(DEVICE)
|
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = optim.SGD(
|
optimizer = optim.SGD(
|
||||||
|
@ -130,7 +135,14 @@ def train(hp):
|
||||||
gamma=0.2
|
gamma=0.2
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dl, train_dl_p, test_dl = get_dataloaders()
|
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
|
||||||
|
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:8] = {S[:8]}")
|
||||||
|
print(f"Got x_in: {x_in.shape}")
|
||||||
|
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:8] = {x_m[:8]}")
|
||||||
|
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:8] = {S[x_m][:8]}")
|
||||||
|
print(f"Got train dataloader: {len(train_dl)}")
|
||||||
|
|
||||||
|
exit(1)
|
||||||
|
|
||||||
print(f"Training with {hp['epochs']} epochs")
|
print(f"Training with {hp['epochs']} epochs")
|
||||||
|
|
||||||
|
@ -151,13 +163,13 @@ def train(hp):
|
||||||
|
|
||||||
with BatchMemoryManager(
|
with BatchMemoryManager(
|
||||||
data_loader=train_loader,
|
data_loader=train_loader,
|
||||||
max_physical_batch_size=1000, # Roughly 12gb vram, uses 9.4
|
max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
|
||||||
optimizer=optimizer
|
optimizer=optimizer
|
||||||
) as memory_safe_data_loader:
|
) as memory_safe_data_loader:
|
||||||
best_test_set_accuracy = train_no_cap(
|
best_test_set_accuracy = train_no_cap(
|
||||||
model,
|
model,
|
||||||
hp,
|
hp,
|
||||||
train_dl,
|
memory_safe_data_loader,
|
||||||
test_dl,
|
test_dl,
|
||||||
optimizer,
|
optimizer,
|
||||||
criterion,
|
criterion,
|
||||||
|
@ -195,6 +207,7 @@ def main():
|
||||||
DEVICE = torch.device('cpu')
|
DEVICE = torch.device('cpu')
|
||||||
|
|
||||||
hyperparams = {
|
hyperparams = {
|
||||||
|
"target_points": 100,
|
||||||
"wrn_depth": 16,
|
"wrn_depth": 16,
|
||||||
"wrn_width": 1,
|
"wrn_width": 1,
|
||||||
"epsilon": args.epsilon,
|
"epsilon": args.epsilon,
|
||||||
|
|
Loading…
Reference in a new issue