More docstring updates in preparation for api docs generation.

PiperOrigin-RevId: 368667796
This commit is contained in:
Steve Chien 2021-04-15 10:30:30 -07:00 committed by A. Unique TensorFlower
parent ca347b8995
commit 41530f4426
10 changed files with 136 additions and 73 deletions

View file

@ -60,16 +60,16 @@ class PrivacyLedger(object):
def __init__(self, def __init__(self,
population_size, population_size,
selection_probability): selection_probability):
"""Initialize the PrivacyLedger. """Initializes the PrivacyLedger.
Args: Args:
population_size: An integer (may be variable) specifying the size of the population_size: An integer (may be variable) specifying the size of the
population, i.e. size of the training data used in each epoch. population, i.e. size of the training data used in each epoch.
selection_probability: A float (may be variable) specifying the selection_probability: A floating point value (may be variable) specifying
probability each record is included in a sample. the probability each record is included in a sample.
Raises: Raises:
ValueError: If selection_probability is 0. ValueError: If `selection_probability` is 0.
""" """
self._population_size = population_size self._population_size = population_size
self._selection_probability = selection_probability self._selection_probability = selection_probability
@ -141,7 +141,7 @@ class PrivacyLedger(object):
sess: The tensorflow session in which the ledger was created. sess: The tensorflow session in which the ledger was created.
Returns: Returns:
The query ledger as a list of SampleEntries. The query ledger as a list of `SampleEntry` instances.
""" """
sample_array = sess.run(self._sample_buffer.values) sample_array = sess.run(self._sample_buffer.values)
query_array = sess.run(self._query_buffer.values) query_array = sess.run(self._query_buffer.values)
@ -152,7 +152,7 @@ class PrivacyLedger(object):
"""Gets the formatted query ledger. """Gets the formatted query ledger.
Returns: Returns:
The query ledger as a list of SampleEntries. The query ledger as a list of `SampleEntry` instances.
""" """
sample_array = self._sample_buffer.values.numpy() sample_array = self._sample_buffer.values.numpy()
query_array = self._query_buffer.values.numpy() query_array = self._query_buffer.values.numpy()
@ -161,21 +161,21 @@ class PrivacyLedger(object):
class QueryWithLedger(dp_query.DPQuery): class QueryWithLedger(dp_query.DPQuery):
"""A class for DP queries that record events to a PrivacyLedger. """A class for DP queries that record events to a `PrivacyLedger`.
QueryWithLedger should be the top-level query in a structure of queries that `QueryWithLedger` should be the top-level query in a structure of queries that
may include sum queries, nested queries, etc. It should simply wrap another may include sum queries, nested queries, etc. It should simply wrap another
query and contain a reference to the ledger. Any contained queries (including query and contain a reference to the ledger. Any contained queries (including
those contained in the leaves of a nested query) should also contain a those contained in the leaves of a nested query) should also contain a
reference to the same ledger object. reference to the same ledger object.
For example usage, see privacy_ledger_test.py. For example usage, see `privacy_ledger_test.py`.
""" """
def __init__(self, query, def __init__(self, query,
population_size=None, selection_probability=None, population_size=None, selection_probability=None,
ledger=None): ledger=None):
"""Initializes the QueryWithLedger. """Initializes the `QueryWithLedger`.
Args: Args:
query: The query whose events should be recorded to the ledger. Any query: The query whose events should be recorded to the ledger. Any
@ -183,12 +183,12 @@ class QueryWithLedger(dp_query.DPQuery):
contain a reference to the same ledger given here. contain a reference to the same ledger given here.
population_size: An integer (may be variable) specifying the size of the population_size: An integer (may be variable) specifying the size of the
population, i.e. size of the training data used in each epoch. May be population, i.e. size of the training data used in each epoch. May be
None if `ledger` is specified. `None` if `ledger` is specified.
selection_probability: A float (may be variable) specifying the selection_probability: A floating point value (may be variable) specifying
probability each record is included in a sample. May be None if `ledger` the probability each record is included in a sample. May be `None` if
is specified. `ledger` is specified.
ledger: A PrivacyLedger to use. Must be specified if either of ledger: A `PrivacyLedger` to use. Must be specified if either of
`population_size` or `selection_probability` is None. `population_size` or `selection_probability` is `None`.
""" """
self._query = query self._query = query
if population_size is not None and selection_probability is not None: if population_size is not None and selection_probability is not None:

