dawn-bench-models/pytorch/CIFAR10/benchmark/cifar10/train.py

376 lines
14 KiB
Python
Raw Normal View History

import os
from datetime import datetime
from collections import OrderedDict
import click
import torch
import tqdm
import numpy as np
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
from torchvision import datasets
from benchmark import utils
from benchmark.yellowfin import YFOptimizer
from benchmark.cifar10.models import resnet, densenet
MEAN = (0.4914, 0.4822, 0.4465)
STD = (0.2023, 0.1994, 0.2010)
MODELS = {
# "Deep Residual Learning for Image Recognition"
'resnet20': resnet.ResNet20,
'resnet32': resnet.ResNet32,
'resnet44': resnet.ResNet44,
'resnet56': resnet.ResNet56,
'resnet110': resnet.ResNet110,
'resnet1202': resnet.ResNet1202,
# "Wide Residual Networks"
'wrn-40-4': resnet.WRN_40_4,
'wrn-16-8': resnet.WRN_16_8,
'wrn-28-10': resnet.WRN_28_10,
# Based on "Identity Mappings in Deep Residual Networks"
'preact8': resnet.PreActResNet8,
'preact14': resnet.PreActResNet14,
'preact20': resnet.PreActResNet20,
'preact56': resnet.PreActResNet56,
'preact164-basic': resnet.PreActResNet164Basic,
# "Identity Mappings in Deep Residual Networks"
'preact110': resnet.PreActResNet110,
'preact164': resnet.PreActResNet164,
'preact1001': resnet.PreActResNet1001,
# Based on "Deep Networks with Stochastic Depth"
'stochastic56': resnet.StochasticResNet56,
'stochastic56-08': resnet.StochasticResNet56_08,
'stochastic110': resnet.StochasticResNet110,
'stochastic1202': resnet.StochasticResNet1202,
# "Aggregated Residual Transformations for Deep Neural Networks"
'resnext29-8-64': lambda _=None: resnet.ResNeXt29(8, 64),
'resnext29-16-64': lambda _=None: resnet.ResNeXt29(16, 64),
# "Densely Connected Convolutional Networks"
'densenetbc100': densenet.DenseNetBC100,
'densenetbc250': densenet.DenseNetBC250,
'densenetbc190': densenet.DenseNetBC190,
# Kuangliu/pytorch-cifar
'resnet18': resnet.ResNet18,
'resnet50': resnet.ResNet50,
'resnet101': resnet.ResNet101,
'resnet152': resnet.ResNet152,
}
def correct(outputs, targets, top=(1, )):
_, predictions = outputs.topk(max(top), dim=1, largest=True, sorted=True)
targets = targets.view(-1, 1).expand_as(predictions)
corrects = predictions.eq(targets).cpu().int().cumsum(1).sum(0)
tops = list(map(lambda k: corrects.data[0][k - 1], top))
return tops
def run(epoch, model, loader, criterion=None, optimizer=None, top=(1, 5),
use_cuda=False, tracking=None, train=True, half=False):
accuracies = [utils.AverageMeter() for _ in top]
assert criterion is not None or not train, 'Need criterion to train model'
assert optimizer is not None or not train, 'Need optimizer to train model'
loader = tqdm.tqdm(loader)
if train:
model.train()
losses = utils.AverageMeter()
else:
model.eval()
start = datetime.now()
for batch_index, (inputs, targets) in enumerate(loader):
inputs = Variable(inputs, requires_grad=False, volatile=not train)
targets = Variable(targets, requires_grad=False, volatile=not train)
batch_size = targets.size(0)
assert batch_size < 2**32, 'Size is too large! correct will overflow'
if use_cuda:
inputs = inputs.cuda()
targets = targets.cuda()
if half:
inputs = inputs.half()
outputs = model(inputs)
if train:
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.data[0], batch_size)
_, predictions = torch.max(outputs.data, 1)
top_correct = correct(outputs, targets, top=top)
for i, count in enumerate(top_correct):
accuracies[i].update(count * (100. / batch_size), batch_size)
end = datetime.now()
if tracking is not None:
result = OrderedDict()
result['timestamp'] = datetime.now()
result['batch_duration'] = end - start
result['epoch'] = epoch
result['batch'] = batch_index
result['batch_size'] = batch_size
for i, k in enumerate(top):
result['top{}_correct'.format(k)] = top_correct[i]
result['top{}_accuracy'.format(k)] = accuracies[i].val
if train:
result['loss'] = loss.data[0]
utils.save_result(result, tracking)
desc = 'Epoch {} {}'.format(epoch, '(Train):' if train else '(Val): ')
if train:
desc += ' Loss {loss.val:.4f} ({loss.avg:.4f})'.format(loss=losses)
for k, acc in zip(top, accuracies):
desc += ' Prec@{} {acc.val:.3f} ({acc.avg:.3f})'.format(k, acc=acc)
loader.set_description(desc)
start = datetime.now()
if train:
message = 'Training accuracy of'
else:
message = 'Validation accuracy of'
for i, k in enumerate(top):
message += ' top-{}: {}'.format(k, accuracies[i].avg)
print(message)
return accuracies[0].avg
@click.command()
@click.option('--dataset-dir', default='./data/cifar10')
@click.option('--checkpoint', '-c', type=click.Choice(['best', 'all', 'last']),
default='last')
@click.option('--restore', '-r')
@click.option('--tracking/--no-tracking', default=True)
@click.option('--cuda/--no-cuda', default=True)
@click.option('--epochs', '-e', default=200)
@click.option('--batch-size', '-b', default=32)
@click.option('--learning-rate', '-l', default=1e-3)
@click.option('--lr-factor', default=1.0, help='only for yellowfin')
@click.option('--momentum', default=0.9)
@click.option('--optimizer', '-o', type=click.Choice(['sgd', 'adam', 'yellowfin']),
default='sgd')
@click.option('--augmentation/--no-augmentation', default=True)
@click.option('device_ids', '--device', '-d', multiple=True, type=int)
@click.option('--num-workers', type=int)
@click.option('--weight-decay', default=5e-4)
@click.option('--validation', '-v', default=0.0)
@click.option('--evaluate', is_flag=True)
@click.option('--shuffle/--no-shuffle', default=True)
@click.option('--half', is_flag=True)
@click.option('--arch', '-a', type=click.Choice(MODELS.keys()),
default='resnet20')
def train(dataset_dir, checkpoint, restore, tracking, cuda, epochs,
batch_size, learning_rate, lr_factor, momentum, optimizer, augmentation,
device_ids, num_workers, weight_decay, validation, evaluate, shuffle,
half, arch):
timestamp = "{:.0f}".format(datetime.utcnow().timestamp())
local_timestamp = str(datetime.now())
config = {k: v for k, v in locals().items()}
use_cuda = cuda and torch.cuda.is_available()
# create model
model = MODELS[arch]()
# create optimizer
if optimizer == 'adam':
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
elif optimizer == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=learning_rate,
momentum=momentum,
weight_decay=weight_decay)
elif optimizer == 'yellowfin':
optimizer = YFOptimizer(model.parameters(), lr=learning_rate,
mu=momentum, weight_decay=weight_decay)
else:
raise NotImplementedError("Unknown optimizer: {}".format(optimizer))
if restore is not None:
if restore == 'latest':
restore = utils.latest_file(arch)
print(f'Restoring model from {restore}')
assert os.path.exists(restore)
restored_state = torch.load(restore)
assert restored_state['arch'] == arch
model.load_state_dict(restored_state['model'])
if 'optimizer' in restored_state:
optimizer.load_state_dict(restored_state['optimizer'])
if not isinstance(optimizer, YFOptimizer):
for group in optimizer.param_groups:
group['lr'] = learning_rate
best_accuracy = restored_state['accuracy']
start_epoch = restored_state['epoch'] + 1
run_dir = os.path.split(restore)[0]
else:
best_accuracy = 0.0
start_epoch = 1
run_dir = f"./run/{arch}/{timestamp}"
print('Starting accuracy is {}'.format(best_accuracy))
if not os.path.exists(run_dir) and run_dir != '':
os.makedirs(run_dir)
utils.save_config(config, run_dir)
print(model)
print("{} parameters".format(utils.count_parameters(model)))
print(f"Run directory set to {run_dir}")
# Save model text description
with open(os.path.join(run_dir, 'model.txt'), 'w') as file:
file.write(str(model))
if tracking:
train_results_file = os.path.join(run_dir, 'train_results.csv')
valid_results_file = os.path.join(run_dir, 'valid_results.csv')
test_results_file = os.path.join(run_dir, 'test_results.csv')
else:
train_results_file = None
valid_results_file = None
test_results_file = None
# create loss
criterion = nn.CrossEntropyLoss()
if use_cuda:
print('Copying model to GPU')
model = model.cuda()
criterion = criterion.cuda()
if half:
model = model.half()
criterion = criterion.half()
device_ids = device_ids or list(range(torch.cuda.device_count()))
model = torch.nn.DataParallel(
model, device_ids=device_ids)
num_workers = num_workers or len(device_ids)
else:
num_workers = num_workers or 1
if half:
print('Half precision (16-bit floating point) only works on GPU')
print(f"using {num_workers} workers for data loading")
# load data
print("Preparing data:")
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root=dataset_dir, train=False, download=True,
transform=transform_test),
batch_size=batch_size, shuffle=False, num_workers=num_workers,
pin_memory=use_cuda)
if evaluate:
print("Only running evaluation of model on test dataset")
run(start_epoch - 1, model, test_loader, use_cuda=use_cuda,
tracking=test_results_file, train=False)
return
if augmentation:
transform_train = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
else:
transform_train = []
transform_train = transforms.Compose(transform_train + [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
train_dataset = datasets.CIFAR10(root=dataset_dir, train=True,
download=True, transform=transform_train)
num_train = len(train_dataset)
indices = list(range(num_train))
assert 1 > validation and validation >= 0, "Validation must be in [0, 1)"
split = num_train - int(validation * num_train)
if shuffle:
np.random.shuffle(indices)
train_indices = indices[:split]
valid_indices = indices[split:]
print('Using {} examples for training'.format(len(train_indices)))
print('Using {} examples for validation'.format(len(valid_indices)))
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
train_loader = torch.utils.data.DataLoader(
train_dataset, sampler=train_sampler, batch_size=batch_size,
num_workers=num_workers, pin_memory=use_cuda)
if validation != 0:
valid_loader = torch.utils.data.DataLoader(
train_dataset, sampler=valid_sampler, batch_size=batch_size,
num_workers=num_workers, pin_memory=use_cuda)
else:
print('Using test dataset for validation')
valid_loader = test_loader
end_epoch = start_epoch + epochs
# YellowFin doesn't have param_groups causing AttributeError
if not isinstance(optimizer, YFOptimizer):
for group in optimizer.param_groups:
if 'lr' in group:
print('Learning rate set to {}'.format(group['lr']))
assert group['lr'] == learning_rate
else:
print(f"set lr_factor to {lr_factor}")
optimizer.set_lr_factor(lr_factor)
for epoch in range(start_epoch, end_epoch):
run(epoch, model, train_loader, criterion, optimizer,
use_cuda=use_cuda, tracking=train_results_file, train=True,
half=half)
valid_acc = run(epoch, model, valid_loader, use_cuda=use_cuda,
tracking=valid_results_file, train=False, half=half)
is_best = valid_acc > best_accuracy
last_epoch = epoch == (end_epoch - 1)
if is_best or checkpoint == 'all' or (checkpoint == 'last' and last_epoch):
state = {
'epoch': epoch,
'arch': arch,
'model': (model.module if use_cuda else model).state_dict(),
'accuracy': valid_acc,
'optimizer': optimizer.state_dict()
}
if is_best:
print('New best model!')
filename = os.path.join(run_dir, 'checkpoint_best_model.t7')
print(f'Saving checkpoint to {filename}')
best_accuracy = valid_acc
torch.save(state, filename)
if checkpoint == 'all' or (checkpoint == 'last' and last_epoch):
filename = os.path.join(run_dir, f'checkpoint_{epoch}.t7')
print(f'Saving checkpoint to {filename}')
torch.save(state, filename)
if __name__ == '__main__':
train()