Make pytorch evaluate run without optimizer in checkpoint
This commit is contained in:
parent
339261c19f
commit
b7f332f460
1 changed files with 3 additions and 2 deletions
|
@ -210,7 +210,8 @@ def train(dataset_dir, checkpoint, restore, tracking, cuda, epochs,
|
||||||
assert restored_state['arch'] == arch
|
assert restored_state['arch'] == arch
|
||||||
|
|
||||||
model.load_state_dict(restored_state['model'])
|
model.load_state_dict(restored_state['model'])
|
||||||
optimizer.load_state_dict(restored_state['optimizer'])
|
if 'optimizer' in restored_state:
|
||||||
|
optimizer.load_state_dict(restored_state['optimizer'])
|
||||||
if not isinstance(optimizer, YFOptimizer):
|
if not isinstance(optimizer, YFOptimizer):
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group['lr'] = learning_rate
|
group['lr'] = learning_rate
|
||||||
|
@ -225,7 +226,7 @@ def train(dataset_dir, checkpoint, restore, tracking, cuda, epochs,
|
||||||
|
|
||||||
print('Starting accuracy is {}'.format(best_accuracy))
|
print('Starting accuracy is {}'.format(best_accuracy))
|
||||||
|
|
||||||
if not os.path.exists(run_dir):
|
if not os.path.exists(run_dir) and run_dir != '':
|
||||||
os.makedirs(run_dir)
|
os.makedirs(run_dir)
|
||||||
utils.save_config(config, run_dir)
|
utils.save_config(config, run_dir)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue