diff --git a/tutorials/mnist_dpsgd_tutorial_keras.py b/tutorials/mnist_dpsgd_tutorial_keras.py index 0638412..b3e0845 100644 --- a/tutorials/mnist_dpsgd_tutorial_keras.py +++ b/tutorials/mnist_dpsgd_tutorial_keras.py @@ -11,37 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Training a CNN on MNIST with Keras and the DP SGD optimizer. - -**************************** PLEASE READ ME ************************************ - -A modification to Keras needed for this tutorial to work as it is currently -written is *being* pushed. While this modification is in the works, you can -make this tutorial work by making the following change to the TensorFlow source -code (disabling the reduction of the loss used to compile a model): - -Diff for file: tensorflow/python/keras/engine/training_utils.py - -``` -+ from tensorflow.python.ops.losses import losses_impl - - def get_loss_function(): - - ... - -- return losses.LossFunctionWrapper(loss_fn, name=loss_fn.__name__) -+ return losses.LossFunctionWrapper(loss_fn, -+ name=loss_fn.__name__, -+ reduction=losses_impl.Reduction.NONE) -``` - -This allows the DP-SGD optimizer to have access to the loss defined per -example rather than the mean of the loss for the entire minibatch. This is -needed to compute gradients for each microbatch contained in a minibatch. - -**************************** END OF PLEASE READ ME ***************************** - -""" +"""Training a CNN on MNIST with Keras and the DP SGD optimizer.""" from __future__ import absolute_import from __future__ import division @@ -138,16 +108,15 @@ def main(unused_argv): FLAGS.microbatches, learning_rate=FLAGS.learning_rate, unroll_microbatches=True) + # Compute vector of per-example loss rather than its mean over a minibatch. + loss = tf.keras.losses.CategoricalCrossentropy( + from_logits=True, reduction=tf.losses.Reduction.NONE) else: optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate) - - def keras_loss_fn(labels, logits): - """This removes the mandatory named arguments for this loss fn.""" - return tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, - logits=logits) + loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True) # Compile model with Keras - model.compile(optimizer=optimizer, loss=keras_loss_fn, metrics=['accuracy']) + model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) # Train model with Keras model.fit(train_data, train_labels,