View file

@ -293,7 +293,7 @@ def _compute_rdp(q, sigma, alpha):
def compute_rdp(q, noise_multiplier, steps, orders): def compute_rdp(q, noise_multiplier, steps, orders):
"""Compute RDP of the Sampled Gaussian Mechanism. """Computes RDP of the Sampled Gaussian Mechanism.
Args: Args:
q: The sampling rate. q: The sampling rate.
@ -303,7 +303,7 @@ def compute_rdp(q, noise_multiplier, steps, orders):
orders: An array (or a scalar) of RDP orders. orders: An array (or a scalar) of RDP orders.
Returns: Returns:
The RDPs at all orders, can be np.inf. The RDPs at all orders. Can be `np.inf`.
""" """
if np.isscalar(orders): if np.isscalar(orders):
rdp = _compute_rdp(q, noise_multiplier, orders) rdp = _compute_rdp(q, noise_multiplier, orders)
@ -316,7 +316,7 @@ def compute_rdp(q, noise_multiplier, steps, orders):
def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers, def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
steps_list, orders): steps_list, orders):
"""Compute RDP of Heteregoneous Applications of Sampled Gaussian Mechanisms. """Computes RDP of Heteregoneous Applications of Sampled Gaussian Mechanisms.
Args: Args:
sampling_probabilities: A list containing the sampling rates. sampling_probabilities: A list containing the sampling rates.
@ -328,7 +328,7 @@ def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
orders: An array (or a scalar) of RDP orders. orders: An array (or a scalar) of RDP orders.
Returns: Returns:
The RDPs at all orders, can be np.inf. The RDPs at all orders. Can be `np.inf`.
""" """
assert len(sampling_probabilities) == len(noise_multipliers) assert len(sampling_probabilities) == len(noise_multipliers)
@ -341,18 +341,19 @@ def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None): def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
"""Compute delta (or eps) for given eps (or delta) from RDP values. """Computes delta (or eps) for given eps (or delta) from RDP values.
Args: Args:
orders: An array (or a scalar) of RDP orders. orders: An array (or a scalar) of RDP orders.
rdp: An array of RDP values. Must be of the same length as the orders list. rdp: An array of RDP values. Must be of the same length as the orders list.
target_eps: If not None, the epsilon for which we compute the corresponding target_eps: If not `None`, the epsilon for which we compute the
delta. corresponding delta.
target_delta: If not None, the delta for which we compute the corresponding target_delta: If not `None`, the delta for which we compute the
epsilon. Exactly one of target_eps and target_delta must be None. corresponding epsilon. Exactly one of `target_eps` and `target_delta`
must be `None`.
Returns: Returns:
eps, delta, opt_order. A tuple of epsilon, delta, and the optimal order.
Raises: Raises:
ValueError: If target_eps and target_delta are messed up. ValueError: If target_eps and target_delta are messed up.
@ -374,14 +375,14 @@ def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
def compute_rdp_from_ledger(ledger, orders): def compute_rdp_from_ledger(ledger, orders):
"""Compute RDP of Sampled Gaussian Mechanism from ledger. """Computes RDP of Sampled Gaussian Mechanism from ledger.
Args: Args:
ledger: A formatted privacy ledger. ledger: A formatted privacy ledger.
orders: An array (or a scalar) of RDP orders. orders: An array (or a scalar) of RDP orders.
Returns: Returns:
RDP at all orders, can be np.inf. RDP at all orders. Can be `np.inf`.
""" """
total_rdp = np.zeros_like(orders, dtype=float) total_rdp = np.zeros_like(orders, dtype=float)
for sample in ledger: for sample in ledger:

View file

@ -159,7 +159,7 @@ class DPQuery(object):
Returns: Returns:
The updated sample state. In standard DP-SGD training, the set of The updated sample state. In standard DP-SGD training, the set of
previous mcrobatch gradients with the addition of the record argument. previous microbatch gradients with the addition of the record argument.
""" """
preprocessed_record = self.preprocess_record(params, record) preprocessed_record = self.preprocess_record(params, record)
return self.accumulate_preprocessed_record( return self.accumulate_preprocessed_record(

View file

@ -43,6 +43,7 @@ class DNNClassifier(tf.estimator.Estimator):
loss_reduction=tf.keras.losses.Reduction.NONE, loss_reduction=tf.keras.losses.Reduction.NONE,
batch_norm=False, batch_norm=False,
): ):
"""See `tf.estimator.DNNClassifier`."""
head = head_utils.binary_or_multi_class_head( head = head_utils.binary_or_multi_class_head(
n_classes, n_classes,
weight_column=weight_column, weight_column=weight_column,

View file

@ -41,6 +41,7 @@ class DNNClassifier(tf.estimator.Estimator):
loss_reduction=tf.compat.v1.losses.Reduction.SUM, loss_reduction=tf.compat.v1.losses.Reduction.SUM,
batch_norm=False, batch_norm=False,
): ):
"""See `tf.compat.v1.estimator.DNNClassifier`."""
head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access
n_classes, weight_column, label_vocabulary, loss_reduction) n_classes, weight_column, label_vocabulary, loss_reduction)
estimator._canned_estimator_api_gauge.get_cell('Classifier').set('DNN') estimator._canned_estimator_api_gauge.get_cell('Classifier').set('DNN')

View file

