dawn-bench-models/tensorflow/CIFAR10/resnet/resnet_main.py

303 lines
11 KiB
Python
Raw Permalink Normal View History

2017-08-17 12:43:17 -06:00
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet Train/Eval module.
"""
import os
import six
import subprocess
import sys
import time
import cifar_input
import numpy as np
import resnet_model
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.')
tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.')
tf.app.flags.DEFINE_string('model', '', 'model to train.')
tf.app.flags.DEFINE_string('data_format', 'NHWC',
"""Data layout to use: NHWC (TF native)
or NCHW (cuDNN native).""")
tf.app.flags.DEFINE_string('train_data_path', '',
'Filepattern for training data.')
tf.app.flags.DEFINE_string('eval_data_path', '',
'Filepattern for eval data')
tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.')
tf.app.flags.DEFINE_string('train_dir', '',
'Directory to keep training outputs.')
tf.app.flags.DEFINE_string('eval_dir', '',
'Directory to keep eval outputs.')
tf.app.flags.DEFINE_integer('eval_batch_count', 50,
'Number of batches to eval.')
tf.app.flags.DEFINE_bool('eval_once', False,
'Whether evaluate the model only once.')
tf.app.flags.DEFINE_string('log_root', '',
'Should be a parent directory of FLAGS.train_dir/eval_dir.')
tf.app.flags.DEFINE_string('checkpoint_dir', '',
'Directory to store the checkpoints')
tf.app.flags.DEFINE_integer('num_gpus', 0,
'Number of gpus used for training. (0 or 1)')
tf.app.flags.DEFINE_bool('use_bottleneck', False,
'Use bottleneck module or not.')
tf.app.flags.DEFINE_bool('time_inference', False,
'Time inference.')
tf.app.flags.DEFINE_integer('batch_size', -1,
'Batch size to use.')
def train(hps):
"""Training loop."""
images, labels = cifar_input.build_input(
FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode, hps.data_format)
model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
model.build_graph()
param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
tf.get_default_graph(),
tfprof_options=tf.contrib.tfprof.model_analyzer.
TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
tf.contrib.tfprof.model_analyzer.print_model_analysis(
tf.get_default_graph(),
tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
truth = tf.argmax(model.labels, axis=1)
predictions = tf.argmax(model.predictions, axis=1)
precision = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth)))
summary_hook = tf.train.SummarySaverHook(
save_steps=100,
output_dir=FLAGS.train_dir,
summary_op=tf.summary.merge([model.summaries,
tf.summary.scalar('Precision', precision)]))
num_steps_per_epoch = 391 # TODO: Don't hardcode this.
logging_hook = tf.train.LoggingTensorHook(
tensors={'step': model.global_step,
'loss': model.cost,
'precision': precision},
every_n_iter=100)
class _LearningRateSetterHook(tf.train.SessionRunHook):
"""Sets learning_rate based on global step."""
def begin(self):
self._lrn_rate = 0.01
def before_run(self, run_context):
return tf.train.SessionRunArgs(
model.global_step, # Asks for global step value.
feed_dict={model.lrn_rate: self._lrn_rate}) # Sets learning rate
def after_run(self, run_context, run_values):
train_step = run_values.results
if train_step < num_steps_per_epoch:
self._lrn_rate = 0.01
elif train_step < (91 * num_steps_per_epoch):
self._lrn_rate = 0.1
elif train_step < (136 * num_steps_per_epoch):
self._lrn_rate = 0.01
elif train_step < (181 * num_steps_per_epoch):
self._lrn_rate = 0.001
else:
self._lrn_rate = 0.0001
class _SaverHook(tf.train.SessionRunHook):
"""Sets learning_rate based on global step."""
def begin(self):
self.saver = tf.train.Saver(max_to_keep=10000)
subprocess.call("rm -rf %s; mkdir -p %s" % (FLAGS.checkpoint_dir,
FLAGS.checkpoint_dir), shell=True)
self.f = open(os.path.join(FLAGS.checkpoint_dir, "times.log"), 'w')
def after_create_session(self, sess, coord):
self.sess = sess
self.start_time = time.time()
def before_run(self, run_context):
return tf.train.SessionRunArgs(
model.global_step # Asks for global step value.
)
def after_run(self, run_context, run_values):
train_step = run_values.results
epoch = train_step / num_steps_per_epoch
if train_step % num_steps_per_epoch == 0:
end_time = time.time()
directory = os.path.join(FLAGS.checkpoint_dir, ("%5d" % epoch).replace(' ', '0'))
subprocess.call("mkdir -p %s" % directory, shell=True)
ckpt_name = 'model.ckpt'
self.saver.save(self.sess, os.path.join(directory, ckpt_name),
global_step=train_step)
self.f.write("Step: %d\tTime: %s\n" % (train_step, end_time - self.start_time))
print("Saved checkpoint after %d epoch(s) to %s..." % (epoch, directory))
sys.stdout.flush()
self.start_time = time.time()
def end(self, sess):
self.f.close()
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.log_root,
hooks=[logging_hook, _LearningRateSetterHook()],
chief_only_hooks=[summary_hook, _SaverHook()],
save_checkpoint_secs=None,
# Since we provide a SummarySaverHook, we need to disable default
# SummarySaverHook. To do that we set save_summaries_steps to 0.
save_summaries_steps=None,
save_summaries_secs=None,
config=tf.ConfigProto(allow_soft_placement=True)) as mon_sess:
for i in range(num_steps_per_epoch * 181):
mon_sess.run(model.train_op)
def evaluate(hps):
"""Eval loop."""
images, labels = cifar_input.build_input(
FLAGS.dataset, FLAGS.eval_data_path, hps.batch_size, FLAGS.mode, hps.data_format)
model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
model.build_graph()
saver = tf.train.Saver()
summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
tf.train.start_queue_runners(sess)
best_precision = 0.0
while True:
try:
ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
except tf.errors.OutOfRangeError as e:
tf.logging.error('Cannot restore checkpoint: %s', e)
continue
if not (ckpt_state and ckpt_state.model_checkpoint_path):
tf.logging.info('No model to eval yet at %s', FLAGS.log_root)
break
tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
saver.restore(sess, ckpt_state.model_checkpoint_path)
global_step = ckpt_state.model_checkpoint_path.split('/')[-1].split('-')[-1]
if not global_step.isdigit():
global_step = 0
else:
global_step = int(global_step)
total_prediction, correct_prediction, correct_prediction_top5 = 0, 0, 0
start_time = time.time()
for _ in six.moves.range(FLAGS.eval_batch_count):
(summaries, loss, predictions, truth, train_step) = sess.run(
[model.summaries, model.cost, model.predictions,
model.labels, model.global_step])
if not FLAGS.time_inference:
for (indiv_truth, indiv_prediction) in zip(truth, predictions):
indiv_truth = np.argmax(indiv_truth)
top5_prediction = np.argsort(indiv_prediction)[-5:]
top1_prediction = np.argsort(indiv_prediction)[-1]
correct_prediction += (indiv_truth == top1_prediction)
if indiv_truth in top5_prediction:
correct_prediction_top5 += 1
total_prediction += 1
if FLAGS.time_inference:
print("Time for inference: %.4f" % (time.time() - start_time))
else:
precision = 1.0 * correct_prediction / total_prediction
precision_top5 = 1.0 * correct_prediction_top5 / total_prediction
best_precision = max(precision, best_precision)
precision_summ = tf.Summary()
precision_summ.value.add(
tag='Precision', simple_value=precision)
summary_writer.add_summary(precision_summ, train_step)
best_precision_summ = tf.Summary()
best_precision_summ.value.add(
tag='Best Precision', simple_value=best_precision)
summary_writer.add_summary(best_precision_summ, train_step)
summary_writer.add_summary(summaries, train_step)
print('Precision @ 1 = %.4f, Recall @ 5 = %.4f, Global step = %d' %
(precision, precision_top5, global_step))
summary_writer.flush()
if FLAGS.eval_once:
break
time.sleep(60)
def main(_):
if FLAGS.model == '':
raise Exception('--model must be specified.')
if FLAGS.num_gpus == 0:
dev = '/cpu:0'
elif FLAGS.num_gpus == 1:
dev = '/gpu:0'
else:
raise ValueError('Only support 0 or 1 gpu.')
if FLAGS.batch_size == -1:
if FLAGS.mode == 'train':
batch_size = 128
elif FLAGS.mode == 'eval':
batch_size = 100
else:
batch_size = FLAGS.batch_size
if FLAGS.dataset == 'cifar10':
num_classes = 10
elif FLAGS.dataset == 'cifar100':
num_classes = 100
if FLAGS.model == 'resnet20':
num_residual_units = 3
elif FLAGS.model == 'resnet56':
num_residual_units = 9
elif FLAGS.model == 'resnet164' and FLAGS.use_bottleneck:
num_residual_units = 18
elif FLAGS.model == 'resnet164' and not FLAGS.use_bottleneck:
num_residual_units = 27
else:
raise Exception("Invalid model -- only resnet20, resnet56 and resnet164 supported")
data_format = FLAGS.data_format
hps = resnet_model.HParams(batch_size=batch_size,
num_classes=num_classes,
min_lrn_rate=0.0001,
lrn_rate=0.1,
num_residual_units=num_residual_units,
use_bottleneck=FLAGS.use_bottleneck,
weight_decay_rate=0.0005,
relu_leakiness=0.1,
optimizer='mom',
data_format=data_format)
with tf.device(dev):
if FLAGS.mode == 'train':
train(hps)
elif FLAGS.mode == 'eval':
evaluate(hps)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()