fix vector loss issue with Keras by instantiate a loss object
PiperOrigin-RevId: 239483918
This commit is contained in:
parent
0ebd134d99
commit
3c1e9994eb
1 changed files with 6 additions and 37 deletions
|
@ -11,37 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Training a CNN on MNIST with Keras and the DP SGD optimizer.
|
||||
|
||||
**************************** PLEASE READ ME ************************************
|
||||
|
||||
A modification to Keras needed for this tutorial to work as it is currently
|
||||
written is *being* pushed. While this modification is in the works, you can
|
||||
make this tutorial work by making the following change to the TensorFlow source
|
||||
code (disabling the reduction of the loss used to compile a model):
|
||||
|
||||
Diff for file: tensorflow/python/keras/engine/training_utils.py
|
||||
|
||||
```
|
||||
+ from tensorflow.python.ops.losses import losses_impl
|
||||
|
||||
def get_loss_function():
|
||||
|
||||
...
|
||||
|
||||
- return losses.LossFunctionWrapper(loss_fn, name=loss_fn.__name__)
|
||||
+ return losses.LossFunctionWrapper(loss_fn,
|
||||
+ name=loss_fn.__name__,
|
||||
+ reduction=losses_impl.Reduction.NONE)
|
||||
```
|
||||
|
||||
This allows the DP-SGD optimizer to have access to the loss defined per
|
||||
example rather than the mean of the loss for the entire minibatch. This is
|
||||
needed to compute gradients for each microbatch contained in a minibatch.
|
||||
|
||||
**************************** END OF PLEASE READ ME *****************************
|
||||
|
||||
"""
|
||||
"""Training a CNN on MNIST with Keras and the DP SGD optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -138,16 +108,15 @@ def main(unused_argv):
|
|||
FLAGS.microbatches,
|
||||
learning_rate=FLAGS.learning_rate,
|
||||
unroll_microbatches=True)
|
||||
# Compute vector of per-example loss rather than its mean over a minibatch.
|
||||
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.losses.Reduction.NONE)
|
||||
else:
|
||||
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
|
||||
def keras_loss_fn(labels, logits):
|
||||
"""This removes the mandatory named arguments for this loss fn."""
|
||||
return tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels,
|
||||
logits=logits)
|
||||
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
|
||||
|
||||
# Compile model with Keras
|
||||
model.compile(optimizer=optimizer, loss=keras_loss_fn, metrics=['accuracy'])
|
||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||
|
||||
# Train model with Keras
|
||||
model.fit(train_data, train_labels,
|
||||
|
|
Loading…
Reference in a new issue