Update per-class descriptions for DP Keras Model classes.

PiperOrigin-RevId: 367515250
This commit is contained in:
Steve Chien 2021-04-08 15:06:12 -07:00 committed by A. Unique TensorFlower
parent 121982deb1
commit 3c64cce796

View file

@ -20,7 +20,7 @@ def make_dp_model_class(cls):
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
class DPModelClass(cls):
"""A DP version of `cls`, which should be a subclass of `tf.keras.Model`."""
"""A DP version of `cls`, which should be a subclass of `tf.keras.Model`."""
def __init__(
self,
@ -36,7 +36,7 @@ def make_dp_model_class(cls):
gradients).
noise_multiplier: Ratio of the standard deviation to the clipping
norm.
use_xla: If True, compiles train_step to XLA.
use_xla: If `True`, compiles train_step to XLA.
"""
super(DPModelClass, self).__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip
@ -87,6 +87,8 @@ def make_dp_model_class(cls):
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
DPModelClass.__doc__ = ('DP subclass of `tf.keras.{}`.').format(cls.__name__)
return DPModelClass