fix vector loss issue with Keras by instantiate a loss object

PiperOrigin-RevId: 239483918
This commit is contained in:
Nicolas Papernot 2019-03-20 15:08:42 -07:00 committed by A. Unique TensorFlower
parent 0ebd134d99
commit 3c1e9994eb

View file

@ -11,37 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training a CNN on MNIST with Keras and the DP SGD optimizer.
**************************** PLEASE READ ME ************************************
A modification to Keras needed for this tutorial to work as it is currently
written is *being* pushed. While this modification is in the works, you can
make this tutorial work by making the following change to the TensorFlow source
code (disabling the reduction of the loss used to compile a model):
Diff for file: tensorflow/python/keras/engine/training_utils.py
```
+ from tensorflow.python.ops.losses import losses_impl
def get_loss_function():
...
- return losses.LossFunctionWrapper(loss_fn, name=loss_fn.__name__)
+ return losses.LossFunctionWrapper(loss_fn,
+ name=loss_fn.__name__,
+ reduction=losses_impl.Reduction.NONE)
```
This allows the DP-SGD optimizer to have access to the loss defined per
example rather than the mean of the loss for the entire minibatch. This is
needed to compute gradients for each microbatch contained in a minibatch.
**************************** END OF PLEASE READ ME *****************************
"""
"""Training a CNN on MNIST with Keras and the DP SGD optimizer."""
from __future__ import absolute_import
from __future__ import division
@ -138,16 +108,15 @@ def main(unused_argv):
FLAGS.microbatches,
learning_rate=FLAGS.learning_rate,
unroll_microbatches=True)
# Compute vector of per-example loss rather than its mean over a minibatch.
loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE)
else:
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
def keras_loss_fn(labels, logits):
"""This removes the mandatory named arguments for this loss fn."""
return tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels,
logits=logits)
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
# Compile model with Keras
model.compile(optimizer=optimizer, loss=keras_loss_fn, metrics=['accuracy'])
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
# Train model with Keras
model.fit(train_data, train_labels,