Fix bug in DPModel that shows up in distributed training.

PiperOrigin-RevId: 528026372
This commit is contained in:
Walid Krichene 2023-04-28 17:30:48 -07:00 committed by A. Unique TensorFlower
parent 9710a4acc7
commit e65e14b2d6
3 changed files with 202 additions and 181 deletions

View file

@ -21,13 +21,14 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function). `compute_gradient_norms()` function).
""" """
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union from typing import Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
LossFn = Callable[..., tf.Tensor]
def get_registry_generator_fn( def get_registry_generator_fn(
@ -71,7 +72,7 @@ def compute_gradient_norms(
x_batch: InputTensor, x_batch: InputTensor,
y_batch: tf.Tensor, y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry, layer_registry: lr.LayerRegistry,
per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None, per_example_loss_fn: Optional[LossFn] = None,
num_microbatches: Optional[lr.BatchSize] = None, num_microbatches: Optional[lr.BatchSize] = None,
trainable_vars: Optional[List[tf.Variable]] = None, trainable_vars: Optional[List[tf.Variable]] = None,
): ):
@ -92,9 +93,9 @@ def compute_gradient_norms(
compute gradient norms quickly. See compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details. more details.
per_example_loss_fn: If not None, used as the function to compute the per_example_loss_fn: takes as input predictions, labels and weights, and
vectorized per example loss. Otherwise, we derive it from `input_model`'s outputs a vector of per-example losses. If None, derived from
loss function. `input_model.loss` by disabling its reduction.
num_microbatches: An optional number or scalar `tf.Tensor` for the number of num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple num_microbatches (in this case, the batch dimension needs to be a multiple
@ -193,7 +194,8 @@ def compute_clipped_gradients_and_outputs(
l2_norm_clip: float, l2_norm_clip: float,
layer_registry: lr.LayerRegistry, layer_registry: lr.LayerRegistry,
num_microbatches: Optional[lr.BatchSize] = None, num_microbatches: Optional[lr.BatchSize] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor, float]: clipping_loss: Optional[LossFn] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
"""Computes the per-example clipped loss gradient and other useful outputs. """Computes the per-example clipped loss gradient and other useful outputs.
Given a batch of observations `(x_batch, y_batch)`, the main steps of this Given a batch of observations `(x_batch, y_batch)`, the main steps of this
@ -224,14 +226,21 @@ def compute_clipped_gradients_and_outputs(
microbatches. If not None, indicates that the loss is grouped into microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple num_microbatches (in this case, the batch dimension needs to be a multiple
of num_microbatches). of num_microbatches).
clipping_loss: If provided, used for the clipping computation. Defaults to
`input_model.compiled_loss`. Specifying a `clipping_loss` can be useful to
avoid calling `input_model.compiled_loss`, as this will append the value
of the clipped loss to the reported metrics, and this can be misleading as
the value of the clipped loss does not reflect the true loss.
Returns: Returns:
A `tuple` `(grad, y_pred, weighted_loss_value)`. The first element is the A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the
clipped gradient of the loss function, the second is the result of clipped gradient of the loss function, the second is the result of
applying `input_model` to `x_batch`, and the third is loss value of applying `input_model` to `x_batch`, and the third is loss value of
`input_model`, weighted by the loss weights generated by a specific `input_model`, weighted by the loss weights generated by a specific
`compute_clip_weights()` call. `compute_clip_weights()` call.
""" """
if clipping_loss is None:
clipping_loss = input_model.compiled_loss
gradient_norms = compute_gradient_norms( gradient_norms = compute_gradient_norms(
input_model, input_model,
x_batch, x_batch,
@ -260,19 +269,10 @@ def compute_clipped_gradients_and_outputs(
if num_microbatches is None if num_microbatches is None
else lr.add_microbatch_axis(y_pred, num_microbatches) else lr.add_microbatch_axis(y_pred, num_microbatches)
) )
# NOTE: We do not log the loss values here. The caller should invoke clipping_loss_value = clipping_loss(loss_y_batch, loss_y_pred, loss_weights)
# `input_model.compute_loss()` to log loss values. Specifically,
# calling `input_model.compute_loss()` performs the following steps:
#
# (i) sums `input_model.loss` with the regularization losses given in
# `input_model.losses` to obtain the total loss
# (ii) evaluates the total loss with sample weights (if given)
weighted_loss_value = input_model.loss(
loss_y_batch, loss_y_pred, loss_weights
)
clipped_grads = tape.gradient( clipped_grads = tape.gradient(
weighted_loss_value, clipping_loss_value,
input_model.trainable_variables, input_model.trainable_variables,
unconnected_gradients=tf.UnconnectedGradients.ZERO, unconnected_gradients=tf.UnconnectedGradients.ZERO,
) )
return clipped_grads, y_pred, weighted_loss_value return clipped_grads, y_pred, clipping_loss_value

View file

@ -15,10 +15,13 @@
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
def make_dp_model_class(cls): def make_dp_model_class(cls):
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it.""" """Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
@ -122,6 +125,7 @@ def make_dp_model_class(cls):
self._l2_norm_clip = l2_norm_clip self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier self._noise_multiplier = noise_multiplier
self._layer_registry = layer_registry self._layer_registry = layer_registry
self._clipping_loss = None
# Given that `num_microbatches` was added as an argument after the fact, # Given that `num_microbatches` was added as an argument after the fact,
# this check helps detect unintended calls to the earlier API. # this check helps detect unintended calls to the earlier API.
@ -176,15 +180,34 @@ def make_dp_model_class(cls):
) )
def _compute_per_example_grads(self, data): def _compute_per_example_grads(self, data):
if self._clipping_loss is None:
self._make_clipping_loss()
microbatched_x, microbatched_y = data microbatched_x, microbatched_y = data
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
microbatched_y_pred = self(microbatched_x, training=True) microbatched_y_pred = self(microbatched_x, training=True)
# NOTE: Calling `self.loss()` neither logs the total loss nor does it # NOTE: `self._clipping_loss` does not include any regularization terms.
# include any regularization terms. microbatched_loss = self._clipping_loss(
microbatched_loss = self.loss(microbatched_y, microbatched_y_pred) microbatched_y, microbatched_y_pred
)
grads_list = tape.gradient(microbatched_loss, self.trainable_variables) grads_list = tape.gradient(microbatched_loss, self.trainable_variables)
clipped_grads = self._process_per_example_grads(grads_list) clipped_grads = self._process_per_example_grads(grads_list)
return microbatched_loss, clipped_grads return clipped_grads
def _make_clipping_loss(self):
"""Creates a LossesContainer to be used for clipping.
To compute the privatized loss, we wrap the model's compiled_loss inside a
new LossesContainer. This lets us avoid calling model.compiled_loss, which
appends the loss value to the returned metrics (we want to avoid this as
the privatized loss does not reflect the true loss and can be misleading).
"""
losses_container_cls = self.compiled_loss.__class__
self._clipping_loss = losses_container_cls(
self.compiled_loss._user_losses, # pylint:disable=protected-access
loss_weights=self.compiled_loss._user_loss_weights, # pylint:disable=protected-access
output_names=self.output_names,
total_loss_mean=tf.keras.metrics.Mean(name=_PRIVATIZED_LOSS_NAME),
)
def train_step(self, data): def train_step(self, data):
"""DP-SGD version of base class method. """DP-SGD version of base class method.
@ -205,11 +228,16 @@ def make_dp_model_class(cls):
Returns: Returns:
See the base class. See the base class.
""" """
if self._clipping_loss is None:
self._make_clipping_loss()
output_metrics = {} output_metrics = {}
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data) x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data)
if weights is not None:
raise NotImplementedError(
'DPModel does not currently support weighted losses.'
)
batch_size = tf.shape(y)[0] batch_size = tf.shape(y)[0]
eff_num_microbatches = self._num_microbatches or batch_size eff_num_microbatches = self._num_microbatches or batch_size
privatized_loss_name = 'privatized_loss'
# Branch based on gradient clipping algorithm. # Branch based on gradient clipping algorithm.
if self._enable_fast_peg_computation: if self._enable_fast_peg_computation:
@ -221,7 +249,7 @@ def make_dp_model_class(cls):
# trick, and uses these norms to clip the per-example gradients. # trick, and uses these norms to clip the per-example gradients.
# NOTE: Reshaping of the input according to the effective number of # NOTE: Reshaping of the input according to the effective number of
# microbatches is done here. # microbatches is done here.
clipped_grads, y_pred, weighted_loss = ( clipped_grads, y_pred, clipping_loss = (
clip_grads.compute_clipped_gradients_and_outputs( clip_grads.compute_clipped_gradients_and_outputs(
self, self,
x, x,
@ -229,8 +257,10 @@ def make_dp_model_class(cls):
self._l2_norm_clip, self._l2_norm_clip,
self._layer_registry, self._layer_registry,
self._num_microbatches, self._num_microbatches,
self._clipping_loss,
) )
) )
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
if self._noise_multiplier > 0: if self._noise_multiplier > 0:
grads = gradient_clipping_utils.add_aggregate_noise( grads = gradient_clipping_utils.add_aggregate_noise(
self, self,
@ -241,7 +271,6 @@ def make_dp_model_class(cls):
) )
else: else:
grads = clipped_grads grads = clipped_grads
output_metrics[privatized_loss_name] = weighted_loss
else: else:
logging.info('Computing gradients using original clipping algorithm.') logging.info('Computing gradients using original clipping algorithm.')
# Computes per-example clipped gradients directly. This is called # Computes per-example clipped gradients directly. This is called
@ -249,7 +278,7 @@ def make_dp_model_class(cls):
# algorithm. # algorithm.
reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_num_microbatches) reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_num_microbatches)
microbatched_data = tf.nest.map_structure(reshape_fn, data) microbatched_data = tf.nest.map_structure(reshape_fn, data)
microbatched_losses, clipped_grads = tf.vectorized_map( clipped_grads = tf.vectorized_map(
self._compute_per_example_grads, self._compute_per_example_grads,
microbatched_data, microbatched_data,
) )
@ -257,11 +286,6 @@ def make_dp_model_class(cls):
grads = tf.nest.map_structure( grads = tf.nest.map_structure(
self._reduce_per_example_grads, clipped_grads self._reduce_per_example_grads, clipped_grads
) )
if self.loss.reduction == tf.keras.losses.Reduction.SUM:
microbatched_loss = tf.reduce_sum(microbatched_losses)
else:
microbatched_loss = tf.reduce_mean(microbatched_losses)
output_metrics[privatized_loss_name] = microbatched_loss
# Add the values and gradients contributed by regularization losses. # Add the values and gradients contributed by regularization losses.
if self.losses: if self.losses:
@ -277,9 +301,10 @@ def make_dp_model_class(cls):
unconnected_gradients=tf.UnconnectedGradients.ZERO, unconnected_gradients=tf.UnconnectedGradients.ZERO,
) )
grads = [a + b for (a, b) in zip(grads, regularization_grads)] grads = [a + b for (a, b) in zip(grads, regularization_grads)]
output_metrics[privatized_loss_name] += summed_regularization_loss if self._enable_fast_peg_computation:
output_metrics[_PRIVATIZED_LOSS_NAME] += summed_regularization_loss
# Log the true loss. # Log the true loss, including regularization losses.
self.compiled_loss(y, y_pred, regularization_losses=self.losses) self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Forward the private gradients to the optimizer and return the results. # Forward the private gradients to the optimizer and return the results.

