O1: return target point labels
This commit is contained in:
parent
d606245ad1
commit
4692502763
1 changed files with 47 additions and 12 deletions
|
@ -47,24 +47,32 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
||||||
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
||||||
|
|
||||||
|
# Original dataset
|
||||||
|
x = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms
|
||||||
|
|
||||||
|
# Choose m points to randomly exclude at chance
|
||||||
S = np.full(len(train_ds), True)
|
S = np.full(len(train_ds), True)
|
||||||
S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
|
S[:m] = np.random.choice([True, False], size=m) # Vector of determining if each point is in or out
|
||||||
p = np.random.permutation(len(train_ds))
|
p = np.random.permutation(len(train_ds))
|
||||||
|
|
||||||
x_in = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms
|
# Store the m points which could have been included/excluded
|
||||||
x_in = x_in[S[p]]
|
mask = np.full(len(train_ds), False)
|
||||||
|
mask[:m] = True
|
||||||
|
mask = mask[p]
|
||||||
|
|
||||||
|
x_m = x[mask] # These are the points being guessed at
|
||||||
|
y_m = np.array(train_ds.targets)[mask].astype(np.int64)
|
||||||
|
|
||||||
|
# Remove excluded points from dataset
|
||||||
|
x_in = x[S[p]]
|
||||||
y_in = np.array(train_ds.targets).astype(np.int64)
|
y_in = np.array(train_ds.targets).astype(np.int64)
|
||||||
y_in = y_in[S[p]]
|
y_in = y_in[S[p]]
|
||||||
|
|
||||||
x_m = np.full(len(train_ds), False)
|
|
||||||
x_m[:m] = True
|
|
||||||
x_m = x_m[p] # These are the points being guessed at
|
|
||||||
|
|
||||||
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
|
td = TensorDataset(torch.from_numpy(x_in), torch.from_numpy(y_in).long())
|
||||||
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||||
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
return train_dl, test_dl, x_in, x_m, S[p]
|
return train_dl, test_dl, x_in, x_m, y_m, S[p]
|
||||||
|
|
||||||
|
|
||||||
def evaluate_on(model, dataloader):
|
def evaluate_on(model, dataloader):
|
||||||
|
@ -230,20 +238,47 @@ def main():
|
||||||
hp['norm'],
|
hp['norm'],
|
||||||
))
|
))
|
||||||
|
|
||||||
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
|
train_dl, test_dl, x_in, x_m, y_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
|
||||||
|
print(f"len train: {len(train_dl)}")
|
||||||
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
|
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
|
||||||
print(f"Got x_in: {x_in.shape}")
|
print(f"Got x_in: {x_in.shape}")
|
||||||
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
|
print(f"Got x_m: {x_m.shape}")
|
||||||
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
|
print(f"Got y_m: {y_m.shape}")
|
||||||
print(f"Got train dataloader: {len(train_dl)}")
|
|
||||||
|
for x, y in train_dl:
|
||||||
|
print(f"dl x shape: {x.shape}")
|
||||||
|
print(f"dl y shape: {y.shape}")
|
||||||
|
break
|
||||||
|
|
||||||
model_init, model_trained = train(hp, train_dl, test_dl)
|
model_init, model_trained = train(hp, train_dl, test_dl)
|
||||||
|
|
||||||
|
# torch.save(model_init.state_dict(), "data/init_model.pt")
|
||||||
|
# torch.save(model_trained.state_dict(), "data/trained_model.pt")
|
||||||
|
|
||||||
|
scores = list()
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
with torch.no_grad():
|
||||||
|
model_init.eval()
|
||||||
|
x_m = torch.from_numpy(x_m).to(DEVICE)
|
||||||
|
y_m = torch.from_numpy(y_m).long().to(DEVICE)
|
||||||
|
|
||||||
|
for i in range(len(x_m)):
|
||||||
|
x_point = x_m[i].unsqueeze(0)
|
||||||
|
y_point = y_m[i].unsqueeze(0)
|
||||||
|
|
||||||
|
init_loss = criterion(model_init(x_point)[0], y_point)
|
||||||
|
trained_loss = criterion(model_trained(x_point)[0], y_point)
|
||||||
|
|
||||||
|
scores.append(init_loss - trained_loss)
|
||||||
|
|
||||||
|
print(len(scores))
|
||||||
|
print(scores[:10])
|
||||||
|
|
||||||
correct, total = evaluate_on(model_init, train_dl)
|
correct, total = evaluate_on(model_init, train_dl)
|
||||||
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
|
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
|
||||||
correct, total = evaluate_on(model_trained, test_dl)
|
correct, total = evaluate_on(model_trained, test_dl)
|
||||||
print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
|
print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
|
||||||
|
|
||||||
torch.save(model_trained.state_dict(), hp['logfile'].with_suffix('.pt'))
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in a new issue