O1: fix inout dataloader
This commit is contained in:
parent
e9af7cacf1
commit
36deb4613b
1 changed files with 13 additions and 13 deletions
|
@ -6,7 +6,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.optim.lr_scheduler import MultiStepLR
|
from torch.optim.lr_scheduler import MultiStepLR
|
||||||
from torch.utils.data import DataLoader, Subset
|
from torch.utils.data import DataLoader, Subset, TensorDataset
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
@ -50,15 +50,17 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
|
S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
|
||||||
p = np.random.permutation(len(train_ds))
|
p = np.random.permutation(len(train_ds))
|
||||||
|
|
||||||
x_in = train_ds.data[S[p]] # This is the training set
|
x_in = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms
|
||||||
|
x_in = x_in[S[p]]
|
||||||
|
y_in = np.array(train_ds.targets).astype(np.int64)
|
||||||
|
y_in = y_in[S[p]]
|
||||||
|
|
||||||
x_m = np.full(len(train_ds), False)
|
x_m = np.full(len(train_ds), False)
|
||||||
x_m[:m] = True
|
x_m[:m] = True
|
||||||
x_m = x_m[p] # These are the points being guessed at
|
x_m = x_m[p] # These are the points being guessed at
|
||||||
|
|
||||||
#print(f"Number of keep: {np.sum(S)}:{np.sum(~S)} for m={np.sum(x_m)}")
|
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
|
||||||
|
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||||
train_dl = DataLoader(x_in, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
|
||||||
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
return train_dl, test_dl, x_in, x_m, S[p]
|
return train_dl, test_dl, x_in, x_m, S[p]
|
||||||
|
@ -84,7 +86,7 @@ def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, sch
|
||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
if epoch % 20 == 0 or epoch == hp['epochs'] - 1:
|
if epoch % 10 == 0 or epoch == hp['epochs'] - 1:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
correct = 0
|
correct = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
@ -136,14 +138,11 @@ def train(hp):
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
|
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
|
||||||
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:8] = {S[:8]}")
|
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
|
||||||
print(f"Got x_in: {x_in.shape}")
|
print(f"Got x_in: {x_in.shape}")
|
||||||
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:8] = {x_m[:8]}")
|
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
|
||||||
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:8] = {S[x_m][:8]}")
|
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
|
||||||
print(f"Got train dataloader: {len(train_dl)}")
|
print(f"Got train dataloader: {len(train_dl)}")
|
||||||
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
print(f"Training with {hp['epochs']} epochs")
|
print(f"Training with {hp['epochs']} epochs")
|
||||||
|
|
||||||
if hp['epsilon'] is not None:
|
if hp['epsilon'] is not None:
|
||||||
|
@ -197,6 +196,7 @@ def main():
|
||||||
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
|
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
|
||||||
parser.add_argument('--cuda', type=int, help='gpu index', required=False)
|
parser.add_argument('--cuda', type=int, help='gpu index', required=False)
|
||||||
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
|
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
|
||||||
|
parser.add_argument('--m', type=int, help='number of target points', required=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if torch.cuda.is_available() and args.cuda:
|
if torch.cuda.is_available() and args.cuda:
|
||||||
|
@ -207,7 +207,7 @@ def main():
|
||||||
DEVICE = torch.device('cpu')
|
DEVICE = torch.device('cpu')
|
||||||
|
|
||||||
hyperparams = {
|
hyperparams = {
|
||||||
"target_points": 100,
|
"target_points": args.m,
|
||||||
"wrn_depth": 16,
|
"wrn_depth": 16,
|
||||||
"wrn_width": 1,
|
"wrn_width": 1,
|
||||||
"epsilon": args.epsilon,
|
"epsilon": args.epsilon,
|
||||||
|
|
Loading…
Reference in a new issue