diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index d644203..7e7c6f4 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -20,7 +20,7 @@ def make_dp_model_class(cls): """Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it.""" class DPModelClass(cls): - """A DP version of `cls`, which should be a subclass of `tf.keras.Model`.""" + """A DP version of `cls`, which should be a subclass of `tf.keras.Model`.""" def __init__( self, @@ -36,7 +36,7 @@ def make_dp_model_class(cls): gradients). noise_multiplier: Ratio of the standard deviation to the clipping norm. - use_xla: If True, compiles train_step to XLA. + use_xla: If `True`, compiles train_step to XLA. """ super(DPModelClass, self).__init__(*args, **kwargs) self._l2_norm_clip = l2_norm_clip @@ -87,6 +87,8 @@ def make_dp_model_class(cls): self.compiled_metrics.update_state(y, y_pred) return {m.name: m.result() for m in self.metrics} + DPModelClass.__doc__ = ('DP subclass of `tf.keras.{}`.').format(cls.__name__) + return DPModelClass