# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function from datetime import datetime import math import numpy as np from six.moves import xrange import tensorflow as tf import time from differential_privacy.multiple_teachers import utils FLAGS = tf.app.flags.FLAGS # Basic model parameters. tf.app.flags.DEFINE_integer('dropout_seed', 123, """seed for dropout.""") tf.app.flags.DEFINE_integer('batch_size', 128, """Nb of images in a batch.""") tf.app.flags.DEFINE_integer('epochs_per_decay', 350, """Nb epochs per decay""") tf.app.flags.DEFINE_integer('learning_rate', 5, """100 * learning rate""") tf.app.flags.DEFINE_boolean('log_device_placement', False, """see TF doc""") # Constants describing the training process. MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor. def _variable_on_cpu(name, shape, initializer): """Helper to create a Variable stored on CPU memory. Args: name: name of the variable shape: list of ints initializer: initializer for Variable Returns: Variable Tensor """ with tf.device('/cpu:0'): var = tf.get_variable(name, shape, initializer=initializer) return var def _variable_with_weight_decay(name, shape, stddev, wd): """Helper to create an initialized Variable with weight decay. Note that the Variable is initialized with a truncated normal distribution. A weight decay is added only if one is specified. Args: name: name of the variable shape: list of ints stddev: standard deviation of a truncated Gaussian wd: add L2Loss weight decay multiplied by this float. If None, weight decay is not added for this Variable. Returns: Variable Tensor """ var = _variable_on_cpu(name, shape, tf.truncated_normal_initializer(stddev=stddev)) if wd is not None: weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') tf.add_to_collection('losses', weight_decay) return var def inference(images, dropout=False): """Build the CNN model. Args: images: Images returned from distorted_inputs() or inputs(). dropout: Boolean controlling whether to use dropout or not Returns: Logits """ if FLAGS.dataset == 'mnist': first_conv_shape = [5, 5, 1, 64] else: first_conv_shape = [5, 5, 3, 64] # conv1 with tf.variable_scope('conv1') as scope: kernel = _variable_with_weight_decay('weights', shape=first_conv_shape, stddev=1e-4, wd=0.0) conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0)) bias = tf.nn.bias_add(conv, biases) conv1 = tf.nn.relu(bias, name=scope.name) if dropout: conv1 = tf.nn.dropout(conv1, 0.3, seed=FLAGS.dropout_seed) # pool1 pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1') # norm1 norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1') # conv2 with tf.variable_scope('conv2') as scope: kernel = _variable_with_weight_decay('weights', shape=[5, 5, 64, 128], stddev=1e-4, wd=0.0) conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [128], tf.constant_initializer(0.1)) bias = tf.nn.bias_add(conv, biases) conv2 = tf.nn.relu(bias, name=scope.name) if dropout: conv2 = tf.nn.dropout(conv2, 0.3, seed=FLAGS.dropout_seed) # norm2 norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2') # pool2 pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool2') # local3 with tf.variable_scope('local3') as scope: # Move everything into depth so we can perform a single matrix multiply. reshape = tf.reshape(pool2, [FLAGS.batch_size, -1]) dim = reshape.get_shape()[1].value weights = _variable_with_weight_decay('weights', shape=[dim, 384], stddev=0.04, wd=0.004) biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1)) local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name) if dropout: local3 = tf.nn.dropout(local3, 0.5, seed=FLAGS.dropout_seed) # local4 with tf.variable_scope('local4') as scope: weights = _variable_with_weight_decay('weights', shape=[384, 192], stddev=0.04, wd=0.004) biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1)) local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name) if dropout: local4 = tf.nn.dropout(local4, 0.5, seed=FLAGS.dropout_seed) # compute logits with tf.variable_scope('softmax_linear') as scope: weights = _variable_with_weight_decay('weights', [192, FLAGS.nb_labels], stddev=1/192.0, wd=0.0) biases = _variable_on_cpu('biases', [FLAGS.nb_labels], tf.constant_initializer(0.0)) logits = tf.add(tf.matmul(local4, weights), biases, name=scope.name) return logits def inference_deeper(images, dropout=False): """Build a deeper CNN model. Args: images: Images returned from distorted_inputs() or inputs(). dropout: Boolean controlling whether to use dropout or not Returns: Logits """ if FLAGS.dataset == 'mnist': first_conv_shape = [3, 3, 1, 96] else: first_conv_shape = [3, 3, 3, 96] # conv1 with tf.variable_scope('conv1') as scope: kernel = _variable_with_weight_decay('weights', shape=first_conv_shape, stddev=0.05, wd=0.0) conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0)) bias = tf.nn.bias_add(conv, biases) conv1 = tf.nn.relu(bias, name=scope.name) # conv2 with tf.variable_scope('conv2') as scope: kernel = _variable_with_weight_decay('weights', shape=[3, 3, 96, 96], stddev=0.05, wd=0.0) conv = tf.nn.conv2d(conv1, kernel, [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0)) bias = tf.nn.bias_add(conv, biases) conv2 = tf.nn.relu(bias, name=scope.name) # conv3 with tf.variable_scope('conv3') as scope: kernel = _variable_with_weight_decay('weights', shape=[3, 3, 96, 96], stddev=0.05, wd=0.0) conv = tf.nn.conv2d(conv2, kernel, [1, 2, 2, 1], padding='SAME') biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0)) bias = tf.nn.bias_add(conv, biases) conv3 = tf.nn.relu(bias, name=scope.name) if dropout: conv3 = tf.nn.dropout(conv3, 0.5, seed=FLAGS.dropout_seed) # conv4 with tf.variable_scope('conv4') as scope: kernel = _variable_with_weight_decay('weights', shape=[3, 3, 96, 192], stddev=0.05, wd=0.0) conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0)) bias = tf.nn.bias_add(conv, biases) conv4 = tf.nn.relu(bias, name=scope.name) # conv5 with tf.variable_scope('conv5') as scope: kernel = _variable_with_weight_decay('weights', shape=[3, 3, 192, 192], stddev=0.05, wd=0.0) conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0)) bias = tf.nn.bias_add(conv, biases) conv5 = tf.nn.relu(bias, name=scope.name) # conv6 with tf.variable_scope('conv6') as scope: kernel = _variable_with_weight_decay('weights', shape=[3, 3, 192, 192], stddev=0.05, wd=0.0) conv = tf.nn.conv2d(conv5, kernel, [1, 2, 2, 1], padding='SAME') biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0)) bias = tf.nn.bias_add(conv, biases) conv6 = tf.nn.relu(bias, name=scope.name) if dropout: conv6 = tf.nn.dropout(conv6, 0.5, seed=FLAGS.dropout_seed) # conv7 with tf.variable_scope('conv7') as scope: kernel = _variable_with_weight_decay('weights', shape=[5, 5, 192, 192], stddev=1e-4, wd=0.0) conv = tf.nn.conv2d(conv6, kernel, [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1)) bias = tf.nn.bias_add(conv, biases) conv7 = tf.nn.relu(bias, name=scope.name) # local1 with tf.variable_scope('local1') as scope: # Move everything into depth so we can perform a single matrix multiply. reshape = tf.reshape(conv7, [FLAGS.batch_size, -1]) dim = reshape.get_shape()[1].value weights = _variable_with_weight_decay('weights', shape=[dim, 192], stddev=0.05, wd=0) biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1)) local1 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name) # local2 with tf.variable_scope('local2') as scope: weights = _variable_with_weight_decay('weights', shape=[192, 192], stddev=0.05, wd=0) biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1)) local2 = tf.nn.relu(tf.matmul(local1, weights) + biases, name=scope.name) if dropout: local2 = tf.nn.dropout(local2, 0.5, seed=FLAGS.dropout_seed) # compute logits with tf.variable_scope('softmax_linear') as scope: weights = _variable_with_weight_decay('weights', [192, FLAGS.nb_labels], stddev=0.05, wd=0.0) biases = _variable_on_cpu('biases', [FLAGS.nb_labels], tf.constant_initializer(0.0)) logits = tf.add(tf.matmul(local2, weights), biases, name=scope.name) return logits def loss_fun(logits, labels): """Add L2Loss to all the trainable variables. Add summary for "Loss" and "Loss/avg". Args: logits: Logits from inference(). labels: Labels from distorted_inputs or inputs(). 1-D tensor of shape [batch_size] distillation: if set to True, use probabilities and not class labels to compute softmax loss Returns: Loss tensor of type float. """ # Calculate the cross entropy between labels and predictions labels = tf.cast(labels, tf.int64) cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels, name='cross_entropy_per_example') # Calculate the average cross entropy loss across the batch. cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') # Add to TF collection for losses tf.add_to_collection('losses', cross_entropy_mean) # The total loss is defined as the cross entropy loss plus all of the weight # decay terms (L2 loss). return tf.add_n(tf.get_collection('losses'), name='total_loss') def moving_av(total_loss): """ Generates moving average for all losses Args: total_loss: Total loss from loss(). Returns: loss_averages_op: op for generating moving averages of losses. """ # Compute the moving average of all individual losses and the total loss. loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') losses = tf.get_collection('losses') loss_averages_op = loss_averages.apply(losses + [total_loss]) return loss_averages_op def train_op_fun(total_loss, global_step): """Train model. Create an optimizer and apply to all trainable variables. Add moving average for all trainable variables. Args: total_loss: Total loss from loss(). global_step: Integer Variable counting the number of training steps processed. Returns: train_op: op for training. """ # Variables that affect learning rate. nb_ex_per_train_epoch = int(60000 / FLAGS.nb_teachers) num_batches_per_epoch = nb_ex_per_train_epoch / FLAGS.batch_size decay_steps = int(num_batches_per_epoch * FLAGS.epochs_per_decay) initial_learning_rate = float(FLAGS.learning_rate) / 100.0 # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay(initial_learning_rate, global_step, decay_steps, LEARNING_RATE_DECAY_FACTOR, staircase=True) tf.summary.scalar('learning_rate', lr) # Generate moving averages of all losses and associated summaries. loss_averages_op = moving_av(total_loss) # Compute gradients. with tf.control_dependencies([loss_averages_op]): opt = tf.train.GradientDescentOptimizer(lr) grads = opt.compute_gradients(total_loss) # Apply gradients. apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): tf.summary.histogram(var.op.name, var) # Track the moving averages of all trainable variables. variable_averages = tf.train.ExponentialMovingAverage( MOVING_AVERAGE_DECAY, global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) with tf.control_dependencies([apply_gradient_op, variables_averages_op]): train_op = tf.no_op(name='train') return train_op def _input_placeholder(): """ This helper function declares a TF placeholder for the graph input data :return: TF placeholder for the graph input data """ if FLAGS.dataset == 'mnist': image_size = 28 num_channels = 1 else: image_size = 32 num_channels = 3 # Declare data placeholder train_node_shape = (FLAGS.batch_size, image_size, image_size, num_channels) return tf.placeholder(tf.float32, shape=train_node_shape) def train(images, labels, ckpt_path, dropout=False): """ This function contains the loop that actually trains the model. :param images: a numpy array with the input data :param labels: a numpy array with the output labels :param ckpt_path: a path (including name) where model checkpoints are saved :param dropout: Boolean, whether to use dropout or not :return: True if everything went well """ # Check training data assert len(images) == len(labels) assert images.dtype == np.float32 assert labels.dtype == np.int32 # Set default TF graph with tf.Graph().as_default(): global_step = tf.Variable(0, trainable=False) # Declare data placeholder train_data_node = _input_placeholder() # Create a placeholder to hold labels train_labels_shape = (FLAGS.batch_size,) train_labels_node = tf.placeholder(tf.int32, shape=train_labels_shape) print("Done Initializing Training Placeholders") # Build a Graph that computes the logits predictions from the placeholder if FLAGS.deeper: logits = inference_deeper(train_data_node, dropout=dropout) else: logits = inference(train_data_node, dropout=dropout) # Calculate loss loss = loss_fun(logits, train_labels_node) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = train_op_fun(loss, global_step) # Create a saver. saver = tf.train.Saver(tf.global_variables()) print("Graph constructed and saver created") # Build an initialization operation to run below. init = tf.global_variables_initializer() # Create and init sessions sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) #NOLINT(long-line) sess.run(init) print("Session ready, beginning training loop") # Initialize the number of batches data_length = len(images) nb_batches = math.ceil(data_length / FLAGS.batch_size) for step in xrange(FLAGS.max_steps): # for debug, save start time start_time = time.time() # Current batch number batch_nb = step % nb_batches # Current batch start and end indices start, end = utils.batch_indices(batch_nb, data_length, FLAGS.batch_size) # Prepare dictionnary to feed the session with feed_dict = {train_data_node: images[start:end], train_labels_node: labels[start:end]} # Run training step _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) # Compute duration of training step duration = time.time() - start_time # Sanity check assert not np.isnan(loss_value), 'Model diverged with loss = NaN' # Echo loss once in a while if step % 100 == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch)) # Save the model checkpoint periodically. if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: saver.save(sess, ckpt_path, global_step=step) return True def softmax_preds(images, ckpt_path, return_logits=False): """ Compute softmax activations (probabilities) with the model saved in the path specified as an argument :param images: a np array of images :param ckpt_path: a TF model checkpoint :param logits: if set to True, return logits instead of probabilities :return: probabilities (or logits if logits is set to True) """ # Compute nb samples and deduce nb of batches data_length = len(images) nb_batches = math.ceil(len(images) / FLAGS.batch_size) # Declare data placeholder train_data_node = _input_placeholder() # Build a Graph that computes the logits predictions from the placeholder if FLAGS.deeper: logits = inference_deeper(train_data_node) else: logits = inference(train_data_node) if return_logits: # We are returning the logits directly (no need to apply softmax) output = logits else: # Add softmax predictions to graph: will return probabilities output = tf.nn.softmax(logits) # Restore the moving average version of the learned variables for eval. variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY) variables_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variables_to_restore) # Will hold the result preds = np.zeros((data_length, FLAGS.nb_labels), dtype=np.float32) # Create TF session with tf.Session() as sess: # Restore TF session from checkpoint file saver.restore(sess, ckpt_path) # Parse data by batch for batch_nb in xrange(0, int(nb_batches+1)): # Compute batch start and end indices start, end = utils.batch_indices(batch_nb, data_length, FLAGS.batch_size) # Prepare feed dictionary feed_dict = {train_data_node: images[start:end]} # Run session ([0] because run returns a batch with len 1st dim == 1) preds[start:end, :] = sess.run([output], feed_dict=feed_dict)[0] # Reset graph to allow multiple calls tf.reset_default_graph() return preds