O1: save starting model

This commit is contained in:
Akemi Izuko 2024-12-03 19:43:05 -07:00
parent 36deb4613b
commit d606245ad1
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -2,6 +2,7 @@ import argparse
import equations import equations
import numpy as np import numpy as np
import time import time
import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
@ -66,12 +67,33 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
return train_dl, test_dl, x_in, x_m, S[p] return train_dl, test_dl, x_in, x_m, S[p]
def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler): def evaluate_on(model, dataloader):
correct = 0
total = 0
with torch.no_grad():
model.eval()
for data in dataloader:
images, labels = data
images = images.to(DEVICE)
labels = labels.to(DEVICE)
wrn_outputs = model(images)
outputs = wrn_outputs[0]
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct, total
def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
best_test_set_accuracy = 0 best_test_set_accuracy = 0
for epoch in range(hp['epochs']): for epoch in range(hp['epochs']):
model.train() model.train()
for i, data in enumerate(train_loader, 0): for i, data in enumerate(train_dl, 0):
inputs, labels = data inputs, labels = data
inputs = inputs.to(DEVICE) inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE) labels = labels.to(DEVICE)
@ -87,30 +109,14 @@ def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, sch
scheduler.step() scheduler.step()
if epoch % 10 == 0 or epoch == hp['epochs'] - 1: if epoch % 10 == 0 or epoch == hp['epochs'] - 1:
with torch.no_grad(): correct, total = evaluate_on(model, test_dl)
correct = 0 epoch_accuracy = round(100 * correct / total, 2)
total = 0
model.eval()
for data in test_loader:
images, labels = data
images = images.to(DEVICE)
labels = labels.to(DEVICE)
wrn_outputs = model(images)
outputs = wrn_outputs[0]
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_accuracy = correct / total
epoch_accuracy = round(100 * epoch_accuracy, 2)
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%") print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%")
return best_test_set_accuracy return best_test_set_accuracy
def train(hp): def train(hp, train_dl, test_dl):
model = WideResNet( model = WideResNet(
d=hp["wrn_depth"], d=hp["wrn_depth"],
k=hp["wrn_width"], k=hp["wrn_width"],
@ -123,6 +129,8 @@ def train(hp):
model = ModuleValidator.fix(model) model = ModuleValidator.fix(model)
ModuleValidator.validate(model, strict=True) ModuleValidator.validate(model, strict=True)
model_init = copy.deepcopy(model)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD( optimizer = optim.SGD(
model.parameters(), model.parameters(),
@ -137,12 +145,6 @@ def train(hp):
gamma=0.2 gamma=0.2
) )
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
print(f"Got x_in: {x_in.shape}")
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
print(f"Got train dataloader: {len(train_dl)}")
print(f"Training with {hp['epochs']} epochs") print(f"Training with {hp['epochs']} epochs")
if hp['epsilon'] is not None: if hp['epsilon'] is not None:
@ -186,7 +188,7 @@ def train(hp):
scheduler, scheduler,
) )
return model return model_init, model
def main(): def main():
@ -206,7 +208,7 @@ def main():
else: else:
DEVICE = torch.device('cpu') DEVICE = torch.device('cpu')
hyperparams = { hp = {
"target_points": args.m, "target_points": args.m,
"wrn_depth": 16, "wrn_depth": 16,
"wrn_width": 1, "wrn_width": 1,
@ -214,22 +216,34 @@ def main():
"delta": 1e-5, "delta": 1e-5,
"norm": args.norm, "norm": args.norm,
"batch_size": 4096, "batch_size": 4096,
"epochs": 200, "epochs": 20,
} }
hyperparams['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format( hp['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format(
int(time.time()), int(time.time()),
hyperparams['wrn_depth'], hp['wrn_depth'],
hyperparams['wrn_width'], hp['wrn_width'],
hyperparams['batch_size'], hp['batch_size'],
hyperparams['epochs'], hp['epochs'],
hyperparams['epsilon'], hp['epsilon'],
hyperparams['delta'], hp['delta'],
hyperparams['norm'], hp['norm'],
)) ))
model = train(hyperparams) train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
torch.save(model.state_dict(), hyperparams['logfile'].with_suffix('.pt')) print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
print(f"Got x_in: {x_in.shape}")
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
print(f"Got train dataloader: {len(train_dl)}")
model_init, model_trained = train(hp, train_dl, test_dl)
correct, total = evaluate_on(model_init, train_dl)
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
correct, total = evaluate_on(model_trained, test_dl)
print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
torch.save(model_trained.state_dict(), hp['logfile'].with_suffix('.pt'))
if __name__ == '__main__': if __name__ == '__main__':
main() main()