@ -19,8 +19,8 @@ import tensorflow as tf
def make_dp_model_class(cls): def make_dp_model_class(cls):
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it.""" """Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
class DPModelClass(cls): class DPModelClass(cls): # pylint: disable=empty-docstring
"""A DP version of `cls`, which should be a subclass of `tf.keras.Model`.""" __doc__ = ('DP subclass of `tf.keras.{}`.').format(cls.__name__)
def __init__( def __init__(
self, self,
@ -37,6 +37,9 @@ def make_dp_model_class(cls):
noise_multiplier: Ratio of the standard deviation to the clipping noise_multiplier: Ratio of the standard deviation to the clipping
norm. norm.
use_xla: If `True`, compiles train_step to XLA. use_xla: If `True`, compiles train_step to XLA.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__`
method.
""" """
super(DPModelClass, self).__init__(*args, **kwargs) super(DPModelClass, self).__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip self._l2_norm_clip = l2_norm_clip
@ -78,6 +81,7 @@ def make_dp_model_class(cls):
return tf.squeeze(y_pred, axis=0), loss, clipped_grads return tf.squeeze(y_pred, axis=0), loss, clipped_grads
def train_step(self, data): def train_step(self, data):
"""DP-SGD version of base class method."""
_, y = data _, y = data
y_pred, _, per_eg_grads = tf.vectorized_map( y_pred, _, per_eg_grads = tf.vectorized_map(
self._compute_per_example_grads, data) self._compute_per_example_grads, data)
@ -87,8 +91,6 @@ def make_dp_model_class(cls):
self.compiled_metrics.update_state(y, y_pred) self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics} return {m.name: m.result() for m in self.metrics}
DPModelClass.__doc__ = ('DP subclass of `tf.keras.{}`.').format(cls.__name__)
return DPModelClass return DPModelClass

View file

@ -26,7 +26,15 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
def make_optimizer_class(cls): def make_optimizer_class(cls):
"""Constructs a DP optimizer class from an existing one.""" """Given a subclass of `tf.compat.v1.train.Optimizer`, returns a DP-SGD subclass of it.
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.compat.v1.train.Optimizer`.
Returns:
A DP-SGD subclass of `cls`.
"""
parent_code = tf.train.Optimizer.compute_gradients.__code__ parent_code = tf.train.Optimizer.compute_gradients.__code__
has_compute_gradients = hasattr(cls, 'compute_gradients') has_compute_gradients = hasattr(cls, 'compute_gradients')
@ -40,8 +48,8 @@ def make_optimizer_class(cls):
'make_optimizer_class() does not interfere with overridden version.', 'make_optimizer_class() does not interfere with overridden version.',
cls.__name__) cls.__name__)
class DPOptimizerClass(cls): class DPOptimizerClass(cls): # pylint: disable=empty-docstring
"""Differentially private subclass of given class cls.""" __doc__ = ('DP subclass of `tf.compat.v1.train.{}`.').format(cls.__name__)
def __init__( def __init__(
self, self,
@ -50,7 +58,7 @@ def make_optimizer_class(cls):
unroll_microbatches=False, unroll_microbatches=False,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs): **kwargs):
"""Initialize the DPOptimizerClass. """Initializes the DPOptimizerClass.
Args: Args:
dp_sum_query: `DPQuery` object, specifying differential privacy dp_sum_query: `DPQuery` object, specifying differential privacy
@ -61,6 +69,8 @@ def make_optimizer_class(cls):
unroll_microbatches: If true, processes microbatches within a Python unroll_microbatches: If true, processes microbatches within a Python
loop instead of a `tf.while_loop`. Can be used if using a loop instead of a `tf.while_loop`. Can be used if using a
`tf.while_loop` raises an exception. `tf.while_loop` raises an exception.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
""" """
super(DPOptimizerClass, self).__init__(*args, **kwargs) super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._dp_sum_query = dp_sum_query self._dp_sum_query = dp_sum_query
@ -80,6 +90,7 @@ def make_optimizer_class(cls):
colocate_gradients_with_ops=False, colocate_gradients_with_ops=False,
grad_loss=None, grad_loss=None,
gradient_tape=None): gradient_tape=None):
"""DP-SGD version of base class method."""
self._was_compute_gradients_called = True self._was_compute_gradients_called = True
if self._global_state is None: if self._global_state is None:
self._global_state = self._dp_sum_query.initial_global_state() self._global_state = self._dp_sum_query.initial_global_state()
@ -124,7 +135,7 @@ def make_optimizer_class(cls):
return grads_and_vars return grads_and_vars
else: else:
# TF is running in graph mode, check we did not receive a gradient tape. # TF is running in graph mode. Check we did not receive a gradient tape.
if gradient_tape: if gradient_tape:
raise ValueError('When in graph mode, a tape should not be passed.') raise ValueError('When in graph mode, a tape should not be passed.')
@ -197,6 +208,8 @@ def make_optimizer_class(cls):
return list(zip(final_grads, var_list)) return list(zip(final_grads, var_list))
def apply_gradients(self, grads_and_vars, global_step=None, name=None): def apply_gradients(self, grads_and_vars, global_step=None, name=None):
# pylint: disable=g-doc-args, g-doc-return-or-yield
"""DP-SGD version of base class method."""
assert self._was_compute_gradients_called, ( assert self._was_compute_gradients_called, (
'compute_gradients() on the differentially private optimizer was not' 'compute_gradients() on the differentially private optimizer was not'
' called. Which means that the training is not differentially ' ' called. Which means that the training is not differentially '
@ -205,17 +218,24 @@ def make_optimizer_class(cls):
return super(DPOptimizerClass, return super(DPOptimizerClass,
self).apply_gradients(grads_and_vars, global_step, name) self).apply_gradients(grads_and_vars, global_step, name)
DPOptimizerClass.__doc__ = ('DP subclass of `tf.compat.v1.train.{}`.').format(
cls.__name__)
return DPOptimizerClass return DPOptimizerClass
def make_gaussian_optimizer_class(cls): def make_gaussian_optimizer_class(cls):
"""Constructs a DP optimizer with Gaussian averaging of updates.""" """Given a subclass of `tf.compat.v1.train.Optimizer`, returns a subclass using DP-SGD with Gaussian averaging.
class DPGaussianOptimizerClass(make_optimizer_class(cls)): Args:
"""DP subclass of given class cls using Gaussian averaging.""" cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.compat.v1.train.Optimizer`.
Returns:
A subclass of `cls` using DP-SGD with Gaussian averaging.
"""
class DPGaussianOptimizerClass(make_optimizer_class(cls)): # pylint: disable=empty-docstring
__doc__ = (
'DP subclass of `tf.compat.v1.train.{}` using Gaussian averaging.'
).format(cls.__name__)
def __init__( def __init__(
self, self,
@ -226,6 +246,21 @@ def make_gaussian_optimizer_class(cls):
unroll_microbatches=False, unroll_microbatches=False,
*args, # pylint: disable=keyword-arg-before-vararg *args, # pylint: disable=keyword-arg-before-vararg
**kwargs): **kwargs):
"""Initializes the `DPGaussianOptimizerClass`.
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch is
split. If `None`, will default to the size of the minibatch, and
per-example gradients will be computed.
ledger: Defaults to `None`. An instance of `tf_privacy.PrivacyLedger`.
unroll_microbatches: If true, processes microbatches within a Python
loop instead of a `tf.while_loop`. Can be used if using a
`tf.while_loop` raises an exception.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
self._l2_norm_clip = l2_norm_clip self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier self._noise_multiplier = noise_multiplier
self._num_microbatches = num_microbatches self._num_microbatches = num_microbatches
@ -268,10 +303,6 @@ def make_gaussian_optimizer_class(cls):
def ledger(self): def ledger(self):
return self._dp_sum_query.ledger return self._dp_sum_query.ledger
DPGaussianOptimizerClass.__doc__ = (
'DP subclass of `tf.train.{}` using Gaussian averaging.').format(
cls.__name__)
return DPGaussianOptimizerClass return DPGaussianOptimizerClass
AdagradOptimizer = tf.train.AdagradOptimizer AdagradOptimizer = tf.train.AdagradOptimizer

View file

@ -24,10 +24,18 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
def make_keras_optimizer_class(cls): def make_keras_optimizer_class(cls):
"""Constructs a DP Keras optimizer class from an existing one.""" """Given a subclass of `tf.keras.optimizers.Optimizer`, returns a DP-SGD subclass of it.
class DPOptimizerClass(cls): Args:
"""Differentially private subclass of given class cls. cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.Optimizer`.
Returns:
A DP-SGD subclass of `cls`.
"""
class DPOptimizerClass(cls): # pylint: disable=empty-docstring
__doc__ = """Differentially private subclass of given class `tf.keras.optimizers.{}.
The class tf.keras.optimizers.Optimizer has two methods to compute The class tf.keras.optimizers.Optimizer has two methods to compute
gradients, `_compute_gradients` and `get_gradients`. The first works gradients, `_compute_gradients` and `get_gradients`. The first works
@ -37,7 +45,7 @@ def make_keras_optimizer_class(cls):
Internally, DPOptimizerClass stores hyperparameters both individually Internally, DPOptimizerClass stores hyperparameters both individually
and encapsulated in a `GaussianSumQuery` object for these two use cases. and encapsulated in a `GaussianSumQuery` object for these two use cases.
However, this should be invisible to users of this class. However, this should be invisible to users of this class.
""" """.format(cls.__name__)
def __init__( def __init__(
self, self,
@ -53,6 +61,8 @@ def make_keras_optimizer_class(cls):
noise_multiplier: Ratio of the standard deviation to the clipping norm. noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch num_microbatches: Number of microbatches into which each minibatch
is split. is split.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
""" """
super(DPOptimizerClass, self).__init__(*args, **kwargs) super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip self._l2_norm_clip = l2_norm_clip
@ -64,7 +74,7 @@ def make_keras_optimizer_class(cls):
self._was_dp_gradients_called = False self._was_dp_gradients_called = False
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
"""DP version of superclass method.""" """DP-SGD version of base class method."""
self._was_dp_gradients_called = True self._was_dp_gradients_called = True
# Compute loss. # Compute loss.
@ -120,7 +130,7 @@ def make_keras_optimizer_class(cls):
return list(zip(final_gradients, var_list)) return list(zip(final_gradients, var_list))
def get_gradients(self, loss, params): def get_gradients(self, loss, params):
"""DP version of superclass method.""" """DP-SGD version of base class method."""
self._was_dp_gradients_called = True self._was_dp_gradients_called = True
if self._global_state is None: if self._global_state is None:
@ -160,6 +170,7 @@ def make_keras_optimizer_class(cls):
return final_grads return final_grads
def apply_gradients(self, grads_and_vars, global_step=None, name=None): def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""DP-SGD version of base class method."""
assert self._was_dp_gradients_called, ( assert self._was_dp_gradients_called, (
'Neither _compute_gradients() or get_gradients() on the ' 'Neither _compute_gradients() or get_gradients() on the '
'differentially private optimizer was called. This means the ' 'differentially private optimizer was called. This means the '
@ -169,10 +180,6 @@ def make_keras_optimizer_class(cls):
return super(DPOptimizerClass, return super(DPOptimizerClass,
self).apply_gradients(grads_and_vars, global_step, name) self).apply_gradients(grads_and_vars, global_step, name)
DPOptimizerClass.__doc__ = (
'DP subclass of `tf.keras.optimizers.{}` using Gaussian averaging.'
).format(cls.__name__)
return DPOptimizerClass return DPOptimizerClass

View file

@ -37,10 +37,18 @@ def clip_gradients_vmap(g, l2_norm_clip):
def make_vectorized_keras_optimizer_class(cls): def make_vectorized_keras_optimizer_class(cls):
"""Constructs a DP Keras optimizer class from an existing one.""" """Given a subclass of `tf.keras.optimizers.Optimizer`, returns a vectorized DP-SGD subclass of it.
class DPOptimizerClass(cls): Args:
"""Differentially private subclass of given class cls. cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.Optimizer`.
Returns:
A vectorized DP-SGD subclass of `cls`.
"""
class DPOptimizerClass(cls): # pylint: disable=empty-docstring
__doc__ = """Vectorized differentially private subclass of given class `tf.keras.optimizers.{}.
The class tf.keras.optimizers.Optimizer has two methods to compute The class tf.keras.optimizers.Optimizer has two methods to compute
gradients, `_compute_gradients` and `get_gradients`. The first works gradients, `_compute_gradients` and `get_gradients`. The first works
@ -50,7 +58,7 @@ def make_vectorized_keras_optimizer_class(cls):
Internally, DPOptimizerClass stores hyperparameters both individually Internally, DPOptimizerClass stores hyperparameters both individually
and encapsulated in a `GaussianSumQuery` object for these two use cases. and encapsulated in a `GaussianSumQuery` object for these two use cases.
However, this should be invisible to users of this class. However, this should be invisible to users of this class.
""" """.format(cls.__name__)
def __init__( def __init__(
self, self,
@ -66,6 +74,8 @@ def make_vectorized_keras_optimizer_class(cls):
noise_multiplier: Ratio of the standard deviation to the clipping norm. noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch num_microbatches: Number of microbatches into which each minibatch
is split. is split.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
""" """
super(DPOptimizerClass, self).__init__(*args, **kwargs) super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip self._l2_norm_clip = l2_norm_clip
@ -77,7 +87,7 @@ def make_vectorized_keras_optimizer_class(cls):
self._was_dp_gradients_called = False self._was_dp_gradients_called = False
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
"""DP version of superclass method.""" """DP-SGD version of base class method."""
self._was_dp_gradients_called = True self._was_dp_gradients_called = True
# Compute loss. # Compute loss.
@ -130,7 +140,7 @@ def make_vectorized_keras_optimizer_class(cls):
return list(zip(final_gradients, var_list)) return list(zip(final_gradients, var_list))
def get_gradients(self, loss, params): def get_gradients(self, loss, params):
"""DP version of superclass method.""" """DP-SGD version of base class method."""
self._was_dp_gradients_called = True self._was_dp_gradients_called = True
if self._global_state is None: if self._global_state is None:
@ -168,6 +178,8 @@ def make_vectorized_keras_optimizer_class(cls):
return final_grads return final_grads
def apply_gradients(self, grads_and_vars, global_step=None, name=None): def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""DP-SGD version of base class method."""
assert self._was_dp_gradients_called, ( assert self._was_dp_gradients_called, (
'Neither _compute_gradients() or get_gradients() on the ' 'Neither _compute_gradients() or get_gradients() on the '
'differentially private optimizer was called. This means the ' 'differentially private optimizer was called. This means the '
@ -177,9 +189,6 @@ def make_vectorized_keras_optimizer_class(cls):
return super(DPOptimizerClass, return super(DPOptimizerClass,
self).apply_gradients(grads_and_vars, global_step, name) self).apply_gradients(grads_and_vars, global_step, name)
DPOptimizerClass.__doc__ = (
'Vectorized DP subclass of `tf.keras.optimizers.{}` using Gaussian '
'averaging.').format(cls.__name__)
return DPOptimizerClass return DPOptimizerClass

View file

@ -29,7 +29,15 @@ GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
def make_vectorized_optimizer_class(cls): def make_vectorized_optimizer_class(cls):
"""Constructs a vectorized DP optimizer class from an existing one.""" """Given a subclass of `tf.compat.v1.train.Optimizer`, returns a vectorized DP-SGD subclass of it.
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.compat.v1.train.Optimizer`.
Returns:
A DP-SGD subclass of `cls`.
"""
child_code = cls.compute_gradients.__code__ child_code = cls.compute_gradients.__code__
if child_code is not parent_code: if child_code is not parent_code:
logging.warning( logging.warning(
@ -38,8 +46,11 @@ def make_vectorized_optimizer_class(cls):
'make_optimizer_class() does not interfere with overridden version.', 'make_optimizer_class() does not interfere with overridden version.',
cls.__name__) cls.__name__)
class DPOptimizerClass(cls): class DPOptimizerClass(cls): # pylint: disable=empty-docstring
"""Differentially private subclass of given class cls.""" __doc__ = (
'Vectorized DP subclass of `tf.compat.v1.train.{}` using Gaussian '
'averaging.'
).format(cls.__name__)
def __init__( def __init__(
self, self,
@ -56,6 +67,8 @@ def make_vectorized_optimizer_class(cls):
num_microbatches: Number of microbatches into which each minibatch is num_microbatches: Number of microbatches into which each minibatch is
split. If `None`, will default to the size of the minibatch, and split. If `None`, will default to the size of the minibatch, and
per-example gradients will be computed. per-example gradients will be computed.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
""" """
super(DPOptimizerClass, self).__init__(*args, **kwargs) super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip self._l2_norm_clip = l2_norm_clip
@ -70,6 +83,7 @@ def make_vectorized_optimizer_class(cls):
colocate_gradients_with_ops=False, colocate_gradients_with_ops=False,
grad_loss=None, grad_loss=None,
gradient_tape=None): gradient_tape=None):
"""DP-SGD version of base class method."""
if callable(loss): if callable(loss):
# TF is running in Eager mode # TF is running in Eager mode
raise NotImplementedError('Vectorized optimizer unavailable for TF2.') raise NotImplementedError('Vectorized optimizer unavailable for TF2.')
@ -136,9 +150,6 @@ def make_vectorized_optimizer_class(cls):
return list(zip(final_grads, var_list)) return list(zip(final_grads, var_list))
DPOptimizerClass.__doc__ = (
'Vectorized DP subclass of `tf.compat.v1.train.{}` using '
'Gaussian averaging.').format(cls.__name__)
return DPOptimizerClass return DPOptimizerClass