Update per-class descriptions for DP Keras Model classes.
PiperOrigin-RevId: 367515250
This commit is contained in:
parent
121982deb1
commit
3c64cce796
1 changed files with 4 additions and 2 deletions
|
@ -36,7 +36,7 @@ def make_dp_model_class(cls):
|
||||||
gradients).
|
gradients).
|
||||||
noise_multiplier: Ratio of the standard deviation to the clipping
|
noise_multiplier: Ratio of the standard deviation to the clipping
|
||||||
norm.
|
norm.
|
||||||
use_xla: If True, compiles train_step to XLA.
|
use_xla: If `True`, compiles train_step to XLA.
|
||||||
"""
|
"""
|
||||||
super(DPModelClass, self).__init__(*args, **kwargs)
|
super(DPModelClass, self).__init__(*args, **kwargs)
|
||||||
self._l2_norm_clip = l2_norm_clip
|
self._l2_norm_clip = l2_norm_clip
|
||||||
|
@ -87,6 +87,8 @@ def make_dp_model_class(cls):
|
||||||
self.compiled_metrics.update_state(y, y_pred)
|
self.compiled_metrics.update_state(y, y_pred)
|
||||||
return {m.name: m.result() for m in self.metrics}
|
return {m.name: m.result() for m in self.metrics}
|
||||||
|
|
||||||
|
DPModelClass.__doc__ = ('DP subclass of `tf.keras.{}`.').format(cls.__name__)
|
||||||
|
|
||||||
return DPModelClass
|
return DPModelClass
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue