O1: save starting model

This commit is contained in:
Akemi Izuko 2024-12-03 19:43:05 -07:00
parent 36deb4613b
commit d606245ad1
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -2,6 +2,7 @@ import argparse
import equations
import numpy as np
import time
import copy
import torch
import torch.nn as nn
from torch import optim
@ -66,12 +67,33 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
return train_dl, test_dl, x_in, x_m, S[p]
def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, scheduler):
def evaluate_on(model, dataloader):
correct = 0
total = 0
with torch.no_grad():
model.eval()
for data in dataloader:
images, labels = data
images = images.to(DEVICE)
labels = labels.to(DEVICE)
wrn_outputs = model(images)
outputs = wrn_outputs[0]
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct, total
def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
best_test_set_accuracy = 0
for epoch in range(hp['epochs']):
model.train()
for i, data in enumerate(train_loader, 0):
for i, data in enumerate(train_dl, 0):
inputs, labels = data
inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
@ -87,30 +109,14 @@ def train_no_cap(model, hp, train_loader, test_loader, optimizer, criterion, sch
scheduler.step()
if epoch % 10 == 0 or epoch == hp['epochs'] - 1:
with torch.no_grad():
correct = 0
total = 0
model.eval()
for data in test_loader:
images, labels = data
images = images.to(DEVICE)
labels = labels.to(DEVICE)
wrn_outputs = model(images)
outputs = wrn_outputs[0]
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_accuracy = correct / total
epoch_accuracy = round(100 * epoch_accuracy, 2)
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%")
correct, total = evaluate_on(model, test_dl)
epoch_accuracy = round(100 * correct / total, 2)
print(f"Epoch {epoch+1}/{hp['epochs']}: {epoch_accuracy}%")
return best_test_set_accuracy
def train(hp):
def train(hp, train_dl, test_dl):
model = WideResNet(
d=hp["wrn_depth"],
k=hp["wrn_width"],
@ -123,6 +129,8 @@ def train(hp):
model = ModuleValidator.fix(model)
ModuleValidator.validate(model, strict=True)
model_init = copy.deepcopy(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
model.parameters(),
@ -137,12 +145,6 @@ def train(hp):
gamma=0.2
)
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
print(f"Got x_in: {x_in.shape}")
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
print(f"Got train dataloader: {len(train_dl)}")
print(f"Training with {hp['epochs']} epochs")
if hp['epsilon'] is not None:
@ -186,7 +188,7 @@ def train(hp):
scheduler,
)
return model
return model_init, model
def main():
@ -206,7 +208,7 @@ def main():
else:
DEVICE = torch.device('cpu')
hyperparams = {
hp = {
"target_points": args.m,
"wrn_depth": 16,
"wrn_width": 1,
@ -214,22 +216,34 @@ def main():
"delta": 1e-5,
"norm": args.norm,
"batch_size": 4096,
"epochs": 200,
"epochs": 20,
}
hyperparams['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format(
hp['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format(
int(time.time()),
hyperparams['wrn_depth'],
hyperparams['wrn_width'],
hyperparams['batch_size'],
hyperparams['epochs'],
hyperparams['epsilon'],
hyperparams['delta'],
hyperparams['norm'],
hp['wrn_depth'],
hp['wrn_width'],
hp['batch_size'],
hp['epochs'],
hp['epsilon'],
hp['delta'],
hp['norm'],
))
model = train(hyperparams)
torch.save(model.state_dict(), hyperparams['logfile'].with_suffix('.pt'))
train_dl, test_dl, x_in, x_m, S = get_dataloaders(hp['target_points'], hp['batch_size'])
print(f"Got vector S: {S.shape}, sum={np.sum(S)}, S[:{hp['target_points']}] = {S[:8]}")
print(f"Got x_in: {x_in.shape}")
print(f"Got x_m: {x_m.shape}, sum={np.sum(S)}, x_m[:{hp['target_points']}] = {x_m[:8]}")
print(f"S @ x_m: sum={np.sum(S[x_m])}, S[x_m][:{hp['target_points']}] = {S[x_m][:8]}")
print(f"Got train dataloader: {len(train_dl)}")
model_init, model_trained = train(hp, train_dl, test_dl)
correct, total = evaluate_on(model_init, train_dl)
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
correct, total = evaluate_on(model_trained, test_dl)
print(f"Done model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
torch.save(model_trained.state_dict(), hp['logfile'].with_suffix('.pt'))
if __name__ == '__main__':
main()