dawn-bench-models/tensorflow/CIFAR10/resnet/resnet_model.py

282 lines
10 KiB
Python
Raw Normal View History

2017-08-17 12:43:17 -06:00
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet model.
Related papers:
https://arxiv.org/pdf/1603.05027v2.pdf
https://arxiv.org/pdf/1512.03385v1.pdf
https://arxiv.org/pdf/1605.07146v1.pdf
"""
from collections import namedtuple
import numpy as np
import tensorflow as tf
import six
from tensorflow.python.training import moving_averages
HParams = namedtuple('HParams',
'batch_size, num_classes, min_lrn_rate, lrn_rate, '
'num_residual_units, use_bottleneck, weight_decay_rate, '
'relu_leakiness, optimizer, data_format')
class ResNet(object):
"""ResNet model."""
def __init__(self, hps, images, labels, mode):
"""ResNet constructor.
Args:
hps: Hyperparameters.
images: Batches of images. [batch_size, image_size, image_size, 3]
labels: Batches of labels. [batch_size, num_classes]
mode: One of 'train' and 'eval'.
"""
self.hps = hps
self._images = images
self.labels = labels
self.mode = mode
self._extra_train_ops = []
def build_graph(self):
"""Build a whole graph for the model."""
self.global_step = tf.contrib.framework.get_or_create_global_step()
self._build_model()
if self.mode == 'train':
self._build_train_op()
self.summaries = tf.summary.merge_all()
def _stride_arr(self, stride):
"""Map a stride scalar to the stride array for tf.nn.conv2d."""
if self.hps.data_format == 'NHWC':
return [1, stride, stride, 1]
elif self.hps.data_format == 'NCHW':
return [1, 1, stride, stride]
else:
raise Exception("Invalid data_format")
def _build_model(self):
"""Build the core model within the graph."""
with tf.variable_scope('init'):
x = self._images
x = self._conv('init_conv', x, 3, 3, 16, self._stride_arr(1))
strides = [1, 2, 2]
activate_before_residual = [True, False, False]
if self.hps.use_bottleneck:
res_func = self._bottleneck_residual
filters = [16, 64, 128, 256]
else:
res_func = self._residual
filters = [16, 16, 32, 64]
# Uncomment the following codes to use w28-10 wide residual network.
# It is more memory efficient than very deep residual network and has
# comparably good performance.
# https://arxiv.org/pdf/1605.07146v1.pdf
# filters = [16, 160, 320, 640]
# Update hps.num_residual_units to 4
with tf.variable_scope('unit_1_0'):
x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]),
activate_before_residual[0])
for i in six.moves.range(1, self.hps.num_residual_units):
with tf.variable_scope('unit_1_%d' % i):
x = res_func(x, filters[1], filters[1], self._stride_arr(1), False)
with tf.variable_scope('unit_2_0'):
x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]),
activate_before_residual[1])
for i in six.moves.range(1, self.hps.num_residual_units):
with tf.variable_scope('unit_2_%d' % i):
x = res_func(x, filters[2], filters[2], self._stride_arr(1), False)
with tf.variable_scope('unit_3_0'):
x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]),
activate_before_residual[2])
for i in six.moves.range(1, self.hps.num_residual_units):
with tf.variable_scope('unit_3_%d' % i):
x = res_func(x, filters[3], filters[3], self._stride_arr(1), False)
with tf.variable_scope('unit_last'):
x = self._batch_norm('final_bn', x)
x = self._relu(x, self.hps.relu_leakiness)
x = self._global_avg_pool(x)
with tf.variable_scope('logit'):
logits = self._fully_connected(x, self.hps.num_classes)
self.predictions = tf.nn.softmax(logits)
with tf.variable_scope('costs'):
xent = tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=self.labels)
self.cost = tf.reduce_mean(xent, name='xent')
self.cost += self._decay()
tf.summary.scalar('cost', self.cost)
def _build_train_op(self):
"""Build training specific ops for the graph."""
self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32)
tf.summary.scalar('learning_rate', self.lrn_rate)
trainable_variables = tf.trainable_variables()
grads = tf.gradients(self.cost, trainable_variables)
if self.hps.optimizer == 'sgd':
optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate)
elif self.hps.optimizer == 'mom':
optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9)
apply_op = optimizer.apply_gradients(
zip(grads, trainable_variables),
global_step=self.global_step, name='train_step')
train_ops = [apply_op] + self._extra_train_ops
self.train_op = tf.group(*train_ops)
# TODO(xpan): Consider batch_norm in contrib/layers/python/layers/layers.py
def _batch_norm(self, name, x):
"""Batch normalization."""
with tf.variable_scope(name) as scope:
output = tf.contrib.layers.batch_norm(x,
decay=0.9,
epsilon=0.001,
data_format=self.hps.data_format,
scope=scope,
is_training=(self.mode == 'train'),
fused=True,
updates_collections=None)
return output
def _residual(self, x, in_filter, out_filter, stride,
activate_before_residual=False):
"""Residual unit with 2 sub layers."""
if activate_before_residual:
with tf.variable_scope('shared_activation'):
x = self._batch_norm('init_bn', x)
x = self._relu(x, self.hps.relu_leakiness)
orig_x = x
else:
with tf.variable_scope('residual_only_activation'):
orig_x = x
x = self._batch_norm('init_bn', x)
x = self._relu(x, self.hps.relu_leakiness)
with tf.variable_scope('sub1'):
x = self._conv('conv1', x, 3, in_filter, out_filter, stride)
with tf.variable_scope('sub2'):
x = self._batch_norm('bn2', x)
x = self._relu(x, self.hps.relu_leakiness)
x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1])
with tf.variable_scope('sub_add'):
if in_filter != out_filter:
orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID',
data_format=self.hps.data_format)
if self.hps.data_format == 'NHWC':
orig_x = tf.pad(
orig_x, [[0, 0], [0, 0], [0, 0],
[(out_filter-in_filter)//2, (out_filter-in_filter)//2]])
elif self.hps.data_format == 'NCHW':
orig_x = tf.pad(
orig_x, [[0, 0], [(out_filter-in_filter)//2, (out_filter-in_filter)//2],
[0, 0], [0, 0]])
x += orig_x
tf.logging.debug('image after unit %s', x.get_shape())
return x
def _bottleneck_residual(self, x, in_filter, out_filter, stride,
activate_before_residual=False):
"""Bottleneck residual unit with 3 sub layers."""
if activate_before_residual:
with tf.variable_scope('common_bn_relu'):
x = self._batch_norm('init_bn', x)
x = self._relu(x, self.hps.relu_leakiness)
orig_x = x
else:
with tf.variable_scope('residual_bn_relu'):
orig_x = x
x = self._batch_norm('init_bn', x)
x = self._relu(x, self.hps.relu_leakiness)
with tf.variable_scope('sub1'):
x = self._conv('conv1', x, 1, in_filter, out_filter/4, stride)
with tf.variable_scope('sub2'):
x = self._batch_norm('bn2', x)
x = self._relu(x, self.hps.relu_leakiness)
x = self._conv('conv2', x, 3, out_filter/4, out_filter/4, [1, 1, 1, 1])
with tf.variable_scope('sub3'):
x = self._batch_norm('bn3', x)
x = self._relu(x, self.hps.relu_leakiness)
x = self._conv('conv3', x, 1, out_filter/4, out_filter, [1, 1, 1, 1])
with tf.variable_scope('sub_add'):
if in_filter != out_filter:
orig_x = self._conv('project', orig_x, 1, in_filter, out_filter, stride)
x += orig_x
tf.logging.info('image after unit %s', x.get_shape())
return x
def _decay(self):
"""L2 weight decay loss."""
costs = []
for var in tf.trainable_variables():
if var.op.name.find(r'DW') > 0:
costs.append(tf.nn.l2_loss(var))
# tf.summary.histogram(var.op.name, var)
return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs))
def _conv(self, name, x, filter_size, in_filters, out_filters, strides):
"""Convolution."""
with tf.variable_scope(name):
n = filter_size * filter_size * out_filters
kernel = tf.get_variable(
'DW', [filter_size, filter_size, in_filters, out_filters],
tf.float32, initializer=tf.random_normal_initializer(
stddev=np.sqrt(2.0/n)))
return tf.nn.conv2d(x, kernel, strides, padding='SAME',
data_format=self.hps.data_format)
def _relu(self, x, leakiness=0.0):
"""Relu, with optional leaky support."""
return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
def _fully_connected(self, x, out_dim):
"""FullyConnected layer for final output."""
x = tf.reshape(x, [self.hps.batch_size, -1])
w = tf.get_variable(
'DW', [x.get_shape()[1], out_dim],
initializer=tf.uniform_unit_scaling_initializer(factor=1.0))
b = tf.get_variable('biases', [out_dim],
initializer=tf.constant_initializer())
return tf.nn.xw_plus_b(x, w, b)
def _global_avg_pool(self, x):
assert x.get_shape().ndims == 4
if self.hps.data_format == 'NHWC':
return tf.reduce_mean(x, [1, 2])
elif self.hps.data_format == 'NCHW':
return tf.reduce_mean(x, [2, 3])