O1: fix inout dataloader
This commit is contained in:
parent
e9af7cacf1
commit
36deb4613b
1 changed files with 13 additions and 13 deletions
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch import optim
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torch.utils.data import DataLoader, Subset, TensorDataset
|
||||
import torch.nn.functional as F
|
||||
from pathlib import Path
|
||||
from torchvision import transforms
|
||||
|
@ -50,15 +50,17 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
|||
S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
|
||||
p = np.random.permutation(len(train_ds))
|
||||
|
||||
x_in = train_ds.data[S[p]] # This is the training set
|
||||
x_in = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms
|
||||
x_in = x_in[S[p]]
|
||||
y_in = np.array(train_ds.targets).astype(np.int64)
|
||||
y_in = y_in[S[p]]
|
||||
|
||||
x_m = np.full(len(train_ds), False)
|
||||
x_m[:m] = True
|
||||
x_m = x_m[p] # These are the points being guessed at
|
||||
|
||||
#print(f"Number of keep: {np.sum(S)}:{np.sum(~S)} for m={np.sum(x_m)}")
|
||||
|
||||
train_dl = DataLoader(x_in, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
|
||||
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||
|
||||
return train_dl, test_dl, x_in, x_m, S[p]
|
||||
|
@ -84,7 +86,7 @@ def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, sch
|
|||
|
||||
scheduler.step()
|
||||
|
||||
if epoch % 20 == 0 or epoch == hp['epochs'] - 1:
|
||||
if epoch % 10 == 0 or epoch == hp['epochs'] - 1:
|
||||
with torch.no_grad():
|
||||
correct = 0
|
||||
total = 0
|
||||
|
@ -136,14 +138,11 @@ def train(hp):
|
|||
)
|
||||
|
||||
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
|
||||
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:8] = {S[:8]}")
|
||||
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
|
||||
print(f"Got x_in: {x_in.shape}")
|
||||
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:8] = {x_m[:8]}")
|
||||
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:8] = {S[x_m][:8]}")
|
||||
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
|
||||
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
|
||||
print(f"Got train dataloader: {len(train_dl)}")
|
||||
|
||||
exit(1)
|
||||
|
||||
print(f"Training with {hp['epochs']} epochs")
|
||||
|
||||
if hp['epsilon'] is not None:
|
||||
|
@ -197,6 +196,7 @@ def main():
|
|||
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
|
||||
parser.add_argument('--cuda', type=int, help='gpu index', required=False)
|
||||
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
|
||||
parser.add_argument('--m', type=int, help='number of target points', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available() and args.cuda:
|
||||
|
@ -207,7 +207,7 @@ def main():
|
|||
DEVICE = torch.device('cpu')
|
||||
|
||||
hyperparams = {
|
||||
"target_points": 100,
|
||||
"target_points": args.m,
|
||||
"wrn_depth": 16,
|
||||
"wrn_width": 1,
|
||||
"epsilon": args.epsilon,
|
||||
|
|
Loading…
Reference in a new issue