Report the true loss in DPModel instead of the norm-adjusted loss.
PiperOrigin-RevId: 517112812
This commit is contained in:
parent
8f4ab1a8bb
commit
043e8b5272
3 changed files with 215 additions and 73 deletions
|
@ -21,7 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
|||
`compute_gradient_norms()` function).
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Text, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
|
@ -129,11 +129,14 @@ def compute_gradient_norms(
|
|||
vars_list = [a for (a, b) in filtered_outputs]
|
||||
sqr_norm_fns_list = [b for (a, b) in filtered_outputs]
|
||||
# Second loop evaluates the squared L2 norm functions and appends the results.
|
||||
grads_list = tape.gradient(summed_loss, vars_list)
|
||||
grads_list = tape.gradient(
|
||||
summed_loss,
|
||||
vars_list,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||
)
|
||||
sqr_norm_list = []
|
||||
for grads, f in zip(grads_list, sqr_norm_fns_list):
|
||||
sqr_norm_list.append(f(grads))
|
||||
del tape
|
||||
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
|
||||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||
|
||||
|
@ -163,24 +166,23 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
|
|||
return l2_norm_clip / tf.math.maximum(l2_norm_clip, gradient_norms)
|
||||
|
||||
|
||||
def compute_pred_and_clipped_gradients(
|
||||
def compute_clipped_gradients_and_outputs(
|
||||
input_model: tf.keras.Model,
|
||||
x_batch: InputTensor,
|
||||
y_batch: tf.Tensor,
|
||||
l2_norm_clip: float,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
num_microbatches: Optional[lr.BatchSize] = None,
|
||||
):
|
||||
"""Computes the per-example predictions and per-example clipped loss gradient.
|
||||
) -> Tuple[List[tf.Tensor], tf.Tensor, float]:
|
||||
"""Computes the per-example clipped loss gradient and other useful outputs.
|
||||
|
||||
Given a batch of observations `(x_batch, y_batch)`, the main steps of this
|
||||
function are: (i) compute the l2-norm of the gradients of the trainable
|
||||
variables of `input_model` for each example in the batch; (ii) use the norms
|
||||
computed in (i) to obtain "clip_weights" that are used to generate a weighted
|
||||
loss function whose gradient for each example has l2-norm at most
|
||||
`l2_norm_clip`; (iii) output the clipped gradients in (ii) and the
|
||||
`tf.Tensor` generated by `input_model` when it is given `x_batch` as its
|
||||
input.
|
||||
`l2_norm_clip`; (iii) output the clipped gradients in (ii) and other useful
|
||||
outputs to the caller.
|
||||
|
||||
Args:
|
||||
input_model: The `tf.keras.Model` from which to obtain the layers from.
|
||||
|
@ -204,9 +206,11 @@ def compute_pred_and_clipped_gradients(
|
|||
of num_microbatches).
|
||||
|
||||
Returns:
|
||||
A `tuple` `(y_pred, grad)`. The first element is the prediction generated by
|
||||
the model on the input `x_batch`. The second element is the clipped
|
||||
gradient of the loss function.
|
||||
A `tuple` `(grad, y_pred, weighted_loss_value)`. The first element is the
|
||||
clipped gradient of the loss function, the second is the result of
|
||||
applying `input_model` to `x_batch`, and the third is loss value of
|
||||
`input_model`, weighted by the loss weights generated by a specific
|
||||
`compute_clip_weights()` call.
|
||||
"""
|
||||
gradient_norms = compute_gradient_norms(
|
||||
input_model,
|
||||
|
@ -217,18 +221,37 @@ def compute_pred_and_clipped_gradients(
|
|||
)
|
||||
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
||||
with tf.GradientTape() as tape:
|
||||
y_pred = input_model(x_batch, training=True)
|
||||
if num_microbatches is not None:
|
||||
y_batch = lr.add_microbatch_axis(y_batch, num_microbatches)
|
||||
y_pred = lr.add_microbatch_axis(y_pred, num_microbatches)
|
||||
# Warning: When num_microbatches is not None, we need to be sure that
|
||||
# WARNING: When num_microbatches is not None, we need to be sure that
|
||||
# `compute_loss` always computes the mean over the microbatches
|
||||
# as it is the assumption made when computing the gradient norm.
|
||||
# It is indeed the case for multiple keras loss functions
|
||||
# (e.g. mean_squared_error and binary_crossentropy). However it
|
||||
# is not defined in the contract so may not hold, especially for
|
||||
# custom losses.
|
||||
loss_value = input_model.compute_loss(
|
||||
x_batch, y_batch, y_pred, loss_weights
|
||||
y_pred = input_model(x_batch, training=True)
|
||||
loss_y_batch = (
|
||||
y_batch
|
||||
if num_microbatches is None
|
||||
else lr.add_microbatch_axis(y_batch, num_microbatches)
|
||||
)
|
||||
return y_pred, tape.gradient(loss_value, input_model.trainable_variables)
|
||||
loss_y_pred = (
|
||||
y_pred
|
||||
if num_microbatches is None
|
||||
else lr.add_microbatch_axis(y_pred, num_microbatches)
|
||||
)
|
||||
# NOTE: We do not log the loss values here. The caller should invoke
|
||||
# `input_model.compute_loss()` to log loss values. Specifically,
|
||||
# calling `input_model.compute_loss()` performs the following steps:
|
||||
#
|
||||
# (i) sums `input_model.loss` with the regularization losses given in
|
||||
# `input_model.losses` to obtain the total loss
|
||||
# (ii) evaluates the total loss with sample weights (if given)
|
||||
weighted_loss_value = input_model.loss(
|
||||
loss_y_batch, loss_y_pred, loss_weights
|
||||
)
|
||||
clipped_grads = tape.gradient(
|
||||
weighted_loss_value,
|
||||
input_model.trainable_variables,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||
)
|
||||
return clipped_grads, y_pred, weighted_loss_value
|
||||
|
|
|
@ -17,14 +17,14 @@ from absl import logging
|
|||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||
|
||||
|
||||
def make_dp_model_class(cls):
|
||||
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
||||
|
||||
class DPModelClass(cls): # pylint: disable=missing-class-docstring
|
||||
__doc__ = (
|
||||
"""DP subclass of `{base_model}`.
|
||||
__doc__ = ("""DP subclass of `{base_model}`.
|
||||
|
||||
This can be used as a differentially private replacement for
|
||||
{base_model}. This class implements DP-SGD using the standard
|
||||
|
@ -32,7 +32,7 @@ def make_dp_model_class(cls):
|
|||
|
||||
This class also utilizes a faster gradient clipping algorithm if the
|
||||
following two conditions hold:
|
||||
(i) the trainable layers of the model are keys in the `dict` input
|
||||
(i) the trainable layers of the model are keys in the input
|
||||
`layer_registry`,
|
||||
(ii) the loss `tf.Tensor` for a given batch of examples is either a
|
||||
scalar or a 2D `tf.Tensor` that has only one column
|
||||
|
@ -56,6 +56,11 @@ def make_dp_model_class(cls):
|
|||
It is the caller's responsibility to make sure that the loss function
|
||||
does behave this way.
|
||||
|
||||
WARNING: This API does not have privacy guarantees for custom
|
||||
layer-level losses created by the `layer.add_loss()` API. It does,
|
||||
however, support layer regularization losses. All of these layer-level
|
||||
losses are found in `model.losses`.
|
||||
|
||||
When instantiating this class, you need to supply several
|
||||
DP-related arguments followed by the standard arguments for
|
||||
`{short_base_model}`.
|
||||
|
@ -83,8 +88,7 @@ def make_dp_model_class(cls):
|
|||
model.fit(train_data, train_labels, epochs=1, batch_size=32)
|
||||
```
|
||||
|
||||
"""
|
||||
).format(
|
||||
""").format(
|
||||
base_model='tf.keras.' + cls.__name__,
|
||||
short_base_model=cls.__name__,
|
||||
dp_model_class='DP' + cls.__name__,
|
||||
|
@ -124,8 +128,10 @@ def make_dp_model_class(cls):
|
|||
# In particular, boolean values supplied to `use_xla` in the earlier API
|
||||
# will raise an error.
|
||||
if isinstance(num_microbatches, bool):
|
||||
raise ValueError('Boolean value supplied for `num_microbatches`. '
|
||||
'Did you intend it for `use_xla`?')
|
||||
raise ValueError(
|
||||
'Boolean value supplied for `num_microbatches`. '
|
||||
'Did you intend it for `use_xla`?'
|
||||
)
|
||||
self._num_microbatches = num_microbatches
|
||||
|
||||
# If all the trainable layers are in the input layer registry, we
|
||||
|
@ -144,7 +150,8 @@ def make_dp_model_class(cls):
|
|||
|
||||
if use_xla:
|
||||
self.train_step = tf.function(
|
||||
self.train_step, experimental_compile=True)
|
||||
self.train_step, experimental_compile=True
|
||||
)
|
||||
|
||||
def _process_per_example_grads(self, grads):
|
||||
grads_flat = tf.nest.flatten(grads)
|
||||
|
@ -152,7 +159,7 @@ def make_dp_model_class(cls):
|
|||
tf.reduce_sum(input_tensor=tf.square(g)) for g in grads_flat
|
||||
]
|
||||
global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
|
||||
div = tf.maximum(global_norm / self._l2_norm_clip, 1.)
|
||||
div = tf.maximum(global_norm / self._l2_norm_clip, 1.0)
|
||||
clipped_flat = [g / div for g in grads_flat]
|
||||
return tf.nest.pack_sequence_as(grads, clipped_flat)
|
||||
|
||||
|
@ -160,19 +167,23 @@ def make_dp_model_class(cls):
|
|||
summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0)
|
||||
noise_stddev = self._l2_norm_clip * self._noise_multiplier
|
||||
noise = tf.random.normal(
|
||||
tf.shape(input=summed_grads), stddev=noise_stddev)
|
||||
tf.shape(input=summed_grads), stddev=noise_stddev
|
||||
)
|
||||
noised_grads = summed_grads + noise
|
||||
return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype)
|
||||
return noised_grads / tf.cast(
|
||||
tf.shape(stacked_grads)[0], noised_grads.dtype
|
||||
)
|
||||
|
||||
def _compute_per_example_grads(self, data):
|
||||
x, y = data
|
||||
microbatched_x, microbatched_y = data
|
||||
with tf.GradientTape() as tape:
|
||||
y_pred = self(x, training=True)
|
||||
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
|
||||
|
||||
grads_list = tape.gradient(loss, self.trainable_variables)
|
||||
microbatched_y_pred = self(microbatched_x, training=True)
|
||||
# NOTE: Calling `self.loss()` neither logs the total loss nor does it
|
||||
# include any regularization terms.
|
||||
microbatched_loss = self.loss(microbatched_y, microbatched_y_pred)
|
||||
grads_list = tape.gradient(microbatched_loss, self.trainable_variables)
|
||||
clipped_grads = self._process_per_example_grads(grads_list)
|
||||
return y_pred, loss, clipped_grads
|
||||
return microbatched_loss, clipped_grads
|
||||
|
||||
def train_step(self, data):
|
||||
"""DP-SGD version of base class method.
|
||||
|
@ -184,7 +195,7 @@ def make_dp_model_class(cls):
|
|||
condition is satisfied if the model subclasses the keras.Sequential or
|
||||
keras.engine.functional.Functional class).
|
||||
|
||||
If (i) and (ii) above do not hold, then clips and aggregates
|
||||
If (i) and (ii) above do not hold, then this function clips and aggregates
|
||||
gradients at the microbatch level.
|
||||
|
||||
Args:
|
||||
|
@ -193,6 +204,13 @@ def make_dp_model_class(cls):
|
|||
Returns:
|
||||
See the base class.
|
||||
"""
|
||||
output_metrics = {}
|
||||
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||
batch_size = tf.shape(y)[0]
|
||||
eff_microbatch_size = self._num_microbatches or batch_size
|
||||
privatized_loss_name = 'privatized_loss'
|
||||
|
||||
# Branch based on gradient clipping algorithm.
|
||||
if self._enable_fast_peg_computation:
|
||||
logging.info(
|
||||
'Computing gradients using the fast per-example gradient '
|
||||
|
@ -200,60 +218,73 @@ def make_dp_model_class(cls):
|
|||
)
|
||||
# Computes the per-example gradient norms using a "fast" clipping
|
||||
# trick, and uses these norms to clip the per-example gradients.
|
||||
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||
y_pred, clipped_grads = clip_grads.compute_pred_and_clipped_gradients(
|
||||
self,
|
||||
x,
|
||||
y,
|
||||
self._l2_norm_clip,
|
||||
self._layer_registry,
|
||||
self._num_microbatches,
|
||||
# NOTE: Reshaping of the input according to the effective number of
|
||||
# microbatches is done here.
|
||||
clipped_grads, y_pred, weighted_loss = (
|
||||
clip_grads.compute_clipped_gradients_and_outputs(
|
||||
self,
|
||||
x,
|
||||
y,
|
||||
self._l2_norm_clip,
|
||||
self._layer_registry,
|
||||
self._num_microbatches,
|
||||
)
|
||||
)
|
||||
batch_size = self._num_microbatches or tf.shape(y)[0]
|
||||
grads = gradient_clipping_utils.add_aggregate_noise(
|
||||
self,
|
||||
clipped_grads,
|
||||
batch_size,
|
||||
eff_microbatch_size,
|
||||
self._l2_norm_clip,
|
||||
self._noise_multiplier,
|
||||
)
|
||||
output_metrics[privatized_loss_name] = weighted_loss
|
||||
else:
|
||||
logging.info('Computing gradients using microbatching.')
|
||||
# Computes per-example clipped gradients directly. This is called
|
||||
# if at least one of the layers cannot use the "fast" gradient clipping
|
||||
# algorithm.
|
||||
# TODO(wkong): check if the following is valid with sample weights.
|
||||
_, y = data
|
||||
batch_size = y.shape[0]
|
||||
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
if batch_size % self._num_microbatches != 0:
|
||||
raise ValueError('Number of_microbatches must divide batch size.')
|
||||
|
||||
def reshape_fn(x):
|
||||
new_shape = (
|
||||
self._num_microbatches,
|
||||
batch_size // self._num_microbatches,
|
||||
) + x.shape[1:]
|
||||
return tf.reshape(x, new_shape)
|
||||
|
||||
data = tf.nest.map_structure(reshape_fn, data)
|
||||
|
||||
y_pred, _, per_eg_grads = tf.vectorized_map(
|
||||
self._compute_per_example_grads, data
|
||||
reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_microbatch_size)
|
||||
microbatched_data = tf.nest.map_structure(reshape_fn, data)
|
||||
microbatched_losses, clipped_grads = tf.vectorized_map(
|
||||
self._compute_per_example_grads,
|
||||
microbatched_data,
|
||||
)
|
||||
|
||||
y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:])
|
||||
|
||||
y_pred = self(x, training=True)
|
||||
grads = tf.nest.map_structure(
|
||||
self._reduce_per_example_grads, per_eg_grads
|
||||
self._reduce_per_example_grads, clipped_grads
|
||||
)
|
||||
if self.loss.reduction == tf.keras.losses.Reduction.SUM:
|
||||
microbatched_loss = tf.reduce_sum(microbatched_losses)
|
||||
else:
|
||||
microbatched_loss = tf.reduce_mean(microbatched_losses)
|
||||
output_metrics[privatized_loss_name] = microbatched_loss
|
||||
|
||||
# Add the values and gradients contributed by regularization losses.
|
||||
if self.losses:
|
||||
logging.warning(
|
||||
'Losses in `model.losses` must be input (batch) independent in '
|
||||
'order to obtain the desired differential privacy guarantees.'
|
||||
)
|
||||
with tf.GradientTape() as tape:
|
||||
summed_regularization_loss = tf.add_n(self.losses)
|
||||
regularization_grads = tape.gradient(
|
||||
summed_regularization_loss,
|
||||
self.trainable_variables,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||
)
|
||||
grads = [a + b for (a, b) in zip(grads, regularization_grads)]
|
||||
output_metrics[privatized_loss_name] += summed_regularization_loss
|
||||
|
||||
# Log the true loss.
|
||||
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
|
||||
|
||||
# Forward the private gradients to the optimizer and return the results.
|
||||
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
||||
self.compiled_metrics.update_state(y, y_pred)
|
||||
return {m.name: m.result() for m in self.metrics}
|
||||
for m in self.metrics:
|
||||
output_metrics[m.name] = m.result()
|
||||
|
||||
return output_metrics
|
||||
|
||||
return DPModelClass
|
||||
|
||||
|
|
|
@ -282,5 +282,93 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
],
|
||||
)
|
||||
|
||||
# Simple test to check that regularizer gradients are contributing to the
|
||||
# final gradient.
|
||||
@parameterized.named_parameters(
|
||||
('no_registry', None),
|
||||
('default_registry', layer_registry.make_default_layer_registry()),
|
||||
)
|
||||
def testRegularizationGradient(self, registry):
|
||||
input_dim = 10
|
||||
batch_size = 2
|
||||
regularizer_multiplier = 0.025
|
||||
inputs = tf.keras.layers.Input((input_dim,))
|
||||
dense_lyr = tf.keras.layers.Dense(
|
||||
1,
|
||||
kernel_initializer='ones',
|
||||
use_bias=False,
|
||||
kernel_regularizer=tf.keras.regularizers.L2(regularizer_multiplier),
|
||||
)
|
||||
# Zero-out outputs to avoid contributions from the main loss function.
|
||||
outputs = tf.multiply(dense_lyr(inputs), 0.0)
|
||||
model = dp_keras_model.DPModel(
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
l2_norm_clip=1e9,
|
||||
noise_multiplier=0.0,
|
||||
layer_registry=registry,
|
||||
)
|
||||
model.compile(
|
||||
loss=tf.keras.losses.MeanSquaredError(),
|
||||
optimizer=tf.keras.optimizers.SGD(1.0),
|
||||
run_eagerly=True,
|
||||
)
|
||||
x_batch = tf.reshape(
|
||||
tf.range(input_dim * batch_size, dtype=tf.float32),
|
||||
[batch_size, input_dim],
|
||||
)
|
||||
y_batch = tf.zeros([batch_size, 1])
|
||||
model.fit(x=x_batch, y=y_batch)
|
||||
|
||||
self.assertAllClose(
|
||||
model.trainable_variables,
|
||||
tf.multiply(
|
||||
tf.ones_like(model.trainable_variables),
|
||||
1.0 - 2.0 * regularizer_multiplier,
|
||||
),
|
||||
)
|
||||
|
||||
# Simple test to check that custom input regularization does NOT contribute
|
||||
# to the gradient.
|
||||
@parameterized.named_parameters(
|
||||
('no_registry', None),
|
||||
('default_registry', layer_registry.make_default_layer_registry()),
|
||||
)
|
||||
def testCustomRegularizationZeroGradient(self, registry):
|
||||
input_dim = 10
|
||||
batch_size = 2
|
||||
inputs = tf.keras.layers.Input((input_dim,))
|
||||
dense_lyr = tf.keras.layers.Dense(
|
||||
1,
|
||||
kernel_initializer='ones',
|
||||
use_bias=False,
|
||||
)
|
||||
# Zero-out outputs to avoid contributions from the main loss function.
|
||||
outputs = tf.multiply(dense_lyr(inputs), 0.0)
|
||||
model = dp_keras_model.DPModel(
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
l2_norm_clip=1e9,
|
||||
noise_multiplier=0.0,
|
||||
layer_registry=registry,
|
||||
)
|
||||
model.add_loss(tf.reduce_sum(inputs))
|
||||
model.compile(
|
||||
loss=tf.keras.losses.MeanSquaredError(),
|
||||
optimizer=tf.keras.optimizers.SGD(1.0),
|
||||
run_eagerly=True,
|
||||
)
|
||||
x_batch = tf.reshape(
|
||||
tf.range(input_dim * batch_size, dtype=tf.float32),
|
||||
[batch_size, input_dim],
|
||||
)
|
||||
y_batch = tf.zeros([batch_size, 1])
|
||||
model.fit(x=x_batch, y=y_batch)
|
||||
|
||||
self.assertAllClose(
|
||||
model.trainable_variables, tf.ones_like(model.trainable_variables)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue