forked from 626_privacy/tensorflow_privacy
Update per-class descriptions for DP Keras Model classes.
PiperOrigin-RevId: 367515250
This commit is contained in:
parent
121982deb1
commit
3c64cce796
1 changed files with 4 additions and 2 deletions
|
@ -36,7 +36,7 @@ def make_dp_model_class(cls):
|
|||
gradients).
|
||||
noise_multiplier: Ratio of the standard deviation to the clipping
|
||||
norm.
|
||||
use_xla: If True, compiles train_step to XLA.
|
||||
use_xla: If `True`, compiles train_step to XLA.
|
||||
"""
|
||||
super(DPModelClass, self).__init__(*args, **kwargs)
|
||||
self._l2_norm_clip = l2_norm_clip
|
||||
|
@ -87,6 +87,8 @@ def make_dp_model_class(cls):
|
|||
self.compiled_metrics.update_state(y, y_pred)
|
||||
return {m.name: m.result() for m in self.metrics}
|
||||
|
||||
DPModelClass.__doc__ = ('DP subclass of `tf.keras.{}`.').format(cls.__name__)
|
||||
|
||||
return DPModelClass
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue