Update docstrings for all DP optimizer classes.

PiperOrigin-RevId: 382811363
This commit is contained in:
Steve Chien 2021-07-02 14:18:24 -07:00 committed by A. Unique TensorFlower
parent c192a4166b
commit 45c935832a
6 changed files with 242 additions and 31 deletions

View file

@ -51,11 +51,13 @@ def _hide_layer_and_module_methods():
"""Hide methods and properties defined in the base classes of keras layers.""" """Hide methods and properties defined in the base classes of keras layers."""
# __dict__ only sees attributes defined in *this* class, not on parent classes # __dict__ only sees attributes defined in *this* class, not on parent classes
# Needed to ignore redudant subclass documentation # Needed to ignore redudant subclass documentation
model_contents = list(tf.keras.Model.__dict__.items())
layer_contents = list(tf.keras.layers.Layer.__dict__.items()) layer_contents = list(tf.keras.layers.Layer.__dict__.items())
model_contents = list(tf.keras.Model.__dict__.items())
module_contents = list(tf.Module.__dict__.items()) module_contents = list(tf.Module.__dict__.items())
optimizer_contents = list(tf.compat.v1.train.Optimizer.__dict__.items())
for name, obj in model_contents + layer_contents + module_contents + optimizer_contents:
for name, obj in model_contents + layer_contents + module_contents:
if name == '__init__': if name == '__init__':
continue continue

View file

@ -80,6 +80,10 @@ else:
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import VectorizedDPKerasSGDOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import VectorizedDPKerasSGDOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import make_vectorized_keras_optimizer_class from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import make_vectorized_keras_optimizer_class
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGDOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagrad from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagrad
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdam from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdam
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGD from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGD

View file

