dawn-bench-models/pytorch/CIFAR10/benchmark/train.py

337 lines
11 KiB
Python
Raw Normal View History

2017-08-17 12:43:17 -06:00
import os
import re
import json
from functools import reduce
from datetime import datetime
from collections import OrderedDict
import click
import torch
import progressbar
from torch import nn, optim
from torch.autograd import Variable
from torchvision import transforms
from torchvision import datasets as dset
from benchmark.models import resnet, densenet
MEAN = (0.4914, 0.4822, 0.4465)
STD = (0.2023, 0.1994, 0.2010)
MODELS = {
# "Deep Residual Learning for Image Recognition"
'resnet20': resnet.ResNet20,
'resnet32': resnet.ResNet32,
'resnet44': resnet.ResNet44,
'resnet56': resnet.ResNet56,
'resnet110': resnet.ResNet110,
'resnet1202': resnet.ResNet1202,
# "Wide Residual Networks"
'wrn-40-4': resnet.WRN_40_4,
'wrn-16-8': resnet.WRN_16_8,
'wrn-28-10': resnet.WRN_28_10,
# Based on "Identity Mappings in Deep Residual Networks"
'preact20': resnet.PreActResNet20,
'preact56': resnet.PreActResNet56,
'preact164-basic': resnet.PreActResNet164Basic,
# "Identity Mappings in Deep Residual Networks"
'preact110': resnet.PreActResNet110,
'preact164': resnet.PreActResNet164,
'preact1001': resnet.PreActResNet1001,
# "Aggregated Residual Transformations for Deep Neural Networks"
'resnext29-8-64': lambda _=None: resnet.ResNeXt29(8, 64),
'resnext29-16-64': lambda _=None: resnet.ResNeXt29(16, 64),
# "Densely Connected Convolutional Networks"
'densenetbc100': densenet.DenseNetBC100,
'densenetbc250': densenet.DenseNetBC250,
'densenetbc190': densenet.DenseNetBC190,
# Kuangliu/pytorch-cifar
'resnet18': resnet.ResNet18,
'resnet50': resnet.ResNet50,
'resnet101': resnet.ResNet101,
'resnet152': resnet.ResNet152,
}
def count_parameters(model):
c = map(lambda p: reduce(lambda x, y: x * y, p.size()), model.parameters())
return sum(c)
def correct(outputs, targets, top=(1, )):
_, predictions = outputs.topk(max(top), dim=1, largest=True, sorted=True)
targets = targets.view(-1, 1).expand_as(predictions)
corrects = predictions.eq(targets).cpu().cumsum(1).sum(0)
tops = list(map(lambda k: corrects.data[0][k - 1], top))
return tops
def save_result(result, path):
write_heading = not os.path.exists(path)
with open(path, mode='a') as out:
if write_heading:
out.write(",".join([str(k) for k, v in result.items()]) + '\n')
out.write(",".join([str(v) for k, v in result.items()]) + '\n')
def run(epoch, model, loader, criterion=None, optimizer=None, top=(1, 5),
use_cuda=False, tracking=None, max_value=None, train=True):
assert criterion is not None or not train, 'Need criterion to train model'
assert optimizer is not None or not train, 'Need optimizer to train model'
max_value = max_value or progressbar.UnknownLength
bar = progressbar.ProgressBar(max_value=max_value)
total = 0
correct_counts = {}
if train:
model.train()
else:
model.eval()
start = datetime.now()
for batch_index, (inputs, targets) in enumerate(loader):
inputs = Variable(inputs, requires_grad=False, volatile=not train)
targets = Variable(targets, requires_grad=False, volatile=not train)
if use_cuda:
inputs = inputs.cuda()
targets = targets.cuda()
outputs = model(inputs)
if train:
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, predictions = torch.max(outputs.data, 1)
batch_size = targets.size(0)
top_correct = correct(outputs, targets, top=top)
total += batch_size
for k, count in zip(top, top_correct):
correct_counts[k] = correct_counts.get(k, 0) + count
end = datetime.now()
if tracking is not None:
result = OrderedDict()
result['timestamp'] = datetime.now()
result['batch_duration'] = end - start
result['epoch'] = epoch
result['batch'] = batch_index
result['batch_size'] = batch_size
for i, k in enumerate(top):
result['top{}_correct'.format(k)] = top_correct[i]
if train:
result['loss'] = loss.data[0]
save_result(result, tracking)
bar.update(batch_index + 1)
start = datetime.now()
print()
if train:
message = 'Training accuracy of'
else:
message = 'Test accuracy of'
for k in top:
accuracy = correct_counts[k] / total
message += ' top-{}: {}'.format(k, accuracy)
print(message)
return (1. * correct_counts[top[0]]) / total, batch_index + 1
def save(model, directory, epoch, accuracy, use_cuda=False, filename=None):
state = {
'model': model.module if use_cuda else model,
'epoch': epoch,
'accuracy': accuracy
}
filename = filename or 'checkpoint_{}.t7'.format(epoch)
torch.save(state, os.path.join(directory, filename))
def save_config(config, run_dir):
path = os.path.join(run_dir, "config_{}.json".format(config['timestamp']))
with open(path, 'w') as config_file:
json.dump(config, config_file)
config_file.write('\n')
def load(path):
assert os.path.exists(path)
state = torch.load(path)
model = state['model']
epoch = state['epoch']
accuracy = state['accuracy']
return model, epoch, accuracy
def latest_file(model):
restore = f'./run/{model}'
timestamps = sorted(os.listdir(restore))
assert len(timestamps) > 0
run_dir = os.path.join(restore, timestamps[-1])
files = os.listdir(run_dir)
max_checkpoint = -1
for filename in files:
if re.search('checkpoint_\d+.t7', filename):
num = int(re.search('\d+', filename).group())
if num > max_checkpoint:
max_checkpoint = num
max_checkpoint_file = filename
assert max_checkpoint != -1
return os.path.join(run_dir, max_checkpoint_file)
@click.command()
@click.option('--dataset-dir', default='./data/cifar10')
@click.option('--checkpoint', '-c', type=click.Choice(['best', 'all', 'last']),
default='last')
@click.option('--restore', '-r')
@click.option('--tracking/--no-tracking', default=True)
@click.option('--cuda/--no-cuda', default=True)
@click.option('--epochs', '-e', default=200)
@click.option('--batch-size', '-b', default=32)
@click.option('--learning-rate', '-l', default=1e-3)
@click.option('--sgd', 'optimizer', flag_value='sgd')
@click.option('--adam', 'optimizer', flag_value='adam', default=True)
@click.option('--augmentation/--no-augmentation', default=True)
@click.option('--num-workers', type=int)
@click.option('--weight-decay', default=5e-4)
@click.option('--model', '-m', type=click.Choice(MODELS.keys()),
default='resnet20')
def main(dataset_dir, checkpoint, restore, tracking, cuda, epochs,
batch_size, learning_rate, optimizer, augmentation, num_workers,
weight_decay, model):
timestamp = "{:.0f}".format(datetime.utcnow().timestamp())
config = {k: v for k, v in locals().items()}
use_cuda = cuda and torch.cuda.is_available()
if use_cuda:
num_workers = num_workers or torch.cuda.device_count()
else:
num_workers = num_workers or 1
print(f"using {num_workers} workers for data loading")
print("Preparing data:")
if augmentation:
transform_train = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
else:
transform_train = []
transform_train = transforms.Compose(transform_train + [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
trainset = dset.CIFAR10(root=dataset_dir, train=True, download=True,
transform=transform_train)
train_loader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
pin_memory=use_cuda)
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
testset = dset.CIFAR10(root=dataset_dir, train=False, download=True,
transform=transform_test)
test_loader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
pin_memory=use_cuda)
if restore is not None:
if restore == 'latest':
restore = latest_file(model)
print(f'Restoring model from {restore}')
model, start_epoch, best_accuracy = load(restore)
start_epoch += 1
print('Starting accuracy is {}'.format(best_accuracy))
run_dir = os.path.split(restore)[0]
else:
print(f'Building {model} model')
best_accuracy = -1
start_epoch = 1
run_dir = f"./run/{model}/{timestamp}"
model = MODELS[model]()
if not os.path.exists(run_dir):
os.makedirs(run_dir)
save_config(config, run_dir)
print(model)
print("{} parameters".format(count_parameters(model)))
print(f"Run directory set to {run_dir}")
# Save model text description
with open(os.path.join(run_dir, 'model.txt'), 'w') as file:
file.write(str(model))
if tracking:
train_results_file = os.path.join(run_dir, 'train_results.csv')
test_results_file = os.path.join(run_dir, 'test_results.csv')
else:
train_results_file = None
test_results_file = None
if use_cuda:
print('Copying model to GPU')
model.cuda()
model = torch.nn.DataParallel(
model, device_ids=range(torch.cuda.device_count()))
criterion = nn.CrossEntropyLoss()
# Other parameters?
if optimizer == 'adam':
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
elif optimizer == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=learning_rate,
momentum=0.9,
weight_decay=weight_decay)
else:
raise NotImplementedError("Unknown optimizer: {}".format(optimizer))
train_max_value = None
test_max_value = None
end_epoch = start_epoch + epochs
for epoch in range(start_epoch, end_epoch):
print('Epoch {} of {}'.format(epoch, end_epoch - 1))
train_acc, train_max_value = run(epoch, model, train_loader, criterion,
optimizer, use_cuda=use_cuda,
tracking=train_results_file,
max_value=train_max_value, train=True)
test_acc, test_max_value = run(epoch, model, test_loader,
use_cuda=use_cuda,
tracking=test_results_file, train=False)
if test_acc > best_accuracy:
print('New best model!')
save(model, run_dir, epoch, test_acc, use_cuda=use_cuda,
filename='checkpoint_best_model.t7')
best_accuracy = test_acc
last_epoch = epoch == (end_epoch - 1)
if checkpoint == 'all' or (checkpoint == 'last' and last_epoch):
save(model, run_dir, epoch, test_acc, use_cuda=use_cuda)
if __name__ == '__main__':
main()