O1: save starting model
This commit is contained in:
parent
36deb4613b
commit
d606245ad1
1 changed files with 55 additions and 41 deletions
|
@ -2,6 +2,7 @@ import argparse
|
||||||
import equations
|
import equations
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
import copy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
@ -66,12 +67,33 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
return train_dl, test_dl, x_in, x_m, S[p]
|
return train_dl, test_dl, x_in, x_m, S[p]
|
||||||
|
|
||||||
|
|
||||||
def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler):
|
def evaluate_on(model, dataloader):
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
for data in dataloader:
|
||||||
|
images, labels = data
|
||||||
|
images = images.to(DEVICE)
|
||||||
|
labels = labels.to(DEVICE)
|
||||||
|
|
||||||
|
wrn_outputs = model(images)
|
||||||
|
outputs = wrn_outputs[0]
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
return correct, total
|
||||||
|
|
||||||
|
|
||||||
|
def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
|
||||||
best_test_set_accuracy = 0
|
best_test_set_accuracy = 0
|
||||||
|
|
||||||
for epoch in range(hp['epochs']):
|
for epoch in range(hp['epochs']):
|
||||||
model.train()
|
model.train()
|
||||||
for i, data in enumerate(train_loader, 0):
|
for i, data in enumerate(train_dl, 0):
|
||||||
inputs, labels = data
|
inputs, labels = data
|
||||||
inputs = inputs.to(DEVICE)
|
inputs = inputs.to(DEVICE)
|
||||||
labels = labels.to(DEVICE)
|
labels = labels.to(DEVICE)
|
||||||
|
@ -87,30 +109,14 @@ def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, sch
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
if epoch % 10 == 0 or epoch == hp['epochs'] - 1:
|
if epoch % 10 == 0 or epoch == hp['epochs'] - 1:
|
||||||
with torch.no_grad():
|
correct, total = evaluate_on(model, test_dl)
|
||||||
correct = 0
|
epoch_accuracy = round(100 * correct / total, 2)
|
||||||
total = 0
|
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%")
|
||||||
|
|
||||||
model.eval()
|
|
||||||
for data in test_loader:
|
|
||||||
images, labels = data
|
|
||||||
images = images.to(DEVICE)
|
|
||||||
labels = labels.to(DEVICE)
|
|
||||||
|
|
||||||
wrn_outputs = model(images)
|
|
||||||
outputs = wrn_outputs[0]
|
|
||||||
_, predicted = torch.max(outputs.data, 1)
|
|
||||||
total += labels.size(0)
|
|
||||||
correct += (predicted == labels).sum().item()
|
|
||||||
|
|
||||||
epoch_accuracy = correct / total
|
|
||||||
epoch_accuracy = round(100 * epoch_accuracy, 2)
|
|
||||||
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%")
|
|
||||||
|
|
||||||
return best_test_set_accuracy
|
return best_test_set_accuracy
|
||||||
|
|
||||||
|
|
||||||
def train(hp):
|
def train(hp, train_dl, test_dl):
|
||||||
model = WideResNet(
|
model = WideResNet(
|
||||||
d=hp["wrn_depth"],
|
d=hp["wrn_depth"],
|
||||||
k=hp["wrn_width"],
|
k=hp["wrn_width"],
|
||||||
|
@ -123,6 +129,8 @@ def train(hp):
|
||||||
model = ModuleValidator.fix(model)
|
model = ModuleValidator.fix(model)
|
||||||
ModuleValidator.validate(model, strict=True)
|
ModuleValidator.validate(model, strict=True)
|
||||||
|
|
||||||
|
model_init = copy.deepcopy(model)
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = optim.SGD(
|
optimizer = optim.SGD(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
|
@ -137,12 +145,6 @@ def train(hp):
|
||||||
gamma=0.2
|
gamma=0.2
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
|
|
||||||
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
|
|
||||||
print(f"Got x_in: {x_in.shape}")
|
|
||||||
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
|
|
||||||
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
|
|
||||||
print(f"Got train dataloader: {len(train_dl)}")
|
|
||||||
print(f"Training with {hp['epochs']} epochs")
|
print(f"Training with {hp['epochs']} epochs")
|
||||||
|
|
||||||
if hp['epsilon'] is not None:
|
if hp['epsilon'] is not None:
|
||||||
|
@ -186,7 +188,7 @@ def train(hp):
|
||||||
scheduler,
|
scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model_init, model
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -206,7 +208,7 @@ def main():
|
||||||
else:
|
else:
|
||||||
DEVICE = torch.device('cpu')
|
DEVICE = torch.device('cpu')
|
||||||
|
|
||||||
hyperparams = {
|
hp = {
|
||||||
"target_points": args.m,
|
"target_points": args.m,
|
||||||
"wrn_depth": 16,
|
"wrn_depth": 16,
|
||||||
"wrn_width": 1,
|
"wrn_width": 1,
|
||||||
|
@ -214,22 +216,34 @@ def main():
|
||||||
"delta": 1e-5,
|
"delta": 1e-5,
|
||||||
"norm": args.norm,
|
"norm": args.norm,
|
||||||
"batch_size": 4096,
|
"batch_size": 4096,
|
||||||
"epochs": 200,
|
"epochs": 20,
|
||||||
}
|
}
|
||||||
|
|
||||||
hyperparams['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format(
|
hp['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format(
|
||||||
int(time.time()),
|
int(time.time()),
|
||||||
hyperparams['wrn_depth'],
|
hp['wrn_depth'],
|
||||||
hyperparams['wrn_width'],
|
hp['wrn_width'],
|
||||||
hyperparams['batch_size'],
|
hp['batch_size'],
|
||||||
hyperparams['epochs'],
|
hp['epochs'],
|
||||||
hyperparams['epsilon'],
|
hp['epsilon'],
|
||||||
hyperparams['delta'],
|
hp['delta'],
|
||||||
hyperparams['norm'],
|
hp['norm'],
|
||||||
))
|
))
|
||||||
|
|
||||||
model = train(hyperparams)
|
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
|
||||||
torch.save(model.state_dict(), hyperparams['logfile'].with_suffix('.pt'))
|
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
|
||||||
|
print(f"Got x_in: {x_in.shape}")
|
||||||
|
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
|
||||||
|
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
|
||||||
|
print(f"Got train dataloader: {len(train_dl)}")
|
||||||
|
model_init, model_trained = train(hp, train_dl, test_dl)
|
||||||
|
|
||||||
|
correct, total = evaluate_on(model_init, train_dl)
|
||||||
|
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
|
||||||
|
correct, total = evaluate_on(model_trained, test_dl)
|
||||||
|
print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
|
||||||
|
|
||||||
|
torch.save(model_trained.state_dict(), hp['logfile'].with_suffix('.pt'))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in a new issue