Integrate the fast gradient clipping algorithm with the DP Keras Model class.
PiperOrigin-RevId: 504931452
This commit is contained in:
parent
bc84ed7bfb
commit
9ed34da715
4 changed files with 264 additions and 119 deletions
|
@ -93,8 +93,6 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
||||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||||
losses = per_example_loss_fn(y_batch, model_outputs)
|
losses = per_example_loss_fn(y_batch, model_outputs)
|
||||||
if tf.rank(tf.squeeze(losses)) > 1:
|
|
||||||
raise NotImplementedError('Vector losses are not supported.')
|
|
||||||
summed_loss = tf.reduce_sum(losses)
|
summed_loss = tf.reduce_sum(losses)
|
||||||
# Second loop computes the norm of the gradient of the loss with respect to
|
# Second loop computes the norm of the gradient of the loss with respect to
|
||||||
# the pre-activation tensors, and multiplies these norms with the results of
|
# the pre-activation tensors, and multiplies these norms with the results of
|
||||||
|
|
|
@ -15,6 +15,11 @@ py_library(
|
||||||
"dp_keras_model.py",
|
"dp_keras_model.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -22,5 +27,8 @@ py_test(
|
||||||
srcs = ["dp_keras_model_test.py"],
|
srcs = ["dp_keras_model_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = ["//tensorflow_privacy/privacy/keras_models:dp_keras_model"],
|
deps = [
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories",
|
||||||
|
"//tensorflow_privacy/privacy/keras_models:dp_keras_model",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,19 +13,38 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Keras Model for vectorized dpsgd with XLA acceleration."""
|
"""Keras Model for vectorized dpsgd with XLA acceleration."""
|
||||||
|
|
||||||
|
from absl import logging
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
|
||||||
|
|
||||||
def make_dp_model_class(cls):
|
def make_dp_model_class(cls):
|
||||||
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
||||||
|
|
||||||
class DPModelClass(cls): # pylint: disable=empty-docstring
|
class DPModelClass(cls): # pylint: disable=missing-class-docstring
|
||||||
__doc__ = ("""DP subclass of `{base_model}`.
|
__doc__ = (
|
||||||
|
"""DP subclass of `{base_model}`.
|
||||||
|
|
||||||
This can be used as a differentially private replacement for
|
This can be used as a differentially private replacement for
|
||||||
{base_model}. This class implements DP-SGD using the standard
|
{base_model}. This class implements DP-SGD using the standard
|
||||||
Gaussian mechanism.
|
Gaussian mechanism.
|
||||||
|
|
||||||
|
This class also utilizes a faster gradient clipping algorithm if the
|
||||||
|
following two conditions hold:
|
||||||
|
(i) the trainable layers of the model are keys in the `dict` input
|
||||||
|
`layer_registry`,
|
||||||
|
(ii) the loss `tf.Tensor` for a given batch of examples is either a
|
||||||
|
scalar or a 2D `tf.Tensor` that has only one column
|
||||||
|
`(i.e., tf.shape(loss)[1] == 1)` and whose i-th row corresponds to
|
||||||
|
the loss of the i-th example.
|
||||||
|
This clipping algorithm specifically computes clipped gradients at the
|
||||||
|
per-example level using the layer registry functions in `layer_registry`
|
||||||
|
(see clip_grads.py for more information about the algorithm). In this
|
||||||
|
setting, microbatching is not used (it is equivalent to
|
||||||
|
`num_microbatches == batch_size`), and the input `num_microbatches`
|
||||||
|
is ignored.
|
||||||
|
|
||||||
When instantiating this class, you need to supply several
|
When instantiating this class, you need to supply several
|
||||||
DP-related arguments followed by the standard arguments for
|
DP-related arguments followed by the standard arguments for
|
||||||
`{short_base_model}`.
|
`{short_base_model}`.
|
||||||
|
@ -53,10 +72,12 @@ def make_dp_model_class(cls):
|
||||||
model.fit(train_data, train_labels, epochs=1, batch_size=32)
|
model.fit(train_data, train_labels, epochs=1, batch_size=32)
|
||||||
```
|
```
|
||||||
|
|
||||||
""").format(
|
"""
|
||||||
base_model='tf.keras.' + cls.__name__,
|
).format(
|
||||||
short_base_model=cls.__name__,
|
base_model='tf.keras.' + cls.__name__,
|
||||||
dp_model_class='DP' + cls.__name__)
|
short_base_model=cls.__name__,
|
||||||
|
dp_model_class='DP' + cls.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -64,24 +85,31 @@ def make_dp_model_class(cls):
|
||||||
noise_multiplier,
|
noise_multiplier,
|
||||||
num_microbatches=None,
|
num_microbatches=None,
|
||||||
use_xla=True,
|
use_xla=True,
|
||||||
|
layer_registry=None,
|
||||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||||
**kwargs):
|
**kwargs,
|
||||||
|
):
|
||||||
"""Initializes the DPModelClass.
|
"""Initializes the DPModelClass.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
l2_norm_clip: Clipping norm (max L2 norm of per microbatch
|
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
|
||||||
gradients).
|
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||||
noise_multiplier: Ratio of the standard deviation to the clipping
|
|
||||||
norm.
|
|
||||||
num_microbatches: Number of microbatches.
|
num_microbatches: Number of microbatches.
|
||||||
use_xla: If `True`, compiles train_step to XLA.
|
use_xla: If `True`, compiles train_step to XLA.
|
||||||
|
layer_registry: A `dict` of layers that support "fast" gradient norm
|
||||||
|
computations. The key is the class of the layer and the value is a
|
||||||
|
function that returns a `tuple` `(output, sqr_grad_norms, vars)`,
|
||||||
|
where `output` is the pre-activator tensor, `sqr_grad_norms` is
|
||||||
|
related to the squared norms of a layer's pre-activation tensor, and
|
||||||
|
`vars` are relevant trainable weights (see
|
||||||
|
`layer_registry_factories.py` for examples).
|
||||||
*args: These will be passed on to the base class `__init__` method.
|
*args: These will be passed on to the base class `__init__` method.
|
||||||
**kwargs: These will be passed on to the base class `__init__`
|
**kwargs: These will be passed on to the base class `__init__` method.
|
||||||
method.
|
|
||||||
"""
|
"""
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._l2_norm_clip = l2_norm_clip
|
self._l2_norm_clip = l2_norm_clip
|
||||||
self._noise_multiplier = noise_multiplier
|
self._noise_multiplier = noise_multiplier
|
||||||
|
self._layer_registry = layer_registry
|
||||||
|
|
||||||
# Given that `num_microbatches` was added as an argument after the fact,
|
# Given that `num_microbatches` was added as an argument after the fact,
|
||||||
# this check helps detect unintended calls to the earlier API.
|
# this check helps detect unintended calls to the earlier API.
|
||||||
|
@ -91,7 +119,27 @@ def make_dp_model_class(cls):
|
||||||
raise ValueError('Boolean value supplied for `num_microbatches`. '
|
raise ValueError('Boolean value supplied for `num_microbatches`. '
|
||||||
'Did you intend it for `use_xla`?')
|
'Did you intend it for `use_xla`?')
|
||||||
|
|
||||||
self._num_microbatches = num_microbatches
|
# If all the trainable layers are in the input layer registry, we
|
||||||
|
# don't need to use microbatching and can instead use the "fast"
|
||||||
|
# chain rule trick for computing per-example gradients (peg).
|
||||||
|
if (
|
||||||
|
layer_registry is not None
|
||||||
|
and gradient_clipping_utils.all_trainable_layers_are_registered(
|
||||||
|
self, layer_registry
|
||||||
|
)
|
||||||
|
and gradient_clipping_utils.has_internal_compute_graph(self)
|
||||||
|
):
|
||||||
|
if num_microbatches is not None:
|
||||||
|
raise ValueError(
|
||||||
|
'Cannot initialize a model where num_microbatches '
|
||||||
|
'is not `None` and all trainable layers are '
|
||||||
|
'registered in layer_registry.'
|
||||||
|
)
|
||||||
|
self._num_microbatches = None
|
||||||
|
self._enable_fast_peg_computation = True
|
||||||
|
else:
|
||||||
|
self._num_microbatches = num_microbatches
|
||||||
|
self._enable_fast_peg_computation = False
|
||||||
|
|
||||||
if use_xla:
|
if use_xla:
|
||||||
self.train_step = tf.function(
|
self.train_step = tf.function(
|
||||||
|
@ -126,29 +174,72 @@ def make_dp_model_class(cls):
|
||||||
return y_pred, loss, clipped_grads
|
return y_pred, loss, clipped_grads
|
||||||
|
|
||||||
def train_step(self, data):
|
def train_step(self, data):
|
||||||
"""DP-SGD version of base class method."""
|
"""DP-SGD version of base class method.
|
||||||
_, y = data
|
|
||||||
batch_size = y.shape[0]
|
|
||||||
|
|
||||||
if self._num_microbatches is None:
|
Uses the "fast" gradient clipping algorithm to generate per-example
|
||||||
self._num_microbatches = batch_size
|
clipped gradients if (i) all the trainable layers of the model are
|
||||||
if batch_size % self._num_microbatches != 0:
|
registered in the layer_registry input of the model constructor and
|
||||||
raise ValueError('Number of_microbatches must divide batch size.')
|
(ii) if the model contains an internal compute graph (e.g., this
|
||||||
|
condition is satisfied if the model subclasses the keras.Sequential or
|
||||||
|
keras.engine.functional.Functional class).
|
||||||
|
|
||||||
def reshape_fn(x):
|
If (i) and (ii) above do not hold, then clips and aggregates
|
||||||
new_shape = (self._num_microbatches,
|
gradients at the microbatch level.
|
||||||
batch_size // self._num_microbatches) + x.shape[1:]
|
|
||||||
return tf.reshape(x, new_shape)
|
|
||||||
|
|
||||||
data = tf.nest.map_structure(reshape_fn, data)
|
Args:
|
||||||
|
data: see the base class.
|
||||||
|
|
||||||
y_pred, _, per_eg_grads = tf.vectorized_map(
|
Returns:
|
||||||
self._compute_per_example_grads, data)
|
See the base class.
|
||||||
|
"""
|
||||||
|
if self._enable_fast_peg_computation:
|
||||||
|
logging.info(
|
||||||
|
'Computing gradients using the fast per-example gradient '
|
||||||
|
'norm algorithm.'
|
||||||
|
)
|
||||||
|
# Computes the per-example gradient norms using a "fast" clipping
|
||||||
|
# trick, and uses these norms to clip the per-example gradients.
|
||||||
|
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||||
|
y_pred, clipped_grads = clip_grads.compute_pred_and_clipped_gradients(
|
||||||
|
self, x, y, self._l2_norm_clip, self._layer_registry
|
||||||
|
)
|
||||||
|
grads = gradient_clipping_utils.add_aggregate_noise(
|
||||||
|
self, x, clipped_grads, self._l2_norm_clip, self._noise_multiplier
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info('Computing gradients using microbatching.')
|
||||||
|
# Computes per-example clipped gradients directly. This is called
|
||||||
|
# if at least one of the layers cannot use the "fast" gradient clipping
|
||||||
|
# algorithm.
|
||||||
|
# TODO(wkong): check if the following is valid with sample weights.
|
||||||
|
_, y = data
|
||||||
|
batch_size = y.shape[0]
|
||||||
|
|
||||||
y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:])
|
if self._num_microbatches is None:
|
||||||
|
self._num_microbatches = batch_size
|
||||||
|
if batch_size % self._num_microbatches != 0:
|
||||||
|
raise ValueError('Number of_microbatches must divide batch size.')
|
||||||
|
|
||||||
grads = tf.nest.map_structure(self._reduce_per_example_grads,
|
def reshape_fn(x):
|
||||||
per_eg_grads)
|
new_shape = (
|
||||||
|
self._num_microbatches,
|
||||||
|
batch_size // self._num_microbatches,
|
||||||
|
) + x.shape[1:]
|
||||||
|
return tf.reshape(x, new_shape)
|
||||||
|
|
||||||
|
data = tf.nest.map_structure(reshape_fn, data)
|
||||||
|
|
||||||
|
y_pred, _, per_eg_grads = tf.vectorized_map(
|
||||||
|
self._compute_per_example_grads, data
|
||||||
|
)
|
||||||
|
|
||||||
|
y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:])
|
||||||
|
|
||||||
|
grads = tf.nest.map_structure(
|
||||||
|
self._reduce_per_example_grads, per_eg_grads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward the private gradients to the optimizer and return the results.
|
||||||
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
||||||
self.compiled_metrics.update_state(y, y_pred)
|
self.compiled_metrics.update_state(y, y_pred)
|
||||||
return {m.name: m.result() for m in self.metrics}
|
return {m.name: m.result() for m in self.metrics}
|
||||||
|
|
|
@ -13,10 +13,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories
|
||||||
from tensorflow_privacy.privacy.keras_models import dp_keras_model
|
from tensorflow_privacy.privacy.keras_models import dp_keras_model
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,6 +28,13 @@ def get_data():
|
||||||
return data, labels
|
return data, labels
|
||||||
|
|
||||||
|
|
||||||
|
def get_layer_registries():
|
||||||
|
# Outputs a list of testable layer registries.
|
||||||
|
# The empty registry {} tests the behavior of the standard approach,
|
||||||
|
# while the other one tests the fast gradient clipping algorithm.
|
||||||
|
return [{}, layer_registry_factories.make_default_layer_registry()]
|
||||||
|
|
||||||
|
|
||||||
class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def testBaseline(self):
|
def testBaseline(self):
|
||||||
|
@ -65,32 +71,35 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
"""Tests that clipping norm works."""
|
"""Tests that clipping norm works."""
|
||||||
train_data, train_labels = get_data()
|
train_data, train_labels = get_data()
|
||||||
|
|
||||||
# Simple linear model returns w * x + b.
|
for test_reg in get_layer_registries():
|
||||||
model = dp_keras_model.DPSequential(
|
# Simple linear model returns w * x + b.
|
||||||
l2_norm_clip=l2_norm_clip,
|
model = dp_keras_model.DPSequential(
|
||||||
noise_multiplier=0.0,
|
l2_norm_clip=l2_norm_clip,
|
||||||
layers=[
|
noise_multiplier=0.0,
|
||||||
tf.keras.layers.InputLayer(input_shape=(2,)),
|
layer_registry=test_reg,
|
||||||
tf.keras.layers.Dense(
|
layers=[
|
||||||
1, kernel_initializer='zeros', bias_initializer='zeros')
|
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||||
])
|
tf.keras.layers.Dense(
|
||||||
learning_rate = 0.01
|
1, kernel_initializer='zeros', bias_initializer='zeros'
|
||||||
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
),
|
||||||
loss = tf.keras.losses.MeanSquaredError()
|
],
|
||||||
|
)
|
||||||
|
learning_rate = 0.01
|
||||||
|
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||||
|
loss = tf.keras.losses.MeanSquaredError()
|
||||||
|
model.compile(optimizer=optimizer, loss=loss)
|
||||||
|
model.fit(train_data, train_labels, epochs=1, batch_size=1)
|
||||||
|
|
||||||
model.compile(optimizer=optimizer, loss=loss)
|
model_weights = model.get_weights()
|
||||||
model.fit(train_data, train_labels, epochs=1, batch_size=1)
|
|
||||||
|
|
||||||
model_weights = model.get_weights()
|
unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2)
|
||||||
|
scale = min(1.0, l2_norm_clip / unclipped_gradient)
|
||||||
|
expected_weights = np.array([[90], [120]]) * scale * learning_rate
|
||||||
|
expected_bias = np.array([30]) * scale * learning_rate
|
||||||
|
|
||||||
unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2)
|
# Check parameters are as expected, taking into account the learning rate.
|
||||||
scale = min(1.0, l2_norm_clip / unclipped_gradient)
|
self.assertAllClose(model_weights[0], expected_weights)
|
||||||
expected_weights = np.array([[90], [120]]) * scale * learning_rate
|
self.assertAllClose(model_weights[1], expected_bias)
|
||||||
expected_bias = np.array([30]) * scale * learning_rate
|
|
||||||
|
|
||||||
# Check parameters are as expected, taking into account the learning rate.
|
|
||||||
self.assertAllClose(model_weights[0], expected_weights)
|
|
||||||
self.assertAllClose(model_weights[1], expected_bias)
|
|
||||||
|
|
||||||
def _compute_expected_gradients(self, data, labels, w, l2_norm_clip,
|
def _compute_expected_gradients(self, data, labels, w, l2_norm_clip,
|
||||||
num_microbatches):
|
num_microbatches):
|
||||||
|
@ -98,9 +107,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
if num_microbatches is None:
|
if num_microbatches is None:
|
||||||
num_microbatches = batch_size
|
num_microbatches = batch_size
|
||||||
|
|
||||||
preds = np.matmul(data, w)
|
preds = np.matmul(data, np.expand_dims(w, axis=1))
|
||||||
|
|
||||||
|
grads = 2 * data * (preds - labels)
|
||||||
|
|
||||||
grads = 2 * data * (labels - preds)[:, np.newaxis]
|
|
||||||
grads = np.reshape(grads,
|
grads = np.reshape(grads,
|
||||||
[num_microbatches, batch_size // num_microbatches, -1])
|
[num_microbatches, batch_size // num_microbatches, -1])
|
||||||
|
|
||||||
|
@ -123,32 +133,45 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
def testMicrobatches(self, l2_norm_clip, num_microbatches):
|
def testMicrobatches(self, l2_norm_clip, num_microbatches):
|
||||||
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
|
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
|
||||||
w = np.zeros((2))
|
w = np.zeros((2))
|
||||||
train_labels = np.array([1.0, 3.0, -2.0, -4.0])
|
train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]])
|
||||||
learning_rate = 1.0
|
learning_rate = 1.0
|
||||||
|
|
||||||
expected_grads = self._compute_expected_gradients(train_data, train_labels,
|
for test_reg, test_nm in zip(
|
||||||
w, l2_norm_clip,
|
get_layer_registries(), [num_microbatches, None]
|
||||||
num_microbatches)
|
):
|
||||||
expected_weights = np.squeeze(learning_rate * expected_grads)
|
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||||
|
loss = tf.keras.losses.MeanSquaredError()
|
||||||
|
|
||||||
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
# Simple linear model returns w * x.
|
||||||
loss = tf.keras.losses.MeanSquaredError()
|
model = dp_keras_model.DPSequential(
|
||||||
|
l2_norm_clip=l2_norm_clip,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=test_nm,
|
||||||
|
layer_registry=test_reg,
|
||||||
|
layers=[
|
||||||
|
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||||
|
tf.keras.layers.Dense(
|
||||||
|
1, use_bias=False, kernel_initializer='zeros'
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model.compile(optimizer=optimizer, loss=loss)
|
||||||
|
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
|
||||||
|
|
||||||
# Simple linear model returns w * x + b.
|
model_weights = np.squeeze(model.get_weights())
|
||||||
model = dp_keras_model.DPSequential(
|
|
||||||
l2_norm_clip=l2_norm_clip,
|
|
||||||
noise_multiplier=0.0,
|
|
||||||
num_microbatches=num_microbatches,
|
|
||||||
layers=[
|
|
||||||
tf.keras.layers.InputLayer(input_shape=(2,)),
|
|
||||||
tf.keras.layers.Dense(
|
|
||||||
1, use_bias=False, kernel_initializer='zeros')
|
|
||||||
])
|
|
||||||
model.compile(optimizer=optimizer, loss=loss)
|
|
||||||
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
|
|
||||||
|
|
||||||
model_weights = np.squeeze(model.get_weights())
|
effective_num_microbatches = (
|
||||||
self.assertAllClose(model_weights, expected_weights)
|
train_data.shape[0]
|
||||||
|
if model._num_microbatches is None
|
||||||
|
else num_microbatches
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_grads = self._compute_expected_gradients(
|
||||||
|
train_data, train_labels, w, l2_norm_clip, effective_num_microbatches
|
||||||
|
)
|
||||||
|
expected_weights = np.squeeze(-learning_rate * expected_grads)
|
||||||
|
|
||||||
|
self.assertAllClose(model_weights, expected_weights)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('noise_multiplier 3 2 1', 3.0, 2.0, 1),
|
('noise_multiplier 3 2 1', 3.0, 2.0, 1),
|
||||||
|
@ -168,59 +191,81 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# Data is one example of length 1000, set to zero, with label zero.
|
# Data is one example of length 1000, set to zero, with label zero.
|
||||||
train_data = np.zeros((4, 1000))
|
train_data = np.zeros((4, 1000))
|
||||||
train_labels = np.array([0.0, 0.0, 0.0, 0.0])
|
train_labels = np.array([[0.0], [0.0], [0.0], [0.0]])
|
||||||
|
|
||||||
learning_rate = 1.0
|
learning_rate = 1.0
|
||||||
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
|
||||||
loss = tf.keras.losses.MeanSquaredError()
|
|
||||||
|
|
||||||
# Simple linear model returns w * x + b.
|
for test_reg, test_nm in zip(
|
||||||
model = dp_keras_model.DPSequential(
|
get_layer_registries(), [num_microbatches, None]
|
||||||
l2_norm_clip=l2_norm_clip,
|
):
|
||||||
noise_multiplier=noise_multiplier,
|
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||||
num_microbatches=num_microbatches,
|
loss = tf.keras.losses.MeanSquaredError()
|
||||||
layers=[
|
|
||||||
tf.keras.layers.InputLayer(input_shape=(1000,)),
|
|
||||||
tf.keras.layers.Dense(
|
|
||||||
1, kernel_initializer='zeros', bias_initializer='zeros')
|
|
||||||
])
|
|
||||||
model.compile(optimizer=optimizer, loss=loss)
|
|
||||||
model.fit(train_data, train_labels, epochs=1, batch_size=4)
|
|
||||||
|
|
||||||
model_weights = model.get_weights()
|
# Simple linear model returns w * x + b.
|
||||||
measured_std = np.std(model_weights[0])
|
model = dp_keras_model.DPSequential(
|
||||||
expected_std = l2_norm_clip * noise_multiplier / num_microbatches
|
l2_norm_clip=l2_norm_clip,
|
||||||
|
noise_multiplier=noise_multiplier,
|
||||||
|
num_microbatches=test_nm,
|
||||||
|
layer_registry=test_reg,
|
||||||
|
layers=[
|
||||||
|
tf.keras.layers.InputLayer(input_shape=(1000,)),
|
||||||
|
tf.keras.layers.Dense(
|
||||||
|
1, kernel_initializer='zeros', bias_initializer='zeros'
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model.compile(optimizer=optimizer, loss=loss)
|
||||||
|
model.fit(train_data, train_labels, epochs=1, batch_size=4)
|
||||||
|
|
||||||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
effective_num_microbatches = (
|
||||||
self.assertNear(measured_std, expected_std, 0.1 * expected_std)
|
train_data.shape[0]
|
||||||
|
if model._num_microbatches is None
|
||||||
|
else num_microbatches
|
||||||
|
)
|
||||||
|
|
||||||
|
model_weights = model.get_weights()
|
||||||
|
measured_std = np.std(model_weights[0])
|
||||||
|
expected_std = (
|
||||||
|
l2_norm_clip * noise_multiplier / effective_num_microbatches
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||||
|
self.assertNear(measured_std, expected_std, 0.1 * expected_std)
|
||||||
|
|
||||||
# Simple check to make sure dimensions are correct when output has
|
# Simple check to make sure dimensions are correct when output has
|
||||||
# dimension > 1.
|
# dimension > 1.
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('mb_test None 1', None, 1),
|
('mb_test None 2', None, 2),
|
||||||
('mb_test 1 2', 1, 2),
|
('mb_test 1 2', 1, 2),
|
||||||
('mb_test 2 2', 2, 2),
|
('mb_test 2 2', 2, 2),
|
||||||
('mb_test 4 4', 4, 4),
|
('mb_test 4 4', 4, 4),
|
||||||
)
|
)
|
||||||
def testMultiDimensionalOutput(self, num_microbatches, output_dimension):
|
def testMultiDimensionalOutput(self, num_microbatches, output_dimension):
|
||||||
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
|
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
|
||||||
train_labels = np.array([0, 1, 1, 0])
|
train_labels = np.array([[0], [1], [1], [0]])
|
||||||
learning_rate = 1.0
|
learning_rate = 1.0
|
||||||
|
|
||||||
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
for test_reg, test_nm in zip(
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
get_layer_registries(), [num_microbatches, None]
|
||||||
|
):
|
||||||
|
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||||
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
|
|
||||||
model = dp_keras_model.DPSequential(
|
model = dp_keras_model.DPSequential(
|
||||||
l2_norm_clip=1.0e9,
|
l2_norm_clip=1.0e9,
|
||||||
noise_multiplier=0.0,
|
noise_multiplier=0.0,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=test_nm,
|
||||||
layers=[
|
layer_registry=test_reg,
|
||||||
tf.keras.layers.InputLayer(input_shape=(2,)),
|
layers=[
|
||||||
tf.keras.layers.Dense(
|
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||||
output_dimension, use_bias=False, kernel_initializer='zeros')
|
tf.keras.layers.Dense(
|
||||||
])
|
output_dimension, use_bias=False, kernel_initializer='zeros'
|
||||||
model.compile(optimizer=optimizer, loss=loss_fn)
|
),
|
||||||
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
|
tf.keras.layers.Dense(1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model.compile(optimizer=optimizer, loss=loss_fn)
|
||||||
|
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
|
||||||
|
|
||||||
# Checks that calls to earlier API using `use_xla` as a positional argument
|
# Checks that calls to earlier API using `use_xla` as a positional argument
|
||||||
# raise an exception.
|
# raise an exception.
|
||||||
|
@ -237,8 +282,11 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
layers=[
|
layers=[
|
||||||
tf.keras.layers.InputLayer(input_shape=(2,)),
|
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||||
tf.keras.layers.Dense(
|
tf.keras.layers.Dense(
|
||||||
2, use_bias=False, kernel_initializer='zeros')
|
2, use_bias=False, kernel_initializer='zeros'
|
||||||
])
|
),
|
||||||
|
tf.keras.layers.Dense(1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue