More docstring updates in preparation for api docs generation.

PiperOrigin-RevId: 368667796
This commit is contained in:
Steve Chien 2021-04-15 10:30:30 -07:00 committed by A. Unique TensorFlower
parent ca347b8995
commit 41530f4426
10 changed files with 136 additions and 73 deletions

View file

@ -60,16 +60,16 @@ class PrivacyLedger(object):
def __init__(self,
population_size,
selection_probability):
"""Initialize the PrivacyLedger.
"""Initializes the PrivacyLedger.
Args:
population_size: An integer (may be variable) specifying the size of the
population, i.e. size of the training data used in each epoch.
selection_probability: A float (may be variable) specifying the
probability each record is included in a sample.
selection_probability: A floating point value (may be variable) specifying
the probability each record is included in a sample.
Raises:
ValueError: If selection_probability is 0.
ValueError: If `selection_probability` is 0.
"""
self._population_size = population_size
self._selection_probability = selection_probability
@ -141,7 +141,7 @@ class PrivacyLedger(object):
sess: The tensorflow session in which the ledger was created.
Returns:
The query ledger as a list of SampleEntries.
The query ledger as a list of `SampleEntry` instances.
"""
sample_array = sess.run(self._sample_buffer.values)
query_array = sess.run(self._query_buffer.values)
@ -152,7 +152,7 @@ class PrivacyLedger(object):
"""Gets the formatted query ledger.
Returns:
The query ledger as a list of SampleEntries.
The query ledger as a list of `SampleEntry` instances.
"""
sample_array = self._sample_buffer.values.numpy()
query_array = self._query_buffer.values.numpy()
@ -161,21 +161,21 @@ class PrivacyLedger(object):
class QueryWithLedger(dp_query.DPQuery):
"""A class for DP queries that record events to a PrivacyLedger.
"""A class for DP queries that record events to a `PrivacyLedger`.
QueryWithLedger should be the top-level query in a structure of queries that
`QueryWithLedger` should be the top-level query in a structure of queries that
may include sum queries, nested queries, etc. It should simply wrap another
query and contain a reference to the ledger. Any contained queries (including
those contained in the leaves of a nested query) should also contain a
reference to the same ledger object.
For example usage, see privacy_ledger_test.py.
For example usage, see `privacy_ledger_test.py`.
"""
def __init__(self, query,
population_size=None, selection_probability=None,
ledger=None):
"""Initializes the QueryWithLedger.
"""Initializes the `QueryWithLedger`.
Args:
query: The query whose events should be recorded to the ledger. Any
@ -183,12 +183,12 @@ class QueryWithLedger(dp_query.DPQuery):
contain a reference to the same ledger given here.
population_size: An integer (may be variable) specifying the size of the
population, i.e. size of the training data used in each epoch. May be
None if `ledger` is specified.
selection_probability: A float (may be variable) specifying the
probability each record is included in a sample. May be None if `ledger`
is specified.
ledger: A PrivacyLedger to use. Must be specified if either of
`population_size` or `selection_probability` is None.
`None` if `ledger` is specified.
selection_probability: A floating point value (may be variable) specifying
the probability each record is included in a sample. May be `None` if
`ledger` is specified.
ledger: A `PrivacyLedger` to use. Must be specified if either of
`population_size` or `selection_probability` is `None`.
"""
self._query = query
if population_size is not None and selection_probability is not None:

View file

@ -293,7 +293,7 @@ def _compute_rdp(q, sigma, alpha):
def compute_rdp(q, noise_multiplier, steps, orders):
"""Compute RDP of the Sampled Gaussian Mechanism.
"""Computes RDP of the Sampled Gaussian Mechanism.
Args:
q: The sampling rate.
@ -303,7 +303,7 @@ def compute_rdp(q, noise_multiplier, steps, orders):
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders, can be np.inf.
The RDPs at all orders. Can be `np.inf`.
"""
if np.isscalar(orders):
rdp = _compute_rdp(q, noise_multiplier, orders)
@ -316,7 +316,7 @@ def compute_rdp(q, noise_multiplier, steps, orders):
def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
steps_list, orders):
"""Compute RDP of Heteregoneous Applications of Sampled Gaussian Mechanisms.
"""Computes RDP of Heteregoneous Applications of Sampled Gaussian Mechanisms.
Args:
sampling_probabilities: A list containing the sampling rates.
@ -328,7 +328,7 @@ def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders, can be np.inf.
The RDPs at all orders. Can be `np.inf`.
"""
assert len(sampling_probabilities) == len(noise_multipliers)
@ -341,18 +341,19 @@ def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
"""Compute delta (or eps) for given eps (or delta) from RDP values.
"""Computes delta (or eps) for given eps (or delta) from RDP values.
Args:
orders: An array (or a scalar) of RDP orders.
rdp: An array of RDP values. Must be of the same length as the orders list.
target_eps: If not None, the epsilon for which we compute the corresponding
delta.
target_delta: If not None, the delta for which we compute the corresponding
epsilon. Exactly one of target_eps and target_delta must be None.
target_eps: If not `None`, the epsilon for which we compute the
corresponding delta.
target_delta: If not `None`, the delta for which we compute the
corresponding epsilon. Exactly one of `target_eps` and `target_delta`
must be `None`.
Returns:
eps, delta, opt_order.
A tuple of epsilon, delta, and the optimal order.
Raises:
ValueError: If target_eps and target_delta are messed up.
@ -374,14 +375,14 @@ def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
def compute_rdp_from_ledger(ledger, orders):
"""Compute RDP of Sampled Gaussian Mechanism from ledger.
"""Computes RDP of Sampled Gaussian Mechanism from ledger.
Args:
ledger: A formatted privacy ledger.
orders: An array (or a scalar) of RDP orders.
Returns:
RDP at all orders, can be np.inf.
RDP at all orders. Can be `np.inf`.
"""
total_rdp = np.zeros_like(orders, dtype=float)
for sample in ledger:

View file

@ -159,7 +159,7 @@ class DPQuery(object):
Returns:
The updated sample state. In standard DP-SGD training, the set of
previous mcrobatch gradients with the addition of the record argument.
previous microbatch gradients with the addition of the record argument.
"""
preprocessed_record = self.preprocess_record(params, record)
return self.accumulate_preprocessed_record(

View file

@ -43,6 +43,7 @@ class DNNClassifier(tf.estimator.Estimator):
loss_reduction=tf.keras.losses.Reduction.NONE,
batch_norm=False,
):
"""See `tf.estimator.DNNClassifier`."""
head = head_utils.binary_or_multi_class_head(
n_classes,
weight_column=weight_column,

View file

@ -41,6 +41,7 @@ class DNNClassifier(tf.estimator.Estimator):
loss_reduction=tf.compat.v1.losses.Reduction.SUM,
batch_norm=False,
):
"""See `tf.compat.v1.estimator.DNNClassifier`."""
head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access
n_classes, weight_column, label_vocabulary, loss_reduction)
estimator._canned_estimator_api_gauge.get_cell('Classifier').set('DNN')

View file

