337 lines
11 KiB
Python
337 lines
11 KiB
Python
|
import os
|
||
|
import re
|
||
|
import json
|
||
|
from functools import reduce
|
||
|
from datetime import datetime
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
import click
|
||
|
import torch
|
||
|
import progressbar
|
||
|
from torch import nn, optim
|
||
|
from torch.autograd import Variable
|
||
|
from torchvision import transforms
|
||
|
from torchvision import datasets as dset
|
||
|
|
||
|
from benchmark.models import resnet, densenet
|
||
|
|
||
|
MEAN = (0.4914, 0.4822, 0.4465)
|
||
|
STD = (0.2023, 0.1994, 0.2010)
|
||
|
|
||
|
MODELS = {
|
||
|
# "Deep Residual Learning for Image Recognition"
|
||
|
'resnet20': resnet.ResNet20,
|
||
|
'resnet32': resnet.ResNet32,
|
||
|
'resnet44': resnet.ResNet44,
|
||
|
'resnet56': resnet.ResNet56,
|
||
|
'resnet110': resnet.ResNet110,
|
||
|
'resnet1202': resnet.ResNet1202,
|
||
|
|
||
|
# "Wide Residual Networks"
|
||
|
'wrn-40-4': resnet.WRN_40_4,
|
||
|
'wrn-16-8': resnet.WRN_16_8,
|
||
|
'wrn-28-10': resnet.WRN_28_10,
|
||
|
|
||
|
# Based on "Identity Mappings in Deep Residual Networks"
|
||
|
'preact20': resnet.PreActResNet20,
|
||
|
'preact56': resnet.PreActResNet56,
|
||
|
'preact164-basic': resnet.PreActResNet164Basic,
|
||
|
|
||
|
# "Identity Mappings in Deep Residual Networks"
|
||
|
'preact110': resnet.PreActResNet110,
|
||
|
'preact164': resnet.PreActResNet164,
|
||
|
'preact1001': resnet.PreActResNet1001,
|
||
|
|
||
|
# "Aggregated Residual Transformations for Deep Neural Networks"
|
||
|
'resnext29-8-64': lambda _=None: resnet.ResNeXt29(8, 64),
|
||
|
'resnext29-16-64': lambda _=None: resnet.ResNeXt29(16, 64),
|
||
|
|
||
|
# "Densely Connected Convolutional Networks"
|
||
|
'densenetbc100': densenet.DenseNetBC100,
|
||
|
'densenetbc250': densenet.DenseNetBC250,
|
||
|
'densenetbc190': densenet.DenseNetBC190,
|
||
|
|
||
|
# Kuangliu/pytorch-cifar
|
||
|
'resnet18': resnet.ResNet18,
|
||
|
'resnet50': resnet.ResNet50,
|
||
|
'resnet101': resnet.ResNet101,
|
||
|
'resnet152': resnet.ResNet152,
|
||
|
}
|
||
|
|
||
|
|
||
|
def count_parameters(model):
|
||
|
c = map(lambda p: reduce(lambda x, y: x * y, p.size()), model.parameters())
|
||
|
return sum(c)
|
||
|
|
||
|
|
||
|
def correct(outputs, targets, top=(1, )):
|
||
|
_, predictions = outputs.topk(max(top), dim=1, largest=True, sorted=True)
|
||
|
targets = targets.view(-1, 1).expand_as(predictions)
|
||
|
corrects = predictions.eq(targets).cpu().cumsum(1).sum(0)
|
||
|
tops = list(map(lambda k: corrects.data[0][k - 1], top))
|
||
|
return tops
|
||
|
|
||
|
|
||
|
def save_result(result, path):
|
||
|
write_heading = not os.path.exists(path)
|
||
|
with open(path, mode='a') as out:
|
||
|
if write_heading:
|
||
|
out.write(",".join([str(k) for k, v in result.items()]) + '\n')
|
||
|
out.write(",".join([str(v) for k, v in result.items()]) + '\n')
|
||
|
|
||
|
|
||
|
def run(epoch, model, loader, criterion=None, optimizer=None, top=(1, 5),
|
||
|
use_cuda=False, tracking=None, max_value=None, train=True):
|
||
|
|
||
|
assert criterion is not None or not train, 'Need criterion to train model'
|
||
|
assert optimizer is not None or not train, 'Need optimizer to train model'
|
||
|
max_value = max_value or progressbar.UnknownLength
|
||
|
bar = progressbar.ProgressBar(max_value=max_value)
|
||
|
total = 0
|
||
|
correct_counts = {}
|
||
|
if train:
|
||
|
model.train()
|
||
|
else:
|
||
|
model.eval()
|
||
|
|
||
|
start = datetime.now()
|
||
|
for batch_index, (inputs, targets) in enumerate(loader):
|
||
|
inputs = Variable(inputs, requires_grad=False, volatile=not train)
|
||
|
targets = Variable(targets, requires_grad=False, volatile=not train)
|
||
|
|
||
|
if use_cuda:
|
||
|
inputs = inputs.cuda()
|
||
|
targets = targets.cuda()
|
||
|
|
||
|
outputs = model(inputs)
|
||
|
|
||
|
if train:
|
||
|
loss = criterion(outputs, targets)
|
||
|
optimizer.zero_grad()
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
_, predictions = torch.max(outputs.data, 1)
|
||
|
batch_size = targets.size(0)
|
||
|
top_correct = correct(outputs, targets, top=top)
|
||
|
total += batch_size
|
||
|
for k, count in zip(top, top_correct):
|
||
|
correct_counts[k] = correct_counts.get(k, 0) + count
|
||
|
|
||
|
end = datetime.now()
|
||
|
if tracking is not None:
|
||
|
result = OrderedDict()
|
||
|
result['timestamp'] = datetime.now()
|
||
|
result['batch_duration'] = end - start
|
||
|
result['epoch'] = epoch
|
||
|
result['batch'] = batch_index
|
||
|
result['batch_size'] = batch_size
|
||
|
for i, k in enumerate(top):
|
||
|
result['top{}_correct'.format(k)] = top_correct[i]
|
||
|
if train:
|
||
|
result['loss'] = loss.data[0]
|
||
|
save_result(result, tracking)
|
||
|
|
||
|
bar.update(batch_index + 1)
|
||
|
start = datetime.now()
|
||
|
|
||
|
print()
|
||
|
if train:
|
||
|
message = 'Training accuracy of'
|
||
|
else:
|
||
|
message = 'Test accuracy of'
|
||
|
for k in top:
|
||
|
accuracy = correct_counts[k] / total
|
||
|
message += ' top-{}: {}'.format(k, accuracy)
|
||
|
print(message)
|
||
|
return (1. * correct_counts[top[0]]) / total, batch_index + 1
|
||
|
|
||
|
|
||
|
def save(model, directory, epoch, accuracy, use_cuda=False, filename=None):
|
||
|
state = {
|
||
|
'model': model.module if use_cuda else model,
|
||
|
'epoch': epoch,
|
||
|
'accuracy': accuracy
|
||
|
}
|
||
|
|
||
|
filename = filename or 'checkpoint_{}.t7'.format(epoch)
|
||
|
torch.save(state, os.path.join(directory, filename))
|
||
|
|
||
|
|
||
|
def save_config(config, run_dir):
|
||
|
path = os.path.join(run_dir, "config_{}.json".format(config['timestamp']))
|
||
|
with open(path, 'w') as config_file:
|
||
|
json.dump(config, config_file)
|
||
|
config_file.write('\n')
|
||
|
|
||
|
|
||
|
def load(path):
|
||
|
assert os.path.exists(path)
|
||
|
state = torch.load(path)
|
||
|
model = state['model']
|
||
|
epoch = state['epoch']
|
||
|
accuracy = state['accuracy']
|
||
|
return model, epoch, accuracy
|
||
|
|
||
|
|
||
|
def latest_file(model):
|
||
|
restore = f'./run/{model}'
|
||
|
timestamps = sorted(os.listdir(restore))
|
||
|
assert len(timestamps) > 0
|
||
|
run_dir = os.path.join(restore, timestamps[-1])
|
||
|
files = os.listdir(run_dir)
|
||
|
max_checkpoint = -1
|
||
|
for filename in files:
|
||
|
if re.search('checkpoint_\d+.t7', filename):
|
||
|
num = int(re.search('\d+', filename).group())
|
||
|
|
||
|
if num > max_checkpoint:
|
||
|
max_checkpoint = num
|
||
|
max_checkpoint_file = filename
|
||
|
|
||
|
assert max_checkpoint != -1
|
||
|
return os.path.join(run_dir, max_checkpoint_file)
|
||
|
|
||
|
|
||
|
@click.command()
|
||
|
@click.option('--dataset-dir', default='./data/cifar10')
|
||
|
@click.option('--checkpoint', '-c', type=click.Choice(['best', 'all', 'last']),
|
||
|
default='last')
|
||
|
@click.option('--restore', '-r')
|
||
|
@click.option('--tracking/--no-tracking', default=True)
|
||
|
@click.option('--cuda/--no-cuda', default=True)
|
||
|
@click.option('--epochs', '-e', default=200)
|
||
|
@click.option('--batch-size', '-b', default=32)
|
||
|
@click.option('--learning-rate', '-l', default=1e-3)
|
||
|
@click.option('--sgd', 'optimizer', flag_value='sgd')
|
||
|
@click.option('--adam', 'optimizer', flag_value='adam', default=True)
|
||
|
@click.option('--augmentation/--no-augmentation', default=True)
|
||
|
@click.option('--num-workers', type=int)
|
||
|
@click.option('--weight-decay', default=5e-4)
|
||
|
@click.option('--model', '-m', type=click.Choice(MODELS.keys()),
|
||
|
default='resnet20')
|
||
|
def main(dataset_dir, checkpoint, restore, tracking, cuda, epochs,
|
||
|
batch_size, learning_rate, optimizer, augmentation, num_workers,
|
||
|
weight_decay, model):
|
||
|
timestamp = "{:.0f}".format(datetime.utcnow().timestamp())
|
||
|
config = {k: v for k, v in locals().items()}
|
||
|
|
||
|
use_cuda = cuda and torch.cuda.is_available()
|
||
|
if use_cuda:
|
||
|
num_workers = num_workers or torch.cuda.device_count()
|
||
|
else:
|
||
|
num_workers = num_workers or 1
|
||
|
|
||
|
print(f"using {num_workers} workers for data loading")
|
||
|
|
||
|
print("Preparing data:")
|
||
|
|
||
|
if augmentation:
|
||
|
transform_train = [
|
||
|
transforms.RandomCrop(32, padding=4),
|
||
|
transforms.RandomHorizontalFlip()
|
||
|
]
|
||
|
else:
|
||
|
transform_train = []
|
||
|
|
||
|
transform_train = transforms.Compose(transform_train + [
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize(MEAN, STD),
|
||
|
])
|
||
|
|
||
|
trainset = dset.CIFAR10(root=dataset_dir, train=True, download=True,
|
||
|
transform=transform_train)
|
||
|
train_loader = torch.utils.data.DataLoader(
|
||
|
trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
|
||
|
pin_memory=use_cuda)
|
||
|
|
||
|
transform_test = transforms.Compose([
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize(MEAN, STD),
|
||
|
])
|
||
|
|
||
|
testset = dset.CIFAR10(root=dataset_dir, train=False, download=True,
|
||
|
transform=transform_test)
|
||
|
test_loader = torch.utils.data.DataLoader(
|
||
|
testset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
||
|
pin_memory=use_cuda)
|
||
|
|
||
|
if restore is not None:
|
||
|
if restore == 'latest':
|
||
|
restore = latest_file(model)
|
||
|
print(f'Restoring model from {restore}')
|
||
|
model, start_epoch, best_accuracy = load(restore)
|
||
|
start_epoch += 1
|
||
|
print('Starting accuracy is {}'.format(best_accuracy))
|
||
|
run_dir = os.path.split(restore)[0]
|
||
|
else:
|
||
|
print(f'Building {model} model')
|
||
|
best_accuracy = -1
|
||
|
start_epoch = 1
|
||
|
run_dir = f"./run/{model}/{timestamp}"
|
||
|
model = MODELS[model]()
|
||
|
|
||
|
if not os.path.exists(run_dir):
|
||
|
os.makedirs(run_dir)
|
||
|
save_config(config, run_dir)
|
||
|
|
||
|
print(model)
|
||
|
print("{} parameters".format(count_parameters(model)))
|
||
|
print(f"Run directory set to {run_dir}")
|
||
|
|
||
|
# Save model text description
|
||
|
with open(os.path.join(run_dir, 'model.txt'), 'w') as file:
|
||
|
file.write(str(model))
|
||
|
|
||
|
if tracking:
|
||
|
train_results_file = os.path.join(run_dir, 'train_results.csv')
|
||
|
test_results_file = os.path.join(run_dir, 'test_results.csv')
|
||
|
else:
|
||
|
train_results_file = None
|
||
|
test_results_file = None
|
||
|
|
||
|
if use_cuda:
|
||
|
print('Copying model to GPU')
|
||
|
model.cuda()
|
||
|
model = torch.nn.DataParallel(
|
||
|
model, device_ids=range(torch.cuda.device_count()))
|
||
|
criterion = nn.CrossEntropyLoss()
|
||
|
|
||
|
# Other parameters?
|
||
|
if optimizer == 'adam':
|
||
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||
|
elif optimizer == 'sgd':
|
||
|
optimizer = optim.SGD(model.parameters(), lr=learning_rate,
|
||
|
momentum=0.9,
|
||
|
weight_decay=weight_decay)
|
||
|
else:
|
||
|
raise NotImplementedError("Unknown optimizer: {}".format(optimizer))
|
||
|
|
||
|
train_max_value = None
|
||
|
test_max_value = None
|
||
|
end_epoch = start_epoch + epochs
|
||
|
for epoch in range(start_epoch, end_epoch):
|
||
|
print('Epoch {} of {}'.format(epoch, end_epoch - 1))
|
||
|
train_acc, train_max_value = run(epoch, model, train_loader, criterion,
|
||
|
optimizer, use_cuda=use_cuda,
|
||
|
tracking=train_results_file,
|
||
|
max_value=train_max_value, train=True)
|
||
|
|
||
|
test_acc, test_max_value = run(epoch, model, test_loader,
|
||
|
use_cuda=use_cuda,
|
||
|
tracking=test_results_file, train=False)
|
||
|
|
||
|
if test_acc > best_accuracy:
|
||
|
print('New best model!')
|
||
|
save(model, run_dir, epoch, test_acc, use_cuda=use_cuda,
|
||
|
filename='checkpoint_best_model.t7')
|
||
|
best_accuracy = test_acc
|
||
|
|
||
|
last_epoch = epoch == (end_epoch - 1)
|
||
|
if checkpoint == 'all' or (checkpoint == 'last' and last_epoch):
|
||
|
save(model, run_dir, epoch, test_acc, use_cuda=use_cuda)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|