Integrate the fast gradient clipping algorithm with the DP Keras Model class.

PiperOrigin-RevId: 504931452
This commit is contained in:
A. Unique TensorFlower 2023-01-26 13:45:23 -08:00
parent bc84ed7bfb
commit 9ed34da715
4 changed files with 264 additions and 119 deletions

View file

@ -93,8 +93,6 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
losses = per_example_loss_fn(y_batch, model_outputs)
if tf.rank(tf.squeeze(losses)) > 1:
raise NotImplementedError('Vector losses are not supported.')
summed_loss = tf.reduce_sum(losses)
# Second loop computes the norm of the gradient of the loss with respect to
# the pre-activation tensors, and multiplies these norms with the results of

View file

@ -15,6 +15,11 @@ py_library(
"dp_keras_model.py",
],
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories",
],
)
py_test(
@ -22,5 +27,8 @@ py_test(
srcs = ["dp_keras_model_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = ["//tensorflow_privacy/privacy/keras_models:dp_keras_model"],
deps = [
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories",
"//tensorflow_privacy/privacy/keras_models:dp_keras_model",
],
)

View file

@ -13,19 +13,38 @@
# limitations under the License.
"""Keras Model for vectorized dpsgd with XLA acceleration."""
from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
def make_dp_model_class(cls):
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
class DPModelClass(cls): # pylint: disable=empty-docstring
__doc__ = ("""DP subclass of `{base_model}`.
class DPModelClass(cls): # pylint: disable=missing-class-docstring
__doc__ = (
"""DP subclass of `{base_model}`.
This can be used as a differentially private replacement for
{base_model}. This class implements DP-SGD using the standard
Gaussian mechanism.
This class also utilizes a faster gradient clipping algorithm if the
following two conditions hold:
(i) the trainable layers of the model are keys in the `dict` input
`layer_registry`,
(ii) the loss `tf.Tensor` for a given batch of examples is either a
scalar or a 2D `tf.Tensor` that has only one column
`(i.e., tf.shape(loss)[1] == 1)` and whose i-th row corresponds to
the loss of the i-th example.
This clipping algorithm specifically computes clipped gradients at the
per-example level using the layer registry functions in `layer_registry`
(see clip_grads.py for more information about the algorithm). In this
setting, microbatching is not used (it is equivalent to
`num_microbatches == batch_size`), and the input `num_microbatches`
is ignored.
When instantiating this class, you need to supply several
DP-related arguments followed by the standard arguments for
`{short_base_model}`.
@ -53,10 +72,12 @@ def make_dp_model_class(cls):
model.fit(train_data, train_labels, epochs=1, batch_size=32)
```
""").format(
base_model='tf.keras.' + cls.__name__,
short_base_model=cls.__name__,
dp_model_class='DP' + cls.__name__)
"""
).format(
base_model='tf.keras.' + cls.__name__,
short_base_model=cls.__name__,
dp_model_class='DP' + cls.__name__,
)
def __init__(
self,
@ -64,24 +85,31 @@ def make_dp_model_class(cls):
noise_multiplier,
num_microbatches=None,
use_xla=True,
layer_registry=None,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
**kwargs,
):
"""Initializes the DPModelClass.
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch
gradients).
noise_multiplier: Ratio of the standard deviation to the clipping
norm.
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches.
use_xla: If `True`, compiles train_step to XLA.
layer_registry: A `dict` of layers that support "fast" gradient norm
computations. The key is the class of the layer and the value is a
function that returns a `tuple` `(output, sqr_grad_norms, vars)`,
where `output` is the pre-activator tensor, `sqr_grad_norms` is
related to the squared norms of a layer's pre-activation tensor, and
`vars` are relevant trainable weights (see
`layer_registry_factories.py` for examples).
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__`
method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
super().__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier
self._layer_registry = layer_registry
# Given that `num_microbatches` was added as an argument after the fact,
# this check helps detect unintended calls to the earlier API.
@ -91,7 +119,27 @@ def make_dp_model_class(cls):
raise ValueError('Boolean value supplied for `num_microbatches`. '
'Did you intend it for `use_xla`?')
self._num_microbatches = num_microbatches
# If all the trainable layers are in the input layer registry, we
# don't need to use microbatching and can instead use the "fast"
# chain rule trick for computing per-example gradients (peg).
if (
layer_registry is not None
and gradient_clipping_utils.all_trainable_layers_are_registered(
self, layer_registry
)
and gradient_clipping_utils.has_internal_compute_graph(self)
):
if num_microbatches is not None:
raise ValueError(
'Cannot initialize a model where num_microbatches '
'is not `None` and all trainable layers are '
'registered in layer_registry.'
)
self._num_microbatches = None
self._enable_fast_peg_computation = True
else:
self._num_microbatches = num_microbatches
self._enable_fast_peg_computation = False
if use_xla:
self.train_step = tf.function(
@ -126,29 +174,72 @@ def make_dp_model_class(cls):
return y_pred, loss, clipped_grads
def train_step(self, data):
"""DP-SGD version of base class method."""
_, y = data
batch_size = y.shape[0]
"""DP-SGD version of base class method.
if self._num_microbatches is None:
self._num_microbatches = batch_size
if batch_size % self._num_microbatches != 0:
raise ValueError('Number of_microbatches must divide batch size.')
Uses the "fast" gradient clipping algorithm to generate per-example
clipped gradients if (i) all the trainable layers of the model are
registered in the layer_registry input of the model constructor and
(ii) if the model contains an internal compute graph (e.g., this
condition is satisfied if the model subclasses the keras.Sequential or
keras.engine.functional.Functional class).
def reshape_fn(x):
new_shape = (self._num_microbatches,
batch_size // self._num_microbatches) + x.shape[1:]
return tf.reshape(x, new_shape)
If (i) and (ii) above do not hold, then clips and aggregates
gradients at the microbatch level.
data = tf.nest.map_structure(reshape_fn, data)
Args:
data: see the base class.
y_pred, _, per_eg_grads = tf.vectorized_map(
self._compute_per_example_grads, data)
Returns:
See the base class.
"""
if self._enable_fast_peg_computation:
logging.info(
'Computing gradients using the fast per-example gradient '
'norm algorithm.'
)
# Computes the per-example gradient norms using a "fast" clipping
# trick, and uses these norms to clip the per-example gradients.
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
y_pred, clipped_grads = clip_grads.compute_pred_and_clipped_gradients(
self, x, y, self._l2_norm_clip, self._layer_registry
)
grads = gradient_clipping_utils.add_aggregate_noise(
self, x, clipped_grads, self._l2_norm_clip, self._noise_multiplier
)
else:
logging.info('Computing gradients using microbatching.')
# Computes per-example clipped gradients directly. This is called
# if at least one of the layers cannot use the "fast" gradient clipping
# algorithm.
# TODO(wkong): check if the following is valid with sample weights.
_, y = data
batch_size = y.shape[0]
y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:])
if self._num_microbatches is None:
self._num_microbatches = batch_size
if batch_size % self._num_microbatches != 0:
raise ValueError('Number of_microbatches must divide batch size.')
grads = tf.nest.map_structure(self._reduce_per_example_grads,
per_eg_grads)
def reshape_fn(x):
new_shape = (
self._num_microbatches,
batch_size // self._num_microbatches,
) + x.shape[1:]
return tf.reshape(x, new_shape)
data = tf.nest.map_structure(reshape_fn, data)
y_pred, _, per_eg_grads = tf.vectorized_map(
self._compute_per_example_grads, data
)
y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:])
grads = tf.nest.map_structure(
self._reduce_per_example_grads, per_eg_grads
)
# Forward the private gradients to the optimizer and return the results.
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}

View file

@ -13,10 +13,9 @@
# limitations under the License.
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories
from tensorflow_privacy.privacy.keras_models import dp_keras_model
@ -29,6 +28,13 @@ def get_data():
return data, labels
def get_layer_registries():
# Outputs a list of testable layer registries.
# The empty registry {} tests the behavior of the standard approach,
# while the other one tests the fast gradient clipping algorithm.
return [{}, layer_registry_factories.make_default_layer_registry()]
class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
def testBaseline(self):
@ -65,32 +71,35 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
"""Tests that clipping norm works."""
train_data, train_labels = get_data()
# Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros')
])
learning_rate = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
for test_reg in get_layer_registries():
# Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
layer_registry=test_reg,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros'
),
],
)
learning_rate = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=1)
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=1)
model_weights = model.get_weights()
model_weights = model.get_weights()
unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2)
scale = min(1.0, l2_norm_clip / unclipped_gradient)
expected_weights = np.array([[90], [120]]) * scale * learning_rate
expected_bias = np.array([30]) * scale * learning_rate
unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2)
scale = min(1.0, l2_norm_clip / unclipped_gradient)
expected_weights = np.array([[90], [120]]) * scale * learning_rate
expected_bias = np.array([30]) * scale * learning_rate
# Check parameters are as expected, taking into account the learning rate.
self.assertAllClose(model_weights[0], expected_weights)
self.assertAllClose(model_weights[1], expected_bias)
# Check parameters are as expected, taking into account the learning rate.
self.assertAllClose(model_weights[0], expected_weights)
self.assertAllClose(model_weights[1], expected_bias)
def _compute_expected_gradients(self, data, labels, w, l2_norm_clip,
num_microbatches):
@ -98,9 +107,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
if num_microbatches is None:
num_microbatches = batch_size
preds = np.matmul(data, w)
preds = np.matmul(data, np.expand_dims(w, axis=1))
grads = 2 * data * (preds - labels)
grads = 2 * data * (labels - preds)[:, np.newaxis]
grads = np.reshape(grads,
[num_microbatches, batch_size // num_microbatches, -1])
@ -123,32 +133,45 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
def testMicrobatches(self, l2_norm_clip, num_microbatches):
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
w = np.zeros((2))
train_labels = np.array([1.0, 3.0, -2.0, -4.0])
train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]])
learning_rate = 1.0
expected_grads = self._compute_expected_gradients(train_data, train_labels,
w, l2_norm_clip,
num_microbatches)
expected_weights = np.squeeze(learning_rate * expected_grads)
for test_reg, test_nm in zip(
get_layer_registries(), [num_microbatches, None]
):
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
# Simple linear model returns w * x.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
num_microbatches=test_nm,
layer_registry=test_reg,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, use_bias=False, kernel_initializer='zeros'
),
],
)
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
# Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, use_bias=False, kernel_initializer='zeros')
])
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
model_weights = np.squeeze(model.get_weights())
model_weights = np.squeeze(model.get_weights())
self.assertAllClose(model_weights, expected_weights)
effective_num_microbatches = (
train_data.shape[0]
if model._num_microbatches is None
else num_microbatches
)
expected_grads = self._compute_expected_gradients(
train_data, train_labels, w, l2_norm_clip, effective_num_microbatches
)
expected_weights = np.squeeze(-learning_rate * expected_grads)
self.assertAllClose(model_weights, expected_weights)
@parameterized.named_parameters(
('noise_multiplier 3 2 1', 3.0, 2.0, 1),
@ -168,59 +191,81 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
# Data is one example of length 1000, set to zero, with label zero.
train_data = np.zeros((4, 1000))
train_labels = np.array([0.0, 0.0, 0.0, 0.0])
train_labels = np.array([[0.0], [0.0], [0.0], [0.0]])
learning_rate = 1.0
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
# Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
num_microbatches=num_microbatches,
layers=[
tf.keras.layers.InputLayer(input_shape=(1000,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros')
])
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4)
for test_reg, test_nm in zip(
get_layer_registries(), [num_microbatches, None]
):
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
model_weights = model.get_weights()
measured_std = np.std(model_weights[0])
expected_std = l2_norm_clip * noise_multiplier / num_microbatches
# Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
num_microbatches=test_nm,
layer_registry=test_reg,
layers=[
tf.keras.layers.InputLayer(input_shape=(1000,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros'
),
],
)
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4)
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(measured_std, expected_std, 0.1 * expected_std)
effective_num_microbatches = (
train_data.shape[0]
if model._num_microbatches is None
else num_microbatches
)
model_weights = model.get_weights()
measured_std = np.std(model_weights[0])
expected_std = (
l2_norm_clip * noise_multiplier / effective_num_microbatches
)
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(measured_std, expected_std, 0.1 * expected_std)
# Simple check to make sure dimensions are correct when output has
# dimension > 1.
@parameterized.named_parameters(
('mb_test None 1', None, 1),
('mb_test None 2', None, 2),
('mb_test 1 2', 1, 2),
('mb_test 2 2', 2, 2),
('mb_test 4 4', 4, 4),
)
def testMultiDimensionalOutput(self, num_microbatches, output_dimension):
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
train_labels = np.array([0, 1, 1, 0])
train_labels = np.array([[0], [1], [1], [0]])
learning_rate = 1.0
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
for test_reg, test_nm in zip(
get_layer_registries(), [num_microbatches, None]
):
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = dp_keras_model.DPSequential(
l2_norm_clip=1.0e9,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
output_dimension, use_bias=False, kernel_initializer='zeros')
])
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
model = dp_keras_model.DPSequential(
l2_norm_clip=1.0e9,
noise_multiplier=0.0,
num_microbatches=test_nm,
layer_registry=test_reg,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
output_dimension, use_bias=False, kernel_initializer='zeros'
),
tf.keras.layers.Dense(1),
],
)
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
# Checks that calls to earlier API using `use_xla` as a positional argument
# raise an exception.
@ -237,8 +282,11 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
2, use_bias=False, kernel_initializer='zeros')
])
2, use_bias=False, kernel_initializer='zeros'
),
tf.keras.layers.Dense(1),
],
)
if __name__ == '__main__':
tf.test.main()