@ -49,7 +49,50 @@ def make_optimizer_class(cls):
cls.__name__) cls.__name__)
class DPOptimizerClass(cls): # pylint: disable=empty-docstring class DPOptimizerClass(cls): # pylint: disable=empty-docstring
__doc__ = ('DP subclass of `tf.compat.v1.train.{}`.').format(cls.__name__) __doc__ = ("""Differentially private subclass of `{base_class}`.
You can use this as a differentially private replacement for
`{base_class}`. Note that you must ensure
that any loss processed by this optimizer comes in vector
form.
This is the fully general form of the optimizer that allows you
to define your own privacy mechanism. If you are planning to use
the standard Gaussian mechanism, it is simpler to use the more
specific `{gaussian_class}` class instead.
When instantiating this optimizer, you need to supply several
DP-related arguments followed by the standard arguments for
`{short_base_class}`.
Examples:
```python
# Create GaussianSumQuery.
dp_sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip=1.0, stddev=0.5)
# Create optimizer.
opt = {dp_class}(dp_sum_query, 1, False, <standard arguments>)
```
When using the optimizer, be sure to pass in the loss as a
rank-one tensor with one entry for each example.
```python
# Compute loss as a tensor. Do not call tf.reduce_mean as you
# would with a standard optimizer.
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
train_op = opt.minimize(loss, global_step=global_step)
```
""").format(
base_class='tf.compat.v1.train.' + cls.__name__,
gaussian_class='DP' +
cls.__name__.replace('Optimizer', 'GaussianOptimizer'),
short_base_class=cls.__name__,
dp_class='DP' + cls.__name__)
def __init__( def __init__(
self, self,
@ -233,9 +276,41 @@ def make_gaussian_optimizer_class(cls):
""" """
class DPGaussianOptimizerClass(make_optimizer_class(cls)): # pylint: disable=empty-docstring class DPGaussianOptimizerClass(make_optimizer_class(cls)): # pylint: disable=empty-docstring
__doc__ = ( __doc__ = ("""DP subclass of `{}`.
'DP subclass of `tf.compat.v1.train.{}` using Gaussian averaging.'
).format(cls.__name__) You can use this as a differentially private replacement for
`tf.compat.v1.train.{}`. This optimizer implements DP-SGD using
the standard Gaussian mechanism.
When instantiating this optimizer, you need to supply several
DP-related arguments followed by the standard arguments for
`{}`.
Examples:
```python
# Create optimizer.
opt = {}(l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1,
<standard arguments>)
```
When using the optimizer, be sure to pass in the loss as a
rank-one tensor with one entry for each example.
```python
# Compute loss as a tensor. Do not call tf.reduce_mean as you
# would with a standard optimizer.
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
train_op = opt.minimize(loss, global_step=global_step)
```
""").format(
'tf.compat.v1.train.' + cls.__name__,
cls.__name__,
cls.__name__,
'DP' + cls.__name__.replace('Optimizer', 'GaussianOptimizer'))
def __init__( def __init__(
self, self,

View file

@ -35,17 +35,64 @@ def make_keras_optimizer_class(cls):
""" """
class DPOptimizerClass(cls): # pylint: disable=empty-docstring class DPOptimizerClass(cls): # pylint: disable=empty-docstring
__doc__ = """Differentially private subclass of given class `tf.keras.optimizers.{}. __doc__ = """Differentially private subclass of class `{base_class}`.
The class tf.keras.optimizers.Optimizer has two methods to compute You can use this as a differentially private replacement for
gradients, `_compute_gradients` and `get_gradients`. The first works `{base_class}`. This optimizer implements DP-SGD using
with eager execution, while the second runs in graph mode and is used the standard Gaussian mechanism.
by canned estimators.
Internally, DPOptimizerClass stores hyperparameters both individually When instantiating this optimizer, you need to supply several
and encapsulated in a `GaussianSumQuery` object for these two use cases. DP-related arguments followed by the standard arguments for
However, this should be invisible to users of this class. `{short_base_class}`.
""".format(cls.__name__)
Examples:
```python
# Create optimizer.
opt = {dp_keras_class}(l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1,
<standard arguments>)
```
When using the optimizer, be sure to pass in the loss as a
rank-one tensor with one entry for each example.
The optimizer can be used directly via its `minimize` method, or
through a Keras `Model`.
```python
# Compute loss as a tensor by using tf.losses.Reduction.NONE.
# Compute vector of per-example loss rather than its mean over a minibatch.
loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE)
# Use optimizer in a Keras model.
opt.minimize(loss, var_list=[var])
```
```python
# Compute loss as a tensor by using tf.losses.Reduction.NONE.
# Compute vector of per-example loss rather than its mean over a minibatch.
loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE)
# Use optimizer in a Keras model.
model = tf.keras.Sequential(...)
model.compile(optimizer=opt, loss=loss, metrics=['accuracy'])
model.fit(...)
```
""".format(base_class='tf.keras.optimizers.' + cls.__name__,
short_base_class=cls.__name__,
dp_keras_class='DPKeras' + cls.__name__)
# The class tf.keras.optimizers.Optimizer has two methods to compute
# gradients, `_compute_gradients` and `get_gradients`. The first works
# with eager execution, while the second runs in graph mode and is used
# by canned estimators.
# Internally, DPOptimizerClass stores hyperparameters both individually
# and encapsulated in a `GaussianSumQuery` object for these two use cases.
# However, this should be invisible to users of this class.
def __init__( def __init__(
self, self,

View file

@ -48,17 +48,61 @@ def make_vectorized_keras_optimizer_class(cls):
""" """
class DPOptimizerClass(cls): # pylint: disable=empty-docstring class DPOptimizerClass(cls): # pylint: disable=empty-docstring
__doc__ = """Vectorized differentially private subclass of given class `tf.keras.optimizers.{}. __doc__ = """Vectorized differentially private subclass of given class
`{base_class}`.
The class tf.keras.optimizers.Optimizer has two methods to compute You can use this as a differentially private replacement for
gradients, `_compute_gradients` and `get_gradients`. The first works `{base_class}`. This optimizer implements DP-SGD using
with eager execution, while the second runs in graph mode and is used the standard Gaussian mechanism. It differs from `{dp_keras_class}` in that
by canned estimators. it attempts to vectorize the gradient computation and clipping of
microbatches.
Internally, DPOptimizerClass stores hyperparameters both individually When instantiating this optimizer, you need to supply several
and encapsulated in a `GaussianSumQuery` object for these two use cases. DP-related arguments followed by the standard arguments for
However, this should be invisible to users of this class. `{short_base_class}`.
""".format(cls.__name__)
Examples:
```python
# Create optimizer.
opt = {dp_vectorized_keras_class}(l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1,
<standard arguments>)
```
When using the optimizer, be sure to pass in the loss as a
rank-one tensor with one entry for each example.
The optimizer can be used directly via its `minimize` method, or
through a Keras `Model`.
```python
# Compute loss as a tensor by using tf.losses.Reduction.NONE.
# Compute vector of per-example loss rather than its mean over a minibatch.
# (Side note: Always verify that the output shape when using
# tf.losses.Reduction.NONE-- it can sometimes be surprising.
loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE)
# Use optimizer in a Keras model.
opt.minimize(loss, var_list=[var])
```
```python
# Compute loss as a tensor by using tf.losses.Reduction.NONE.
# Compute vector of per-example loss rather than its mean over a minibatch.
loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.losses.Reduction.NONE)
# Use optimizer in a Keras model.
model = tf.keras.Sequential(...)
model.compile(optimizer=opt, loss=loss, metrics=['accuracy'])
model.fit(...)
```
""".format(base_class='tf.keras.optimizers.' + cls.__name__,
dp_keras_class='DPKeras' + cls.__name__,
short_base_class=cls.__name__,
dp_vectorized_keras_class='VectorizedDPKeras' + cls.__name__)
def __init__( def __init__(
self, self,

View file

@ -47,10 +47,44 @@ def make_vectorized_optimizer_class(cls):
cls.__name__) cls.__name__)
class DPOptimizerClass(cls): # pylint: disable=empty-docstring class DPOptimizerClass(cls): # pylint: disable=empty-docstring
__doc__ = ( __doc__ = ("""Vectorized DP subclass of `{base_class}` using Gaussian
'Vectorized DP subclass of `tf.compat.v1.train.{}` using Gaussian ' averaging.
'averaging.'
).format(cls.__name__) You can use this as a differentially private replacement for
`{base_class}`. This optimizer implements DP-SGD using
the standard Gaussian mechanism. It differs from `{dp_class}` in that
it attempts to vectorize the gradient computation and clipping of
microbatches.
When instantiating this optimizer, you need to supply several
DP-related arguments followed by the standard arguments for
`{short_base_class}`.
Examples:
```python
# Create optimizer.
opt = {dp_vectorized_class}(l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1,
<standard arguments>)
```
When using the optimizer, be sure to pass in the loss as a
rank-one tensor with one entry for each example.
```python
# Compute loss as a tensor. Do not call tf.reduce_mean as you
# would with a standard optimizer.
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
train_op = opt.minimize(loss, global_step=global_step)
```
""").format(
base_class='tf.compat.v1.train.' + cls.__name__,
dp_class='DP' +
cls.__name__.replace('Optimizer', 'GaussianOptimizer'),
short_base_class=cls.__name__,
dp_vectorized_class='VectorizedDP' + cls.__name__)
def __init__( def __init__(
self, self,
@ -153,6 +187,11 @@ def make_vectorized_optimizer_class(cls):
return DPOptimizerClass return DPOptimizerClass
VectorizedDPAdagrad = make_vectorized_optimizer_class(AdagradOptimizer) VectorizedDPAdagradOptimizer = make_vectorized_optimizer_class(AdagradOptimizer)
VectorizedDPAdam = make_vectorized_optimizer_class(AdamOptimizer) VectorizedDPAdamOptimizer = make_vectorized_optimizer_class(AdamOptimizer)
VectorizedDPSGD = make_vectorized_optimizer_class(GradientDescentOptimizer) VectorizedDPSGDOptimizer = make_vectorized_optimizer_class(
GradientDescentOptimizer)
VectorizedDPAdagrad = VectorizedDPAdagradOptimizer
VectorizedDPAdam = VectorizedDPAdamOptimizer
VectorizedDPSGD = VectorizedDPSGDOptimizer