O1: fix inout dataloader

This commit is contained in:
Akemi Izuko 2024-12-03 16:53:33 -07:00
parent e9af7cacf1
commit 36deb4613b
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -6,7 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, Subset from torch.utils.data import DataLoader, Subset, TensorDataset
import torch.nn.functional as F import torch.nn.functional as F
from pathlib import Path from pathlib import Path
from torchvision import transforms from torchvision import transforms
@ -50,15 +50,17 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
p = np.random.permutation(len(train_ds)) p = np.random.permutation(len(train_ds))
x_in = train_ds.data[S[p]] # This is the training set x_in = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms
x_in = x_in[S[p]]
y_in = np.array(train_ds.targets).astype(np.int64)
y_in = y_in[S[p]]
x_m = np.full(len(train_ds), False) x_m = np.full(len(train_ds), False)
x_m[:m] = True x_m[:m] = True
x_m = x_m[p] # These are the points being guessed at x_m = x_m[p] # These are the points being guessed at
#print(f"Number of keep: {np.sum(S)}:{np.sum(~S)} for m={np.sum(x_m)}") td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
train_dl = DataLoader(x_in, batch_size=train_batch_size, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
return train_dl, test_dl, x_in, x_m, S[p] return train_dl, test_dl, x_in, x_m, S[p]
@ -84,7 +86,7 @@ def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, sch
scheduler.step() scheduler.step()
if epoch % 20 == 0 or epoch == hp['epochs'] - 1: if epoch % 10 == 0 or epoch == hp['epochs'] - 1:
with torch.no_grad(): with torch.no_grad():
correct = 0 correct = 0
total = 0 total = 0
@ -136,14 +138,11 @@ def train(hp):
) )
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size']) train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:8] = {S[:8]}") print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
print(f"Got x_in: {x_in.shape}") print(f"Got x_in: {x_in.shape}")
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:8] = {x_m[:8]}") print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:8] = {S[x_m][:8]}") print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
print(f"Got train dataloader: {len(train_dl)}") print(f"Got train dataloader: {len(train_dl)}")
exit(1)
print(f"Training with {hp['epochs']} epochs") print(f"Training with {hp['epochs']} epochs")
if hp['epsilon'] is not None: if hp['epsilon'] is not None:
@ -197,6 +196,7 @@ def main():
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True) parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
parser.add_argument('--cuda', type=int, help='gpu index', required=False) parser.add_argument('--cuda', type=int, help='gpu index', required=False)
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
parser.add_argument('--m', type=int, help='number of target points', required=True)
args = parser.parse_args() args = parser.parse_args()
if torch.cuda.is_available() and args.cuda: if torch.cuda.is_available() and args.cuda:
@ -207,7 +207,7 @@ def main():
DEVICE = torch.device('cpu') DEVICE = torch.device('cpu')
hyperparams = { hyperparams = {
"target_points": 100, "target_points": args.m,
"wrn_depth": 16, "wrn_depth": 16,
"wrn_width": 1, "wrn_width": 1,
"epsilon": args.epsilon, "epsilon": args.epsilon,