Report the true loss in DPModel instead of the norm-adjusted loss.

PiperOrigin-RevId: 517112812
This commit is contained in:
A. Unique TensorFlower 2023-03-16 07:14:36 -07:00
parent 8f4ab1a8bb
commit 043e8b5272
3 changed files with 215 additions and 73 deletions

View file

@ -21,7 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function).
"""
from typing import Any, Callable, Dict, Iterable, Optional, Text, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
@ -129,11 +129,14 @@ def compute_gradient_norms(
vars_list = [a for (a, b) in filtered_outputs]
sqr_norm_fns_list = [b for (a, b) in filtered_outputs]
# Second loop evaluates the squared L2 norm functions and appends the results.
grads_list = tape.gradient(summed_loss, vars_list)
grads_list = tape.gradient(
summed_loss,
vars_list,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
sqr_norm_list = []
for grads, f in zip(grads_list, sqr_norm_fns_list):
sqr_norm_list.append(f(grads))
del tape
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
@ -163,24 +166,23 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
return l2_norm_clip / tf.math.maximum(l2_norm_clip, gradient_norms)
def compute_pred_and_clipped_gradients(
def compute_clipped_gradients_and_outputs(
input_model: tf.keras.Model,
x_batch: InputTensor,
y_batch: tf.Tensor,
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
num_microbatches: Optional[lr.BatchSize] = None,
):
"""Computes the per-example predictions and per-example clipped loss gradient.
) -> Tuple[List[tf.Tensor], tf.Tensor, float]:
"""Computes the per-example clipped loss gradient and other useful outputs.
Given a batch of observations `(x_batch, y_batch)`, the main steps of this
function are: (i) compute the l2-norm of the gradients of the trainable
variables of `input_model` for each example in the batch; (ii) use the norms
computed in (i) to obtain "clip_weights" that are used to generate a weighted
loss function whose gradient for each example has l2-norm at most
`l2_norm_clip`; (iii) output the clipped gradients in (ii) and the
`tf.Tensor` generated by `input_model` when it is given `x_batch` as its
input.
`l2_norm_clip`; (iii) output the clipped gradients in (ii) and other useful
outputs to the caller.
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from.
@ -204,9 +206,11 @@ def compute_pred_and_clipped_gradients(
of num_microbatches).
Returns:
A `tuple` `(y_pred, grad)`. The first element is the prediction generated by
the model on the input `x_batch`. The second element is the clipped
gradient of the loss function.
A `tuple` `(grad, y_pred, weighted_loss_value)`. The first element is the
clipped gradient of the loss function, the second is the result of
applying `input_model` to `x_batch`, and the third is loss value of
`input_model`, weighted by the loss weights generated by a specific
`compute_clip_weights()` call.
"""
gradient_norms = compute_gradient_norms(
input_model,
@ -217,18 +221,37 @@ def compute_pred_and_clipped_gradients(
)
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
with tf.GradientTape() as tape:
y_pred = input_model(x_batch, training=True)
if num_microbatches is not None:
y_batch = lr.add_microbatch_axis(y_batch, num_microbatches)
y_pred = lr.add_microbatch_axis(y_pred, num_microbatches)
# Warning: When num_microbatches is not None, we need to be sure that
# WARNING: When num_microbatches is not None, we need to be sure that
# `compute_loss` always computes the mean over the microbatches
# as it is the assumption made when computing the gradient norm.
# It is indeed the case for multiple keras loss functions
# (e.g. mean_squared_error and binary_crossentropy). However it
# is not defined in the contract so may not hold, especially for
# custom losses.
loss_value = input_model.compute_loss(
x_batch, y_batch, y_pred, loss_weights
y_pred = input_model(x_batch, training=True)
loss_y_batch = (
y_batch
if num_microbatches is None
else lr.add_microbatch_axis(y_batch, num_microbatches)
)
return y_pred, tape.gradient(loss_value, input_model.trainable_variables)
loss_y_pred = (
y_pred
if num_microbatches is None
else lr.add_microbatch_axis(y_pred, num_microbatches)
)
# NOTE: We do not log the loss values here. The caller should invoke
# `input_model.compute_loss()` to log loss values. Specifically,
# calling `input_model.compute_loss()` performs the following steps:
#
# (i) sums `input_model.loss` with the regularization losses given in
# `input_model.losses` to obtain the total loss
# (ii) evaluates the total loss with sample weights (if given)
weighted_loss_value = input_model.loss(
loss_y_batch, loss_y_pred, loss_weights
)
clipped_grads = tape.gradient(
weighted_loss_value,
input_model.trainable_variables,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
return clipped_grads, y_pred, weighted_loss_value

View file

@ -17,14 +17,14 @@ from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
def make_dp_model_class(cls):
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
class DPModelClass(cls): # pylint: disable=missing-class-docstring
__doc__ = (
"""DP subclass of `{base_model}`.
__doc__ = ("""DP subclass of `{base_model}`.
This can be used as a differentially private replacement for
{base_model}. This class implements DP-SGD using the standard
@ -32,7 +32,7 @@ def make_dp_model_class(cls):
This class also utilizes a faster gradient clipping algorithm if the
following two conditions hold:
(i) the trainable layers of the model are keys in the `dict` input
(i) the trainable layers of the model are keys in the input
`layer_registry`,
(ii) the loss `tf.Tensor` for a given batch of examples is either a
scalar or a 2D `tf.Tensor` that has only one column
@ -56,6 +56,11 @@ def make_dp_model_class(cls):
It is the caller's responsibility to make sure that the loss function
does behave this way.
WARNING: This API does not have privacy guarantees for custom
layer-level losses created by the `layer.add_loss()` API. It does,
however, support layer regularization losses. All of these layer-level
losses are found in `model.losses`.
When instantiating this class, you need to supply several
DP-related arguments followed by the standard arguments for
`{short_base_model}`.
@ -83,8 +88,7 @@ def make_dp_model_class(cls):
model.fit(train_data, train_labels, epochs=1, batch_size=32)
```
"""
).format(
""").format(
base_model='tf.keras.' + cls.__name__,
short_base_model=cls.__name__,
dp_model_class='DP' + cls.__name__,
@ -124,8 +128,10 @@ def make_dp_model_class(cls):
# In particular, boolean values supplied to `use_xla` in the earlier API
# will raise an error.
if isinstance(num_microbatches, bool):
raise ValueError('Boolean value supplied for `num_microbatches`. '
'Did you intend it for `use_xla`?')
raise ValueError(
'Boolean value supplied for `num_microbatches`. '
'Did you intend it for `use_xla`?'
)
self._num_microbatches = num_microbatches
# If all the trainable layers are in the input layer registry, we
@ -144,7 +150,8 @@ def make_dp_model_class(cls):
if use_xla:
self.train_step = tf.function(
self.train_step, experimental_compile=True)
self.train_step, experimental_compile=True
)
def _process_per_example_grads(self, grads):
grads_flat = tf.nest.flatten(grads)
@ -152,7 +159,7 @@ def make_dp_model_class(cls):
tf.reduce_sum(input_tensor=tf.square(g)) for g in grads_flat
]
global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
div = tf.maximum(global_norm / self._l2_norm_clip, 1.)
div = tf.maximum(global_norm / self._l2_norm_clip, 1.0)
clipped_flat = [g / div for g in grads_flat]
return tf.nest.pack_sequence_as(grads, clipped_flat)
@ -160,19 +167,23 @@ def make_dp_model_class(cls):
summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0)
noise_stddev = self._l2_norm_clip * self._noise_multiplier
noise = tf.random.normal(
tf.shape(input=summed_grads), stddev=noise_stddev)
tf.shape(input=summed_grads), stddev=noise_stddev
)
noised_grads = summed_grads + noise
return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype)
return noised_grads / tf.cast(
tf.shape(stacked_grads)[0], noised_grads.dtype
)
def _compute_per_example_grads(self, data):
x, y = data
microbatched_x, microbatched_y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
grads_list = tape.gradient(loss, self.trainable_variables)
microbatched_y_pred = self(microbatched_x, training=True)
# NOTE: Calling `self.loss()` neither logs the total loss nor does it
# include any regularization terms.
microbatched_loss = self.loss(microbatched_y, microbatched_y_pred)
grads_list = tape.gradient(microbatched_loss, self.trainable_variables)
clipped_grads = self._process_per_example_grads(grads_list)
return y_pred, loss, clipped_grads
return microbatched_loss, clipped_grads
def train_step(self, data):
"""DP-SGD version of base class method.
@ -184,7 +195,7 @@ def make_dp_model_class(cls):
condition is satisfied if the model subclasses the keras.Sequential or
keras.engine.functional.Functional class).
If (i) and (ii) above do not hold, then clips and aggregates
If (i) and (ii) above do not hold, then this function clips and aggregates
gradients at the microbatch level.
Args:
@ -193,6 +204,13 @@ def make_dp_model_class(cls):
Returns:
See the base class.
"""
output_metrics = {}
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
batch_size = tf.shape(y)[0]
eff_microbatch_size = self._num_microbatches or batch_size
privatized_loss_name = 'privatized_loss'
# Branch based on gradient clipping algorithm.
if self._enable_fast_peg_computation:
logging.info(
'Computing gradients using the fast per-example gradient '
@ -200,60 +218,73 @@ def make_dp_model_class(cls):
)
# Computes the per-example gradient norms using a "fast" clipping
# trick, and uses these norms to clip the per-example gradients.
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
y_pred, clipped_grads = clip_grads.compute_pred_and_clipped_gradients(
self,
x,
y,
self._l2_norm_clip,
self._layer_registry,
self._num_microbatches,
# NOTE: Reshaping of the input according to the effective number of
# microbatches is done here.
clipped_grads, y_pred, weighted_loss = (
clip_grads.compute_clipped_gradients_and_outputs(
self,
x,
y,
self._l2_norm_clip,
self._layer_registry,
self._num_microbatches,
)
)
batch_size = self._num_microbatches or tf.shape(y)[0]
grads = gradient_clipping_utils.add_aggregate_noise(
self,
clipped_grads,
batch_size,
eff_microbatch_size,
self._l2_norm_clip,
self._noise_multiplier,
)
output_metrics[privatized_loss_name] = weighted_loss
else:
logging.info('Computing gradients using microbatching.')
# Computes per-example clipped gradients directly. This is called
# if at least one of the layers cannot use the "fast" gradient clipping
# algorithm.
# TODO(wkong): check if the following is valid with sample weights.
_, y = data
batch_size = y.shape[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
if batch_size % self._num_microbatches != 0:
raise ValueError('Number of_microbatches must divide batch size.')
def reshape_fn(x):
new_shape = (
self._num_microbatches,
batch_size // self._num_microbatches,
) + x.shape[1:]
return tf.reshape(x, new_shape)
data = tf.nest.map_structure(reshape_fn, data)
y_pred, _, per_eg_grads = tf.vectorized_map(
self._compute_per_example_grads, data
reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_microbatch_size)
microbatched_data = tf.nest.map_structure(reshape_fn, data)
microbatched_losses, clipped_grads = tf.vectorized_map(
self._compute_per_example_grads,
microbatched_data,
)
y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:])
y_pred = self(x, training=True)
grads = tf.nest.map_structure(
self._reduce_per_example_grads, per_eg_grads
self._reduce_per_example_grads, clipped_grads
)
if self.loss.reduction == tf.keras.losses.Reduction.SUM:
microbatched_loss = tf.reduce_sum(microbatched_losses)
else:
microbatched_loss = tf.reduce_mean(microbatched_losses)
output_metrics[privatized_loss_name] = microbatched_loss
# Add the values and gradients contributed by regularization losses.
if self.losses:
logging.warning(
'Losses in `model.losses` must be input (batch) independent in '
'order to obtain the desired differential privacy guarantees.'
)
with tf.GradientTape() as tape:
summed_regularization_loss = tf.add_n(self.losses)
regularization_grads = tape.gradient(
summed_regularization_loss,
self.trainable_variables,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
grads = [a + b for (a, b) in zip(grads, regularization_grads)]
output_metrics[privatized_loss_name] += summed_regularization_loss
# Log the true loss.
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Forward the private gradients to the optimizer and return the results.
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
for m in self.metrics:
output_metrics[m.name] = m.result()
return output_metrics
return DPModelClass

View file

@ -282,5 +282,93 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
],
)
# Simple test to check that regularizer gradients are contributing to the
# final gradient.
@parameterized.named_parameters(
('no_registry', None),
('default_registry', layer_registry.make_default_layer_registry()),
)
def testRegularizationGradient(self, registry):
input_dim = 10
batch_size = 2
regularizer_multiplier = 0.025
inputs = tf.keras.layers.Input((input_dim,))
dense_lyr = tf.keras.layers.Dense(
1,
kernel_initializer='ones',
use_bias=False,
kernel_regularizer=tf.keras.regularizers.L2(regularizer_multiplier),
)
# Zero-out outputs to avoid contributions from the main loss function.
outputs = tf.multiply(dense_lyr(inputs), 0.0)
model = dp_keras_model.DPModel(
inputs=inputs,
outputs=outputs,
l2_norm_clip=1e9,
noise_multiplier=0.0,
layer_registry=registry,
)
model.compile(
loss=tf.keras.losses.MeanSquaredError(),
optimizer=tf.keras.optimizers.SGD(1.0),
run_eagerly=True,
)
x_batch = tf.reshape(
tf.range(input_dim * batch_size, dtype=tf.float32),
[batch_size, input_dim],
)
y_batch = tf.zeros([batch_size, 1])
model.fit(x=x_batch, y=y_batch)
self.assertAllClose(
model.trainable_variables,
tf.multiply(
tf.ones_like(model.trainable_variables),
1.0 - 2.0 * regularizer_multiplier,
),
)
# Simple test to check that custom input regularization does NOT contribute
# to the gradient.
@parameterized.named_parameters(
('no_registry', None),
('default_registry', layer_registry.make_default_layer_registry()),
)
def testCustomRegularizationZeroGradient(self, registry):
input_dim = 10
batch_size = 2
inputs = tf.keras.layers.Input((input_dim,))
dense_lyr = tf.keras.layers.Dense(
1,
kernel_initializer='ones',
use_bias=False,
)
# Zero-out outputs to avoid contributions from the main loss function.
outputs = tf.multiply(dense_lyr(inputs), 0.0)
model = dp_keras_model.DPModel(
inputs=inputs,
outputs=outputs,
l2_norm_clip=1e9,
noise_multiplier=0.0,
layer_registry=registry,
)
model.add_loss(tf.reduce_sum(inputs))
model.compile(
loss=tf.keras.losses.MeanSquaredError(),
optimizer=tf.keras.optimizers.SGD(1.0),
run_eagerly=True,
)
x_batch = tf.reshape(
tf.range(input_dim * batch_size, dtype=tf.float32),
[batch_size, input_dim],
)
y_batch = tf.zeros([batch_size, 1])
model.fit(x=x_batch, y=y_batch)
self.assertAllClose(
model.trainable_variables, tf.ones_like(model.trainable_variables)
)
if __name__ == '__main__':
tf.test.main()