View file

@ -28,16 +28,6 @@ def get_data():
return data, labels return data, labels
def get_layer_registries():
# Outputs a list of testable layer registries.
# The empty registry {} tests the behavior of the standard approach,
# while the other one tests the fast gradient clipping algorithm.
return [
layer_registry.LayerRegistry(),
layer_registry.make_default_layer_registry(),
]
class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase): class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
def testBaseline(self): def testBaseline(self):
@ -65,44 +55,49 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(model_weights[0], [[0.90], [1.20]]) self.assertAllClose(model_weights[0], [[0.90], [1.20]])
self.assertAllClose(model_weights[1], [0.30]) self.assertAllClose(model_weights[1], [0.30])
@parameterized.named_parameters( @parameterized.product(
('l2_norm_clip 10.0', 10.0), l2_norm_clip=(10.0, 40.0, 200.0),
('l2_norm_clip 40.0', 40.0), fast_clipping=(True, False),
('l2_norm_clip 200.0', 200.0),
) )
def testClippingNorm(self, l2_norm_clip): def testClippingNorm(self, l2_norm_clip, fast_clipping):
"""Tests that clipping norm works.""" """Tests that clipping norm works."""
train_data, train_labels = get_data() train_data, train_labels = get_data()
for test_reg in get_layer_registries(): # Simple linear model returns w * x + b.
# Simple linear model returns w * x + b. model = dp_keras_model.DPSequential(
model = dp_keras_model.DPSequential( l2_norm_clip=l2_norm_clip,
l2_norm_clip=l2_norm_clip, noise_multiplier=0.0,
noise_multiplier=0.0, layer_registry=layer_registry.make_default_layer_registry()
layer_registry=test_reg, if fast_clipping
layers=[ else None,
tf.keras.layers.InputLayer(input_shape=(2,)), layers=[
tf.keras.layers.Dense( tf.keras.layers.InputLayer(input_shape=(2,)),
1, kernel_initializer='zeros', bias_initializer='zeros' tf.keras.layers.Dense(
), 1, kernel_initializer='zeros', bias_initializer='zeros'
], ),
) ],
learning_rate = 0.01 )
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) learning_rate = 0.01
loss = tf.keras.losses.MeanSquaredError() optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
model.compile(optimizer=optimizer, loss=loss) loss = tf.keras.losses.MeanSquaredError()
model.fit(train_data, train_labels, epochs=1, batch_size=1) model.compile(optimizer=optimizer, loss=loss)
expected_loss = loss(train_labels, model(train_data))
results = model.fit(train_data, train_labels, epochs=1, batch_size=1)
model_weights = model.get_weights() model_weights = model.get_weights()
unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2) unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2)
scale = min(1.0, l2_norm_clip / unclipped_gradient) scale = min(1.0, l2_norm_clip / unclipped_gradient)
expected_weights = np.array([[90], [120]]) * scale * learning_rate expected_weights = np.array([[90], [120]]) * scale * learning_rate
expected_bias = np.array([30]) * scale * learning_rate expected_bias = np.array([30]) * scale * learning_rate
# Check parameters are as expected, taking into account the learning rate. # Check parameters are as expected, taking into account the learning rate.
self.assertAllClose(model_weights[0], expected_weights) self.assertAllClose(model_weights[0], expected_weights)
self.assertAllClose(model_weights[1], expected_bias) self.assertAllClose(model_weights[1], expected_bias)
# Check the value of the loss.
actual_loss = results.history['loss'][0]
self.assertAllClose(expected_loss, actual_loss)
def _compute_expected_gradients(self, data, labels, w, l2_norm_clip, def _compute_expected_gradients(self, data, labels, w, l2_norm_clip,
num_microbatches): num_microbatches):
@ -127,64 +122,61 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
final_grads = np.mean(mb_grads, axis=0) final_grads = np.mean(mb_grads, axis=0)
return final_grads return final_grads
@parameterized.named_parameters( @parameterized.product(
('mb_test 0', 1.0, None), num_microbatches=(None, 1, 2, 4),
('mb_test 1', 1.0, 1), fast_clipping=(False, True),
('mb_test 2', 1.0, 2),
('mb_test 4', 1.0, 4),
) )
def testMicrobatches(self, l2_norm_clip, num_microbatches): def testMicrobatches(self, num_microbatches, fast_clipping):
l2_norm_clip = 1.0
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]) train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
w = np.zeros((2)) w = np.zeros((2))
train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]]) train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]])
learning_rate = 1.0 learning_rate = 1.0
for test_reg in get_layer_registries(): optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) loss = tf.keras.losses.MeanSquaredError()
loss = tf.keras.losses.MeanSquaredError()
# Simple linear model returns w * x. # Simple linear model returns w * x.
model = dp_keras_model.DPSequential( model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip, l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
layer_registry=test_reg, layer_registry=layer_registry.make_default_layer_registry()
layers=[ if fast_clipping
tf.keras.layers.InputLayer(input_shape=(2,)), else None,
tf.keras.layers.Dense( layers=[
1, use_bias=False, kernel_initializer='zeros' tf.keras.layers.InputLayer(input_shape=(2,)),
), tf.keras.layers.Dense(
], 1, use_bias=False, kernel_initializer='zeros'
) ),
model.compile(optimizer=optimizer, loss=loss) ],
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) )
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
model_weights = np.squeeze(model.get_weights()) model_weights = np.squeeze(model.get_weights())
effective_num_microbatches = ( effective_num_microbatches = (
train_data.shape[0] train_data.shape[0]
if model._num_microbatches is None if model._num_microbatches is None
else num_microbatches else num_microbatches
) )
expected_grads = self._compute_expected_gradients( expected_grads = self._compute_expected_gradients(
train_data, train_labels, w, l2_norm_clip, effective_num_microbatches train_data, train_labels, w, l2_norm_clip, effective_num_microbatches
) )
expected_weights = np.squeeze(-learning_rate * expected_grads) expected_weights = np.squeeze(-learning_rate * expected_grads)
self.assertAllClose(model_weights, expected_weights) self.assertAllClose(model_weights, expected_weights)
@parameterized.named_parameters( @parameterized.product(
('noise_multiplier 3 2 None', 3.0, 2.0, None), l2_norm_clip=(3.0, 5.0),
('noise_multiplier 5 4 None', 5.0, 4.0, None), noise_multiplier=(2.0, 4.0),
('noise_multiplier 3 2 1', 3.0, 2.0, 1), num_microbatches=(None, 1, 2, 4),
('noise_multiplier 5 4 1', 5.0, 4.0, 1), fast_clipping=(False, True),
('noise_multiplier 3 2 2', 3.0, 2.0, 2),
('noise_multiplier 5 4 2', 5.0, 4.0, 2),
('noise_multiplier 3 2 4', 3.0, 2.0, 4),
('noise_multiplier 5 4 4', 5.0, 4.0, 4),
) )
def testNoiseMultiplier(self, l2_norm_clip, noise_multiplier, def testNoiseMultiplier(
num_microbatches): self, l2_norm_clip, noise_multiplier, num_microbatches, fast_clipping
):
# The idea behind this test is to start with a model whose parameters # The idea behind this test is to start with a model whose parameters
# are set to zero. We then run one step of a model that produces # are set to zero. We then run one step of a model that produces
# an un-noised gradient of zero, and then compute the standard deviation # an un-noised gradient of zero, and then compute the standard deviation
@ -197,69 +189,69 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
learning_rate = 1.0 learning_rate = 1.0
for test_reg in get_layer_registries(): optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) loss = tf.keras.losses.MeanSquaredError()
loss = tf.keras.losses.MeanSquaredError()
# Simple linear model returns w * x + b. # Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential( model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip, l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier, noise_multiplier=noise_multiplier,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
layer_registry=test_reg, layer_registry=layer_registry.make_default_layer_registry()
layers=[ if fast_clipping
tf.keras.layers.InputLayer(input_shape=(1000,)), else None,
tf.keras.layers.Dense( layers=[
1, kernel_initializer='zeros', bias_initializer='zeros' tf.keras.layers.InputLayer(input_shape=(1000,)),
), tf.keras.layers.Dense(
], 1, kernel_initializer='zeros', bias_initializer='zeros'
) ),
model.compile(optimizer=optimizer, loss=loss) ],
model.fit(train_data, train_labels, epochs=1, batch_size=4) )
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4)
effective_num_microbatches = num_microbatches or train_data.shape[0] effective_num_microbatches = num_microbatches or train_data.shape[0]
model_weights = model.get_weights() model_weights = model.get_weights()
measured_std = np.std(model_weights[0]) measured_std = np.std(model_weights[0])
expected_std = ( expected_std = l2_norm_clip * noise_multiplier / effective_num_microbatches
l2_norm_clip * noise_multiplier / effective_num_microbatches
)
# Test standard deviation is close to l2_norm_clip * noise_multiplier. # Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(measured_std, expected_std, 0.1 * expected_std) self.assertNear(measured_std, expected_std, 0.1 * expected_std)
# Simple check to make sure dimensions are correct when output has # Simple check to make sure dimensions are correct when output has
# dimension > 1. # dimension > 1.
@parameterized.named_parameters( @parameterized.product(
('mb_test None 2', None, 2), num_microbatches=(None, 1, 2),
('mb_test 1 2', 1, 2), output_dimension=(2, 4),
('mb_test 2 2', 2, 2), fast_clipping=(False, True),
('mb_test 4 4', 4, 4),
) )
def testMultiDimensionalOutput(self, num_microbatches, output_dimension): def testMultiDimensionalOutput(
self, num_microbatches, output_dimension, fast_clipping
):
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]) train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
train_labels = np.array([[0], [1], [1], [0]]) train_labels = np.array([[0], [1], [1], [0]])
learning_rate = 1.0 learning_rate = 1.0
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
for test_reg in get_layer_registries(): model = dp_keras_model.DPSequential(
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) l2_norm_clip=1.0e9,
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) noise_multiplier=0.0,
num_microbatches=num_microbatches,
model = dp_keras_model.DPSequential( layer_registry=layer_registry.make_default_layer_registry()
l2_norm_clip=1.0e9, if fast_clipping
noise_multiplier=0.0, else None,
num_microbatches=num_microbatches, layers=[
layer_registry=test_reg, tf.keras.layers.InputLayer(input_shape=(2,)),
layers=[ tf.keras.layers.Dense(
tf.keras.layers.InputLayer(input_shape=(2,)), output_dimension, use_bias=False, kernel_initializer='zeros'
tf.keras.layers.Dense( ),
output_dimension, use_bias=False, kernel_initializer='zeros' tf.keras.layers.Dense(1),
), ],
tf.keras.layers.Dense(1), )
], model.compile(optimizer=optimizer, loss=loss_fn)
) model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
# Checks that calls to earlier API using `use_xla` as a positional argument # Checks that calls to earlier API using `use_xla` as a positional argument
# raise an exception. # raise an exception.
@ -285,10 +277,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
# Simple test to check that regularizer gradients are contributing to the # Simple test to check that regularizer gradients are contributing to the
# final gradient. # final gradient.
@parameterized.named_parameters( @parameterized.named_parameters(
('no_registry', None), ('fast_clipping', True),
('default_registry', layer_registry.make_default_layer_registry()), ('no_fast_clipping', False),
) )
def testRegularizationGradient(self, registry): def testRegularizationGradient(self, fast_clipping):
input_dim = 10 input_dim = 10
batch_size = 2 batch_size = 2
regularizer_multiplier = 0.025 regularizer_multiplier = 0.025
@ -306,7 +298,9 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
outputs=outputs, outputs=outputs,
l2_norm_clip=1e9, l2_norm_clip=1e9,
noise_multiplier=0.0, noise_multiplier=0.0,
layer_registry=registry, layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
) )
model.compile( model.compile(
loss=tf.keras.losses.MeanSquaredError(), loss=tf.keras.losses.MeanSquaredError(),
@ -331,10 +325,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
# Simple test to check that custom input regularization does NOT contribute # Simple test to check that custom input regularization does NOT contribute
# to the gradient. # to the gradient.
@parameterized.named_parameters( @parameterized.named_parameters(
('no_registry', None), ('fast_clipping', True),
('default_registry', layer_registry.make_default_layer_registry()), ('no_fast_clipping', False),
) )
def testCustomRegularizationZeroGradient(self, registry): def testCustomRegularizationZeroGradient(self, fast_clipping):
input_dim = 10 input_dim = 10
batch_size = 2 batch_size = 2
inputs = tf.keras.layers.Input((input_dim,)) inputs = tf.keras.layers.Input((input_dim,))
@ -350,7 +344,9 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
outputs=outputs, outputs=outputs,
l2_norm_clip=1e9, l2_norm_clip=1e9,
noise_multiplier=0.0, noise_multiplier=0.0,
layer_registry=registry, layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
) )
model.add_loss(tf.reduce_sum(inputs)) model.add_loss(tf.reduce_sum(inputs))
model.compile( model.compile(