O1: save starting model
This commit is contained in:
parent
36deb4613b
commit
d606245ad1
1 changed files with 55 additions and 41 deletions
|
@ -2,6 +2,7 @@ import argparse
|
|||
import equations
|
||||
import numpy as np
|
||||
import time
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
|
@ -66,12 +67,33 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
|||
return train_dl, test_dl, x_in, x_m, S[p]
|
||||
|
||||
|
||||
def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler):
|
||||
def evaluate_on(model, dataloader):
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
|
||||
for data in dataloader:
|
||||
images, labels = data
|
||||
images = images.to(DEVICE)
|
||||
labels = labels.to(DEVICE)
|
||||
|
||||
wrn_outputs = model(images)
|
||||
outputs = wrn_outputs[0]
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
return correct, total
|
||||
|
||||
|
||||
def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
|
||||
best_test_set_accuracy = 0
|
||||
|
||||
for epoch in range(hp['epochs']):
|
||||
model.train()
|
||||
for i, data in enumerate(train_loader, 0):
|
||||
for i, data in enumerate(train_dl, 0):
|
||||
inputs, labels = data
|
||||
inputs = inputs.to(DEVICE)
|
||||
labels = labels.to(DEVICE)
|
||||
|
@ -87,30 +109,14 @@ def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, sch
|
|||
scheduler.step()
|
||||
|
||||
if epoch % 10 == 0 or epoch == hp['epochs'] - 1:
|
||||
with torch.no_grad():
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
model.eval()
|
||||
for data in test_loader:
|
||||
images, labels = data
|
||||
images = images.to(DEVICE)
|
||||
labels = labels.to(DEVICE)
|
||||
|
||||
wrn_outputs = model(images)
|
||||
outputs = wrn_outputs[0]
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
epoch_accuracy = correct / total
|
||||
epoch_accuracy = round(100 * epoch_accuracy, 2)
|
||||
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%")
|
||||
correct, total = evaluate_on(model, test_dl)
|
||||
epoch_accuracy = round(100 * correct / total, 2)
|
||||
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%")
|
||||
|
||||
return best_test_set_accuracy
|
||||
|
||||
|
||||
def train(hp):
|
||||
def train(hp, train_dl, test_dl):
|
||||
model = WideResNet(
|
||||
d=hp["wrn_depth"],
|
||||
k=hp["wrn_width"],
|
||||
|
@ -123,6 +129,8 @@ def train(hp):
|
|||
model = ModuleValidator.fix(model)
|
||||
ModuleValidator.validate(model, strict=True)
|
||||
|
||||
model_init = copy.deepcopy(model)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
|
@ -137,12 +145,6 @@ def train(hp):
|
|||
gamma=0.2
|
||||
)
|
||||
|
||||
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
|
||||
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
|
||||
print(f"Got x_in: {x_in.shape}")
|
||||
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
|
||||
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
|
||||
print(f"Got train dataloader: {len(train_dl)}")
|
||||
print(f"Training with {hp['epochs']} epochs")
|
||||
|
||||
if hp['epsilon'] is not None:
|
||||
|
@ -186,7 +188,7 @@ def train(hp):
|
|||
scheduler,
|
||||
)
|
||||
|
||||
return model
|
||||
return model_init, model
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -206,7 +208,7 @@ def main():
|
|||
else:
|
||||
DEVICE = torch.device('cpu')
|
||||
|
||||
hyperparams = {
|
||||
hp = {
|
||||
"target_points": args.m,
|
||||
"wrn_depth": 16,
|
||||
"wrn_width": 1,
|
||||
|
@ -214,22 +216,34 @@ def main():
|
|||
"delta": 1e-5,
|
||||
"norm": args.norm,
|
||||
"batch_size": 4096,
|
||||
"epochs": 200,
|
||||
"epochs": 20,
|
||||
}
|
||||
|
||||
hyperparams['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format(
|
||||
hp['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format(
|
||||
int(time.time()),
|
||||
hyperparams['wrn_depth'],
|
||||
hyperparams['wrn_width'],
|
||||
hyperparams['batch_size'],
|
||||
hyperparams['epochs'],
|
||||
hyperparams['epsilon'],
|
||||
hyperparams['delta'],
|
||||
hyperparams['norm'],
|
||||
hp['wrn_depth'],
|
||||
hp['wrn_width'],
|
||||
hp['batch_size'],
|
||||
hp['epochs'],
|
||||
hp['epsilon'],
|
||||
hp['delta'],
|
||||
hp['norm'],
|
||||
))
|
||||
|
||||
model = train(hyperparams)
|
||||
torch.save(model.state_dict(), hyperparams['logfile'].with_suffix('.pt'))
|
||||
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
|
||||
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
|
||||
print(f"Got x_in: {x_in.shape}")
|
||||
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
|
||||
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
|
||||
print(f"Got train dataloader: {len(train_dl)}")
|
||||
model_init, model_trained = train(hp, train_dl, test_dl)
|
||||
|
||||
correct, total = evaluate_on(model_init, train_dl)
|
||||
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
|
||||
correct, total = evaluate_on(model_trained, test_dl)
|
||||
print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
|
||||
|
||||
torch.save(model_trained.state_dict(), hp['logfile'].with_suffix('.pt'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Loading…
Reference in a new issue