forked from 626_privacy/tensorflow_privacy
Update docstring for DPModel class.
PiperOrigin-RevId: 382855055
This commit is contained in:
parent
45c935832a
commit
beed219d20
1 changed files with 36 additions and 1 deletions
|
@ -20,7 +20,42 @@ def make_dp_model_class(cls):
|
||||||
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
||||||
|
|
||||||
class DPModelClass(cls): # pylint: disable=empty-docstring
|
class DPModelClass(cls): # pylint: disable=empty-docstring
|
||||||
__doc__ = ('DP subclass of `tf.keras.{}`.').format(cls.__name__)
|
__doc__ = ("""DP subclass of `{base_model}`.
|
||||||
|
|
||||||
|
This can be used as a differentially private replacement for
|
||||||
|
{base_model}. This class implements DP-SGD using the standard
|
||||||
|
Gaussian mechanism.
|
||||||
|
|
||||||
|
When instantiating this class, you need to supply several
|
||||||
|
DP-related arguments followed by the standard arguments for
|
||||||
|
`{short_base_model}`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create Model instance.
|
||||||
|
model = {dp_model_class}(l2_norm_clip=1.0, noise_multiplier=0.5, use_xla=True,
|
||||||
|
<standard arguments>)
|
||||||
|
```
|
||||||
|
|
||||||
|
You should use your {dp_model_class} instance with a standard instance
|
||||||
|
of `tf.keras.Optimizer` as the optimizer, and a standard reduced loss.
|
||||||
|
You do not need to use a differentially private optimizer.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Use a standard (non-DP) optimizer.
|
||||||
|
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
|
||||||
|
|
||||||
|
# Use a standard reduced loss.
|
||||||
|
loss = tf.keras.losses.MeanSquaredError()
|
||||||
|
|
||||||
|
model.compile(optimizer=optimizer, loss=loss)
|
||||||
|
model.fit(train_data, train_labels, epochs=1, batch_size=32)
|
||||||
|
```
|
||||||
|
|
||||||
|
""").format(base_model='tf.keras.' + cls.__name__,
|
||||||
|
short_base_model=cls.__name__,
|
||||||
|
dp_model_class='DP' + cls.__name__)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in a new issue