diff --git a/pytorch/CIFAR10/benchmark/cifar10/train.py b/pytorch/CIFAR10/benchmark/cifar10/train.py index 31e7bc9..a1cf614 100644 --- a/pytorch/CIFAR10/benchmark/cifar10/train.py +++ b/pytorch/CIFAR10/benchmark/cifar10/train.py @@ -210,7 +210,8 @@ def train(dataset_dir, checkpoint, restore, tracking, cuda, epochs, assert restored_state['arch'] == arch model.load_state_dict(restored_state['model']) - optimizer.load_state_dict(restored_state['optimizer']) + if 'optimizer' in restored_state: + optimizer.load_state_dict(restored_state['optimizer']) if not isinstance(optimizer, YFOptimizer): for group in optimizer.param_groups: group['lr'] = learning_rate @@ -225,7 +226,7 @@ def train(dataset_dir, checkpoint, restore, tracking, cuda, epochs, print('Starting accuracy is {}'.format(best_accuracy)) - if not os.path.exists(run_dir): + if not os.path.exists(run_dir) and run_dir != '': os.makedirs(run_dir) utils.save_config(config, run_dir)