@ -19,8 +19,8 @@ import tensorflow as tf
def make_dp_model_class(cls):
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
class DPModelClass(cls):
"""A DP version of `cls`, which should be a subclass of `tf.keras.Model`."""
class DPModelClass(cls): # pylint: disable=empty-docstring
__doc__ = ('DP subclass of `tf.keras.{}`.').format(cls.__name__)
def __init__(
self,
@ -37,6 +37,9 @@ def make_dp_model_class(cls):
noise_multiplier: Ratio of the standard deviation to the clipping
norm.
use_xla: If `True`, compiles train_step to XLA.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__`
method.
"""
super(DPModelClass, self).__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip
@ -78,6 +81,7 @@ def make_dp_model_class(cls):
return tf.squeeze(y_pred, axis=0), loss, clipped_grads
def train_step(self, data):
"""DP-SGD version of base class method."""
_, y = data
y_pred, _, per_eg_grads = tf.vectorized_map(
self._compute_per_example_grads, data)
@ -87,8 +91,6 @@ def make_dp_model_class(cls):
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
DPModelClass.__doc__ = ('DP subclass of `tf.keras.{}`.').format(cls.__name__)
return DPModelClass

View file

@ -26,7 +26,15 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
def make_optimizer_class(cls):
"""Constructs a DP optimizer class from an existing one."""
"""Given a subclass of `tf.compat.v1.train.Optimizer`, returns a DP-SGD subclass of it.
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.compat.v1.train.Optimizer`.
Returns:
A DP-SGD subclass of `cls`.
"""
parent_code = tf.train.Optimizer.compute_gradients.__code__
has_compute_gradients = hasattr(cls, 'compute_gradients')
@ -40,8 +48,8 @@ def make_optimizer_class(cls):
'make_optimizer_class() does not interfere with overridden version.',
cls.__name__)
class DPOptimizerClass(cls):
"""Differentially private subclass of given class cls."""
class DPOptimizerClass(cls): # pylint: disable=empty-docstring
__doc__ = ('DP subclass of `tf.compat.v1.train.{}`.').format(cls.__name__)
def __init__(
self,
@ -50,7 +58,7 @@ def make_optimizer_class(cls):
unroll_microbatches=False,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Initialize the DPOptimizerClass.
"""Initializes the DPOptimizerClass.
Args:
dp_sum_query: `DPQuery` object, specifying differential privacy
@ -61,6 +69,8 @@ def make_optimizer_class(cls):
unroll_microbatches: If true, processes microbatches within a Python
loop instead of a `tf.while_loop`. Can be used if using a
`tf.while_loop` raises an exception.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._dp_sum_query = dp_sum_query
@ -80,6 +90,7 @@ def make_optimizer_class(cls):
colocate_gradients_with_ops=False,
grad_loss=None,
gradient_tape=None):
"""DP-SGD version of base class method."""
self._was_compute_gradients_called = True
if self._global_state is None:
self._global_state = self._dp_sum_query.initial_global_state()
@ -124,7 +135,7 @@ def make_optimizer_class(cls):
return grads_and_vars
else:
# TF is running in graph mode, check we did not receive a gradient tape.
# TF is running in graph mode. Check we did not receive a gradient tape.
if gradient_tape:
raise ValueError('When in graph mode, a tape should not be passed.')
@ -197,6 +208,8 @@ def make_optimizer_class(cls):
return list(zip(final_grads, var_list))
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
# pylint: disable=g-doc-args, g-doc-return-or-yield
"""DP-SGD version of base class method."""
assert self._was_compute_gradients_called, (
'compute_gradients() on the differentially private optimizer was not'
' called. Which means that the training is not differentially '
@ -205,17 +218,24 @@ def make_optimizer_class(cls):
return super(DPOptimizerClass,
self).apply_gradients(grads_and_vars, global_step, name)
DPOptimizerClass.__doc__ = ('DP subclass of `tf.compat.v1.train.{}`.').format(
cls.__name__)
return DPOptimizerClass
def make_gaussian_optimizer_class(cls):
"""Constructs a DP optimizer with Gaussian averaging of updates."""
"""Given a subclass of `tf.compat.v1.train.Optimizer`, returns a subclass using DP-SGD with Gaussian averaging.
class DPGaussianOptimizerClass(make_optimizer_class(cls)):
"""DP subclass of given class cls using Gaussian averaging."""
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.compat.v1.train.Optimizer`.
Returns:
A subclass of `cls` using DP-SGD with Gaussian averaging.
"""
class DPGaussianOptimizerClass(make_optimizer_class(cls)): # pylint: disable=empty-docstring
__doc__ = (
'DP subclass of `tf.compat.v1.train.{}` using Gaussian averaging.'
).format(cls.__name__)
def __init__(
self,
@ -226,6 +246,21 @@ def make_gaussian_optimizer_class(cls):
unroll_microbatches=False,
*args, # pylint: disable=keyword-arg-before-vararg
**kwargs):
"""Initializes the `DPGaussianOptimizerClass`.
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch is
split. If `None`, will default to the size of the minibatch, and
per-example gradients will be computed.
ledger: Defaults to `None`. An instance of `tf_privacy.PrivacyLedger`.
unroll_microbatches: If true, processes microbatches within a Python
loop instead of a `tf.while_loop`. Can be used if using a
`tf.while_loop` raises an exception.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier
self._num_microbatches = num_microbatches
@ -268,10 +303,6 @@ def make_gaussian_optimizer_class(cls):
def ledger(self):
return self._dp_sum_query.ledger
DPGaussianOptimizerClass.__doc__ = (
'DP subclass of `tf.train.{}` using Gaussian averaging.').format(
cls.__name__)
return DPGaussianOptimizerClass
AdagradOptimizer = tf.train.AdagradOptimizer

View file

@ -24,10 +24,18 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
def make_keras_optimizer_class(cls):
"""Constructs a DP Keras optimizer class from an existing one."""
"""Given a subclass of `tf.keras.optimizers.Optimizer`, returns a DP-SGD subclass of it.
class DPOptimizerClass(cls):
"""Differentially private subclass of given class cls.
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.Optimizer`.
Returns:
A DP-SGD subclass of `cls`.
"""
class DPOptimizerClass(cls): # pylint: disable=empty-docstring
__doc__ = """Differentially private subclass of given class `tf.keras.optimizers.{}.
The class tf.keras.optimizers.Optimizer has two methods to compute
gradients, `_compute_gradients` and `get_gradients`. The first works
@ -37,7 +45,7 @@ def make_keras_optimizer_class(cls):
Internally, DPOptimizerClass stores hyperparameters both individually
and encapsulated in a `GaussianSumQuery` object for these two use cases.
However, this should be invisible to users of this class.
"""
""".format(cls.__name__)
def __init__(
self,
@ -53,6 +61,8 @@ def make_keras_optimizer_class(cls):
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch
is split.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip
@ -64,7 +74,7 @@ def make_keras_optimizer_class(cls):
self._was_dp_gradients_called = False
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
"""DP version of superclass method."""
"""DP-SGD version of base class method."""
self._was_dp_gradients_called = True
# Compute loss.
@ -120,7 +130,7 @@ def make_keras_optimizer_class(cls):
return list(zip(final_gradients, var_list))
def get_gradients(self, loss, params):
"""DP version of superclass method."""
"""DP-SGD version of base class method."""
self._was_dp_gradients_called = True
if self._global_state is None:
@ -160,6 +170,7 @@ def make_keras_optimizer_class(cls):
return final_grads
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""DP-SGD version of base class method."""
assert self._was_dp_gradients_called, (
'Neither _compute_gradients() or get_gradients() on the '
'differentially private optimizer was called. This means the '
@ -169,10 +180,6 @@ def make_keras_optimizer_class(cls):
return super(DPOptimizerClass,
self).apply_gradients(grads_and_vars, global_step, name)
DPOptimizerClass.__doc__ = (
'DP subclass of `tf.keras.optimizers.{}` using Gaussian averaging.'
).format(cls.__name__)
return DPOptimizerClass

View file

@ -37,10 +37,18 @@ def clip_gradients_vmap(g, l2_norm_clip):
def make_vectorized_keras_optimizer_class(cls):
"""Constructs a DP Keras optimizer class from an existing one."""
"""Given a subclass of `tf.keras.optimizers.Optimizer`, returns a vectorized DP-SGD subclass of it.
class DPOptimizerClass(cls):
"""Differentially private subclass of given class cls.
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.Optimizer`.
Returns:
A vectorized DP-SGD subclass of `cls`.
"""
class DPOptimizerClass(cls): # pylint: disable=empty-docstring
__doc__ = """Vectorized differentially private subclass of given class `tf.keras.optimizers.{}.
The class tf.keras.optimizers.Optimizer has two methods to compute
gradients, `_compute_gradients` and `get_gradients`. The first works
@ -50,7 +58,7 @@ def make_vectorized_keras_optimizer_class(cls):
Internally, DPOptimizerClass stores hyperparameters both individually
and encapsulated in a `GaussianSumQuery` object for these two use cases.
However, this should be invisible to users of this class.
"""
""".format(cls.__name__)
def __init__(
self,
@ -66,6 +74,8 @@ def make_vectorized_keras_optimizer_class(cls):
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch
is split.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip
@ -77,7 +87,7 @@ def make_vectorized_keras_optimizer_class(cls):
self._was_dp_gradients_called = False
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
"""DP version of superclass method."""
"""DP-SGD version of base class method."""
self._was_dp_gradients_called = True
# Compute loss.
@ -130,7 +140,7 @@ def make_vectorized_keras_optimizer_class(cls):
return list(zip(final_gradients, var_list))
def get_gradients(self, loss, params):
"""DP version of superclass method."""
"""DP-SGD version of base class method."""
self._was_dp_gradients_called = True
if self._global_state is None:
@ -168,6 +178,8 @@ def make_vectorized_keras_optimizer_class(cls):
return final_grads
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""DP-SGD version of base class method."""
assert self._was_dp_gradients_called, (
'Neither _compute_gradients() or get_gradients() on the '
'differentially private optimizer was called. This means the '
@ -177,9 +189,6 @@ def make_vectorized_keras_optimizer_class(cls):
return super(DPOptimizerClass,
self).apply_gradients(grads_and_vars, global_step, name)
DPOptimizerClass.__doc__ = (
'Vectorized DP subclass of `tf.keras.optimizers.{}` using Gaussian '
'averaging.').format(cls.__name__)
return DPOptimizerClass

View file

@ -29,7 +29,15 @@ GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
def make_vectorized_optimizer_class(cls):
"""Constructs a vectorized DP optimizer class from an existing one."""
"""Given a subclass of `tf.compat.v1.train.Optimizer`, returns a vectorized DP-SGD subclass of it.
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.compat.v1.train.Optimizer`.
Returns:
A DP-SGD subclass of `cls`.
"""
child_code = cls.compute_gradients.__code__
if child_code is not parent_code:
logging.warning(
@ -38,8 +46,11 @@ def make_vectorized_optimizer_class(cls):
'make_optimizer_class() does not interfere with overridden version.',
cls.__name__)
class DPOptimizerClass(cls):
"""Differentially private subclass of given class cls."""
class DPOptimizerClass(cls): # pylint: disable=empty-docstring
__doc__ = (
'Vectorized DP subclass of `tf.compat.v1.train.{}` using Gaussian '
'averaging.'
).format(cls.__name__)
def __init__(
self,
@ -56,6 +67,8 @@ def make_vectorized_optimizer_class(cls):
num_microbatches: Number of microbatches into which each minibatch is
split. If `None`, will default to the size of the minibatch, and
per-example gradients will be computed.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip
@ -70,6 +83,7 @@ def make_vectorized_optimizer_class(cls):
colocate_gradients_with_ops=False,
grad_loss=None,
gradient_tape=None):
"""DP-SGD version of base class method."""
if callable(loss):
# TF is running in Eager mode
raise NotImplementedError('Vectorized optimizer unavailable for TF2.')
@ -136,9 +150,6 @@ def make_vectorized_optimizer_class(cls):
return list(zip(final_grads, var_list))
DPOptimizerClass.__doc__ = (
'Vectorized DP subclass of `tf.compat.v1.train.{}` using '
'Gaussian averaging.').format(cls.__name__)
return DPOptimizerClass