diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index 3b88819..dba90f5 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -2,11 +2,38 @@ load("@rules_python//python:defs.bzl", "py_library", "py_test") package(default_visibility = ["//visibility:public"]) +py_library( + name = "type_aliases", + srcs = ["type_aliases.py"], + srcs_version = "PY3", +) + +py_library( + name = "common_manip_utils", + srcs = ["common_manip_utils.py"], + srcs_version = "PY3", + deps = [":type_aliases"], +) + +py_library( + name = "common_test_utils", + srcs = ["common_test_utils.py"], + srcs_version = "PY3", + deps = [ + ":clip_grads", + ":layer_registry", + ":type_aliases", + ], +) + py_library( name = "gradient_clipping_utils", srcs = ["gradient_clipping_utils.py"], srcs_version = "PY3", - deps = [":layer_registry"], + deps = [ + ":layer_registry", + ":type_aliases", + ], ) py_test( @@ -14,13 +41,21 @@ py_test( srcs = ["gradient_clipping_utils_test.py"], python_version = "PY3", srcs_version = "PY3", - deps = [":gradient_clipping_utils"], + deps = [ + ":gradient_clipping_utils", + ":type_aliases", + ], ) py_library( name = "layer_registry", srcs = ["layer_registry.py"], srcs_version = "PY3", + deps = [ + ":type_aliases", + "//tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions:dense", + "//tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions:embedding", + ], ) py_library( @@ -28,20 +63,22 @@ py_library( srcs = ["clip_grads.py"], srcs_version = "PY3", deps = [ + ":common_manip_utils", ":gradient_clipping_utils", ":layer_registry", + ":type_aliases", ], ) py_test( name = "clip_grads_test", - size = "large", srcs = ["clip_grads_test.py"], python_version = "PY3", - shard_count = 8, srcs_version = "PY3", deps = [ ":clip_grads", + ":common_test_utils", ":layer_registry", + ":type_aliases", ], ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index d1aa42d..5bcf99a 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -21,20 +21,19 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the `compute_gradient_norms()` function). """ -from typing import Callable, Dict, Iterable, List, Optional, Text, Tuple, Union +from typing import List, Optional, Tuple import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr - -InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] -LossFn = Callable[..., tf.Tensor] +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases def get_registry_generator_fn( tape: tf.GradientTape, layer_registry: lr.LayerRegistry, - num_microbatches: Optional[lr.BatchSize] = None, + num_microbatches: Optional[type_aliases.BatchSize] = None, ): """Creates the generator function for `compute_gradient_norms()`.""" if layer_registry is None: @@ -70,11 +69,11 @@ def get_registry_generator_fn( def compute_gradient_norms( input_model: tf.keras.Model, layer_registry: lr.LayerRegistry, - x_batch: InputTensor, + x_batch: type_aliases.InputTensors, y_batch: tf.Tensor, weight_batch: Optional[tf.Tensor] = None, - per_example_loss_fn: Optional[LossFn] = None, - num_microbatches: Optional[lr.BatchSize] = None, + per_example_loss_fn: Optional[type_aliases.LossFn] = None, + num_microbatches: Optional[type_aliases.BatchSize] = None, trainable_vars: Optional[List[tf.Variable]] = None, ): """Computes the per-example loss gradient norms for given data. @@ -147,7 +146,10 @@ def compute_gradient_norms( ) if num_microbatches is not None: losses = tf.reduce_mean( - lr.maybe_add_microbatch_axis(losses, num_microbatches), axis=1 + common_manip_utils.maybe_add_microbatch_axis( + losses, num_microbatches + ), + axis=1, ) summed_loss = tf.reduce_sum(losses) # Unwrap the generator outputs so that the next loop avoids duplicating @@ -212,11 +214,11 @@ def compute_clipped_gradients_and_outputs( input_model: tf.keras.Model, l2_norm_clip: float, layer_registry: lr.LayerRegistry, - x_batch: InputTensor, + x_batch: type_aliases.InputTensors, y_batch: tf.Tensor, weight_batch: Optional[tf.Tensor] = None, - num_microbatches: Optional[lr.BatchSize] = None, - clipping_loss: Optional[LossFn] = None, + num_microbatches: Optional[type_aliases.BatchSize] = None, + clipping_loss: Optional[type_aliases.LossFn] = None, ) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]: """Computes the per-example clipped loss gradient and other useful outputs. @@ -287,7 +289,9 @@ def compute_clipped_gradients_and_outputs( # c is computed based on the gradient of w*l, so that if we scale w*l by c, # the result has bounded per-example gradients. So the loss to optimize is # c*w*l. Here we compute c*w before passing it to the loss. - weight_batch = lr.maybe_add_microbatch_axis(weight_batch, num_microbatches) + weight_batch = common_manip_utils.maybe_add_microbatch_axis( + weight_batch, num_microbatches + ) if num_microbatches is None: clip_weights = clip_weights * weight_batch # shape [num_microbatches] else: @@ -303,8 +307,12 @@ def compute_clipped_gradients_and_outputs( # is not defined in the contract so may not hold, especially for # custom losses. y_pred = input_model(x_batch, training=True) - mb_y_batch = lr.maybe_add_microbatch_axis(y_batch, num_microbatches) - mb_y_pred = lr.maybe_add_microbatch_axis(y_pred, num_microbatches) + mb_y_batch = common_manip_utils.maybe_add_microbatch_axis( + y_batch, num_microbatches + ) + mb_y_pred = common_manip_utils.maybe_add_microbatch_axis( + y_pred, num_microbatches + ) clipping_loss_value = clipping_loss(mb_y_batch, mb_y_pred, clip_weights) clipped_grads = tape.gradient( clipping_loss_value, diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index 1483448..2e12a1c 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -12,22 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union +from typing import Any, Dict, Optional, Text, Tuple from absl.testing import parameterized import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry - - -# ============================================================================== -# Type aliases -# ============================================================================== -LayerGenerator = Callable[[int, int], tf.keras.layers.Layer] - -ModelGenerator = Callable[ - [LayerGenerator, Union[int, List[int]], int], tf.keras.Model -] +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases # ============================================================================== @@ -52,7 +44,7 @@ def double_dense_layer_computation( input_kwargs: Dict[Text, Any], tape: tf.GradientTape, num_microbatches: Optional[int], -) -> layer_registry.RegistryFunctionOutput: +) -> type_aliases.RegistryFunctionOutput: """Layer registry function for the custom `DoubleDense` layer class.""" vars1, outputs, sqr_norm_fn1 = layer_registry.dense_layer_computation( layer_instance.dense1, input_args, input_kwargs, tape, num_microbatches @@ -69,304 +61,17 @@ def double_dense_layer_computation( return [vars1, vars2], outputs, sqr_norm_fn -def test_loss_fn( - x: tf.Tensor, y: tf.Tensor, weights: Optional[tf.Tensor] = None -) -> tf.Tensor: - # Define a loss function which is unlikely to be coincidently defined. - if weights is None: - weights = 1.0 - loss = 3.14 * tf.reduce_sum( - tf.cast(weights, tf.float32) * tf.square(x - y), axis=1 - ) - return loss - - -def compute_true_gradient_norms( - input_model: tf.keras.Model, - x_batch: tf.Tensor, - y_batch: tf.Tensor, - weight_batch: Optional[tf.Tensor], - per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], - num_microbatches: Optional[int], - trainable_vars: Optional[tf.Variable] = None, -) -> layer_registry.OutputTensor: - """Computes the real gradient norms for an input `(model, x, y)`.""" - if per_example_loss_fn is None: - loss_config = input_model.loss.get_config() - loss_config['reduction'] = tf.keras.losses.Reduction.NONE - per_example_loss_fn = input_model.loss.from_config(loss_config) - with tf.GradientTape(persistent=True) as tape: - y_pred = input_model(x_batch) - loss = per_example_loss_fn(y_batch, y_pred, weight_batch) - if num_microbatches is not None: - loss = tf.reduce_mean( - tf.reshape( - loss, - tf.concat([[num_microbatches, -1], tf.shape(loss)[1:]], axis=0), - ), - axis=1, - ) - if isinstance(loss, tf.RaggedTensor): - loss = loss.to_tensor() - sqr_norms = [] - trainable_vars = trainable_vars or input_model.trainable_variables - for var in trainable_vars: - jacobian = tape.jacobian(loss, var, experimental_use_pfor=False) - reduction_axes = tf.range(1, len(jacobian.shape)) - sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes) - sqr_norms.append(sqr_norm) - sqr_norm_tsr = tf.stack(sqr_norms, axis=1) - return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) - - -def get_computed_and_true_norms( - model_generator: ModelGenerator, - layer_generator: LayerGenerator, - input_dims: Union[int, List[int]], - output_dim: int, - per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], - num_microbatches: Optional[int], - is_eager: bool, - x_batch: tf.Tensor, - weight_batch: Optional[tf.Tensor] = None, - rng_seed: int = 777, - registry: layer_registry.LayerRegistry = None, - partial: bool = False, -) -> Tuple[tf.Tensor, tf.Tensor]: - """Obtains the true and computed gradient norms for a model and batch input. - - Helpful testing wrapper function used to avoid code duplication. - - Args: - model_generator: A function which takes in three arguments: - `layer_generator`, `idim`, and `odim`. Returns a `tf.keras.Model` that - accepts input tensors of dimension `idim` and returns output tensors of - dimension `odim`. Layers of the model are based on the `layer_generator` - (see below for its description). - layer_generator: A function which takes in two arguments: `idim` and `odim`. - Returns a `tf.keras.layers.Layer` that accepts input tensors of dimension - `idim` and returns output tensors of dimension `odim`. - input_dims: The input dimension(s) of the test `tf.keras.Model` instance. - output_dim: The output dimension of the test `tf.keras.Model` instance. - per_example_loss_fn: If not None, used as vectorized per example loss - function. - num_microbatches: The number of microbatches. None or an integer. - is_eager: whether the model should be run eagerly. - x_batch: inputs to be tested. - weight_batch: optional weights passed to the loss. - rng_seed: used as a seed for random initialization. - registry: required for fast clipping. - partial: Whether to compute the gradient norm with respect to a partial set - of varibles. If True, only consider the variables in the first layer. - - Returns: - A `tuple` `(computed_norm, true_norms)`. The first element contains the - clipped gradient norms that are generated by - `clip_grads.compute_gradient_norms()` under the setting given by the given - model and layer generators. The second element contains the true clipped - gradient norms under the aforementioned setting. - """ - model = model_generator(layer_generator, input_dims, output_dim) - model.compile( - optimizer=tf.keras.optimizers.SGD(learning_rate=1.0), - loss=tf.keras.losses.MeanSquaredError( - reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE - ), - run_eagerly=is_eager, - ) - trainable_vars = None - if partial: - # Gets the first layer with variables. - for l in model.layers: - trainable_vars = l.trainable_variables - if trainable_vars: - break - y_pred = model(x_batch) - y_batch = tf.ones_like(y_pred) - tf.keras.utils.set_random_seed(rng_seed) - computed_norms = clip_grads.compute_gradient_norms( - input_model=model, - x_batch=x_batch, - y_batch=y_batch, - weight_batch=weight_batch, - layer_registry=registry, - per_example_loss_fn=per_example_loss_fn, - num_microbatches=num_microbatches, - trainable_vars=trainable_vars, - ) - tf.keras.utils.set_random_seed(rng_seed) - true_norms = compute_true_gradient_norms( - model, - x_batch, - y_batch, - weight_batch, - per_example_loss_fn, - num_microbatches, - trainable_vars=trainable_vars, - ) - return (computed_norms, true_norms) - - -# ============================================================================== -# Model generators. -# ============================================================================== -def make_two_layer_sequential_model(layer_generator, input_dim, output_dim): - """Creates a 2-layer sequential model.""" - model = tf.keras.Sequential() - model.add(tf.keras.Input(shape=(input_dim,))) - model.add(layer_generator(input_dim, output_dim)) - model.add(tf.keras.layers.Dense(1)) - return model - - -def make_three_layer_sequential_model(layer_generator, input_dim, output_dim): - """Creates a 3-layer sequential model.""" - model = tf.keras.Sequential() - model.add(tf.keras.Input(shape=(input_dim,))) - layer1 = layer_generator(input_dim, output_dim) - model.add(layer1) - if isinstance(layer1, tf.keras.layers.Embedding): - # Having multiple consecutive embedding layers does not make sense since - # embedding layers only map integers to real-valued vectors. - model.add(tf.keras.layers.Dense(output_dim)) - else: - model.add(layer_generator(output_dim, output_dim)) - model.add(tf.keras.layers.Dense(1)) - return model - - -def make_two_layer_functional_model(layer_generator, input_dim, output_dim): - """Creates a 2-layer 1-input functional model with a pre-output square op.""" - inputs = tf.keras.Input(shape=(input_dim,)) - layer1 = layer_generator(input_dim, output_dim) - temp1 = layer1(inputs) - temp2 = tf.square(temp1) - outputs = tf.keras.layers.Dense(1)(temp2) - return tf.keras.Model(inputs=inputs, outputs=outputs) - - -def make_two_tower_model(layer_generator, input_dim, output_dim): - """Creates a 2-layer 2-input functional model.""" - inputs1 = tf.keras.Input(shape=(input_dim,)) - layer1 = layer_generator(input_dim, output_dim) - temp1 = layer1(inputs1) - inputs2 = tf.keras.Input(shape=(input_dim,)) - layer2 = layer_generator(input_dim, output_dim) - temp2 = layer2(inputs2) - temp3 = tf.add(temp1, temp2) - outputs = tf.keras.layers.Dense(1)(temp3) - return tf.keras.Model(inputs=[inputs1, inputs2], outputs=outputs) - - -def make_bow_model(layer_generator, input_dims, output_dim): - del layer_generator - inputs = tf.keras.Input(shape=input_dims) - # For the Embedding layer, input_dim is the vocabulary size. This should - # be distinguished from the input_dim argument, which is the number of ids - # in eache example. - emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim) - feature_embs = emb_layer(inputs) - reduction_axes = tf.range(1, len(feature_embs.shape)) - example_embs = tf.expand_dims( - tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1 - ) - return tf.keras.Model(inputs=inputs, outputs=example_embs) - - -def make_dense_bow_model(layer_generator, input_dims, output_dim): - del layer_generator - inputs = tf.keras.Input(shape=input_dims) - # For the Embedding layer, input_dim is the vocabulary size. This should - # be distinguished from the input_dim argument, which is the number of ids - # in eache example. - cardinality = 10 - emb_layer = tf.keras.layers.Embedding( - input_dim=cardinality, output_dim=output_dim - ) - feature_embs = emb_layer(inputs) - reduction_axes = tf.range(1, len(feature_embs.shape)) - example_embs = tf.expand_dims( - tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1 - ) - outputs = tf.keras.layers.Dense(1)(example_embs) - return tf.keras.Model(inputs=inputs, outputs=outputs) - - -def make_weighted_bow_model(layer_generator, input_dims, output_dim): - # NOTE: This model only accepts dense input tensors. - del layer_generator - inputs = tf.keras.Input(shape=input_dims) - # For the Embedding layer, input_dim is the vocabulary size. This should - # be distinguished from the input_dim argument, which is the number of ids - # in eache example. - cardinality = 10 - emb_layer = tf.keras.layers.Embedding( - input_dim=cardinality, output_dim=output_dim - ) - feature_embs = emb_layer(inputs) - feature_weights = tf.random.uniform(tf.shape(feature_embs)) - weighted_embs = feature_embs * feature_weights - reduction_axes = tf.range(1, len(weighted_embs.shape)) - example_embs = tf.expand_dims( - tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1 - ) - outputs = tf.keras.layers.Dense(1)(example_embs) - return tf.keras.Model(inputs=inputs, outputs=outputs) - - -# ============================================================================== -# Factory functions. -# ============================================================================== -def get_nd_test_batches(n: int): - """Returns a list of input batches of dimension n.""" - # The first two batches have a single element, the last batch has 2 elements. - x0 = tf.zeros([1, n], dtype=tf.float64) - x1 = tf.constant([range(n)], dtype=tf.float64) - x2 = tf.concat([x0, x1], axis=0) - w0 = tf.constant([1], dtype=tf.float64) - w1 = tf.constant([2], dtype=tf.float64) - w2 = tf.constant([0.5, 0.5], dtype=tf.float64) - return [x0, x1, x2], [w0, w1, w2] - - -def get_dense_layer_generators(): - def sigmoid_dense_layer(b): - return tf.keras.layers.Dense(b, activation='sigmoid') - - return { - 'pure_dense': lambda a, b: tf.keras.layers.Dense(b), - 'sigmoid_dense': lambda a, b: sigmoid_dense_layer(b), - } - - -def get_dense_model_generators(): - return { - 'seq1': make_two_layer_sequential_model, - 'seq2': make_three_layer_sequential_model, - 'func1': make_two_layer_functional_model, - 'tower1': make_two_tower_model, - } - - -def get_embedding_model_generators(): - return { - 'bow1': make_bow_model, - 'bow2': make_dense_bow_model, - 'weighted_bow1': make_weighted_bow_model, - } - - # ============================================================================== # Main tests. # ============================================================================== -class ClipGradsDirectTest(tf.test.TestCase, parameterized.TestCase): +class DirectWeightsTest(tf.test.TestCase, parameterized.TestCase): @parameterized.product( input_dim=[1, 2], clip_value=[1e-6, 0.5, 1.0, 2.0, 10.0, 1e6] ) def test_clip_weights(self, input_dim, clip_value): tol = 1e-6 - ts, _ = get_nd_test_batches(input_dim) + ts, _ = common_test_utils.get_nd_test_batches(input_dim) for t in ts: weights = clip_grads.compute_clip_weights(clip_value, t) self.assertAllLessEqual(t * weights, clip_value + tol) @@ -375,136 +80,12 @@ class ClipGradsDirectTest(tf.test.TestCase, parameterized.TestCase): self.assertIsNone(clip_grads.compute_clip_weights(None, tf.ones(3))) -class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): - - @parameterized.product( - model_name=list(get_dense_model_generators().keys()), - layer_name=list(get_dense_layer_generators().keys()), - input_dim=[4], - output_dim=[2], - per_example_loss_fn=[None, test_loss_fn], - num_microbatches=[None, 1, 2], - is_eager=[True, False], - partial=[True, False], - weighted=[True, False], - ) - def test_gradient_norms_on_various_models( - self, - model_name, - layer_name, - input_dim, - output_dim, - per_example_loss_fn, - num_microbatches, - is_eager, - partial, - weighted, - ): - model_generator = get_dense_model_generators()[model_name] - layer_generator = get_dense_layer_generators()[layer_name] - x_batches, weight_batches = get_nd_test_batches(input_dim) - default_registry = layer_registry.make_default_layer_registry() - for x_batch, weight_batch in zip(x_batches, weight_batches): - batch_size = x_batch.shape[0] - if num_microbatches is not None and batch_size % num_microbatches != 0: - continue - (computed_norms, true_norms) = get_computed_and_true_norms( - model_generator, - layer_generator, - input_dim, - output_dim, - per_example_loss_fn, - num_microbatches, - is_eager, - x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch, - weight_batch=weight_batch if weighted else None, - registry=default_registry, - partial=partial, - ) - expected_size = num_microbatches or batch_size - self.assertEqual(computed_norms.shape[0], expected_size) - self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) - - -class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): - - # TODO(weiweikong): Test sparse input tensors when the GitHub CI environment - # supports them for embeddings. - @parameterized.product( - x_batch=[ - # 2D inputs. - tf.convert_to_tensor([[0, 1]], dtype_hint=tf.int32), - tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32), - tf.ragged.constant( - [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32 - ), - tf.ragged.constant( - [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]], - dtype=tf.int32, - ), - # 3D inputs. - tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32), - tf.convert_to_tensor( - [[[0, 1]], [[1, 1]], [[0, 0]]], dtype_hint=tf.int32 - ), - tf.ragged.constant( - [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]], - dtype=tf.int32, - ), - tf.ragged.constant( - [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]], - dtype=tf.int32, - ), - ], - model_name=list(get_embedding_model_generators().keys()), - output_dim=[2], - per_example_loss_fn=[None, test_loss_fn], - num_microbatches=[None, 2], - is_eager=[True, False], - partial=[True, False], - ) - def test_gradient_norms_on_various_models( - self, - x_batch, - model_name, - output_dim, - per_example_loss_fn, - num_microbatches, - is_eager, - partial, - ): - batch_size = x_batch.shape[0] - # The following are invalid test combinations, and are skipped. - if ( - num_microbatches is not None and batch_size % num_microbatches != 0 - ) or ( - model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor) - ): - return - default_registry = layer_registry.make_default_layer_registry() - model_generator = get_embedding_model_generators()[model_name] - (computed_norms, true_norms) = get_computed_and_true_norms( - model_generator=model_generator, - layer_generator=None, - input_dims=x_batch.shape[1:], - output_dim=output_dim, - per_example_loss_fn=per_example_loss_fn, - num_microbatches=num_microbatches, - is_eager=is_eager, - x_batch=x_batch, - registry=default_registry, - partial=partial, - ) - self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size) - self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) - - -class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): +class CustomLayerTest(tf.test.TestCase, parameterized.TestCase): @parameterized.product( input_dim=[3], output_dim=[2], - per_example_loss_fn=[None, test_loss_fn], + per_example_loss_fn=[None, common_test_utils.test_loss_fn], num_microbatches=[None, 2], is_eager=[True, False], partial=[True, False], @@ -522,29 +103,31 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): ): registry = layer_registry.make_default_layer_registry() registry.insert(DoubleDense, double_dense_layer_computation) - x_batches, weight_batches = get_nd_test_batches(input_dim) + x_batches, weight_batches = common_test_utils.get_nd_test_batches(input_dim) for x_batch, weight_batch in zip(x_batches, weight_batches): batch_size = x_batch.shape[0] if num_microbatches is not None and batch_size % num_microbatches != 0: continue - (computed_norms, true_norms) = get_computed_and_true_norms( - model_generator=make_two_layer_sequential_model, - layer_generator=lambda a, b: DoubleDense(b), - input_dims=input_dim, - output_dim=output_dim, - per_example_loss_fn=per_example_loss_fn, - num_microbatches=num_microbatches, - is_eager=is_eager, - x_batch=x_batch, - weight_batch=weight_batch if weighted else None, - registry=registry, - partial=partial, + (computed_norms, true_norms) = ( + common_test_utils.get_computed_and_true_norms( + model_generator=common_test_utils.make_two_layer_sequential_model, + layer_generator=lambda a, b: DoubleDense(b), + input_dims=input_dim, + output_dim=output_dim, + per_example_loss_fn=per_example_loss_fn, + num_microbatches=num_microbatches, + is_eager=is_eager, + x_batch=x_batch, + weight_batch=weight_batch if weighted else None, + registry=registry, + partial=partial, + ) ) self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) -class ClipGradsComputeClippedGradsAndOutputsTest( +class ComputeClippedGradsAndOutputsTest( tf.test.TestCase, parameterized.TestCase ): @@ -553,7 +136,7 @@ class ClipGradsComputeClippedGradsAndOutputsTest( dense_generator = lambda a, b: tf.keras.layers.Dense(b) self._input_dim = 2 self._output_dim = 3 - self._model = make_two_layer_sequential_model( + self._model = common_test_utils.make_two_layer_sequential_model( dense_generator, self._input_dim, self._output_dim ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/common_manip_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/common_manip_utils.py new file mode 100644 index 0000000..db05f0a --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/common_manip_utils.py @@ -0,0 +1,44 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A collection of common utility functions for tensor/data manipulation.""" + +from typing import Optional + +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases + + +def maybe_add_microbatch_axis( + x: tf.Tensor, + num_microbatches: Optional[type_aliases.BatchSize], +) -> tf.Tensor: + """Adds the microbatch axis. + + Args: + x: the input tensor. + num_microbatches: If None, x is returned unchanged. Otherwise, must divide + the batch size. + + Returns: + The input tensor x, reshaped from [batch_size, ...] to + [num_microbatches, batch_size / num_microbatches, ...]. + """ + if num_microbatches is None: + return x + with tf.control_dependencies( + [tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)] + ): + return tf.reshape( + x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0) + ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/common_test_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/common_test_utils.py new file mode 100644 index 0000000..c680407 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/common_test_utils.py @@ -0,0 +1,284 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A collection of common utility functions for unit testing.""" + +from typing import Callable, List, Optional, Tuple, Union + +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases + + +# ============================================================================== +# Helper functions +# ============================================================================== +def get_nd_test_batches(n: int): + """Returns a list of input batches of dimension n.""" + # The first two batches have a single element, the last batch has 2 elements. + x0 = tf.zeros([1, n], dtype=tf.float64) + x1 = tf.constant([range(n)], dtype=tf.float64) + x2 = tf.concat([x0, x1], axis=0) + w0 = tf.constant([1], dtype=tf.float64) + w1 = tf.constant([2], dtype=tf.float64) + w2 = tf.constant([0.5, 0.5], dtype=tf.float64) + return [x0, x1, x2], [w0, w1, w2] + + +def test_loss_fn( + x: tf.Tensor, y: tf.Tensor, weights: Optional[tf.Tensor] = None +) -> tf.Tensor: + # Define a loss function which is unlikely to be coincidently defined. + if weights is None: + weights = 1.0 + loss = 3.14 * tf.reduce_sum( + tf.cast(weights, tf.float32) * tf.square(x - y), axis=1 + ) + return loss + + +def compute_true_gradient_norms( + input_model: tf.keras.Model, + x_batch: tf.Tensor, + y_batch: tf.Tensor, + weight_batch: Optional[tf.Tensor], + per_example_loss_fn: Optional[type_aliases.LossFn], + num_microbatches: Optional[int], + trainable_vars: Optional[tf.Variable] = None, +) -> type_aliases.OutputTensors: + """Computes the real gradient norms for an input `(model, x, y)`.""" + if per_example_loss_fn is None: + loss_config = input_model.loss.get_config() + loss_config['reduction'] = tf.keras.losses.Reduction.NONE + per_example_loss_fn = input_model.loss.from_config(loss_config) + with tf.GradientTape(persistent=True) as tape: + y_pred = input_model(x_batch) + loss = per_example_loss_fn(y_batch, y_pred, weight_batch) + if num_microbatches is not None: + loss = tf.reduce_mean( + tf.reshape( + loss, + tf.concat([[num_microbatches, -1], tf.shape(loss)[1:]], axis=0), + ), + axis=1, + ) + if isinstance(loss, tf.RaggedTensor): + loss = loss.to_tensor() + sqr_norms = [] + trainable_vars = trainable_vars or input_model.trainable_variables + for var in trainable_vars: + jacobian = tape.jacobian(loss, var, experimental_use_pfor=False) + reduction_axes = tf.range(1, len(jacobian.shape)) + sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes) + sqr_norms.append(sqr_norm) + sqr_norm_tsr = tf.stack(sqr_norms, axis=1) + return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) + + +def get_computed_and_true_norms( + model_generator: type_aliases.ModelGenerator, + layer_generator: type_aliases.LayerGenerator, + input_dims: Union[int, List[int]], + output_dim: int, + per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], + num_microbatches: Optional[int], + is_eager: bool, + x_batch: tf.Tensor, + weight_batch: Optional[tf.Tensor] = None, + rng_seed: int = 777, + registry: layer_registry.LayerRegistry = None, + partial: bool = False, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Obtains the true and computed gradient norms for a model and batch input. + + Helpful testing wrapper function used to avoid code duplication. + + Args: + model_generator: A function which takes in three arguments: + `layer_generator`, `idim`, and `odim`. Returns a `tf.keras.Model` that + accepts input tensors of dimension `idim` and returns output tensors of + dimension `odim`. Layers of the model are based on the `layer_generator` + (see below for its description). + layer_generator: A function which takes in two arguments: `idim` and `odim`. + Returns a `tf.keras.layers.Layer` that accepts input tensors of dimension + `idim` and returns output tensors of dimension `odim`. + input_dims: The input dimension(s) of the test `tf.keras.Model` instance. + output_dim: The output dimension of the test `tf.keras.Model` instance. + per_example_loss_fn: If not None, used as vectorized per example loss + function. + num_microbatches: The number of microbatches. None or an integer. + is_eager: whether the model should be run eagerly. + x_batch: inputs to be tested. + weight_batch: optional weights passed to the loss. + rng_seed: used as a seed for random initialization. + registry: required for fast clipping. + partial: Whether to compute the gradient norm with respect to a partial set + of varibles. If True, only consider the variables in the first layer. + + Returns: + A `tuple` `(computed_norm, true_norms)`. The first element contains the + clipped gradient norms that are generated by + `clip_grads.compute_gradient_norms()` under the setting given by the given + model and layer generators. The second element contains the true clipped + gradient norms under the aforementioned setting. + """ + model = model_generator(layer_generator, input_dims, output_dim) + model.compile( + optimizer=tf.keras.optimizers.SGD(learning_rate=1.0), + loss=tf.keras.losses.MeanSquaredError( + reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE + ), + run_eagerly=is_eager, + ) + trainable_vars = None + if partial: + # Gets the first layer with variables. + for l in model.layers: + trainable_vars = l.trainable_variables + if trainable_vars: + break + y_pred = model(x_batch) + y_batch = tf.ones_like(y_pred) + tf.keras.utils.set_random_seed(rng_seed) + computed_norms = clip_grads.compute_gradient_norms( + input_model=model, + x_batch=x_batch, + y_batch=y_batch, + weight_batch=weight_batch, + layer_registry=registry, + per_example_loss_fn=per_example_loss_fn, + num_microbatches=num_microbatches, + trainable_vars=trainable_vars, + ) + tf.keras.utils.set_random_seed(rng_seed) + true_norms = compute_true_gradient_norms( + model, + x_batch, + y_batch, + weight_batch, + per_example_loss_fn, + num_microbatches, + trainable_vars=trainable_vars, + ) + return (computed_norms, true_norms) + + +# ============================================================================== +# Model generators. +# ============================================================================== +def make_two_layer_sequential_model(layer_generator, input_dim, output_dim): + """Creates a 2-layer sequential model.""" + model = tf.keras.Sequential() + model.add(tf.keras.Input(shape=(input_dim,))) + model.add(layer_generator(input_dim, output_dim)) + model.add(tf.keras.layers.Dense(1)) + return model + + +def make_three_layer_sequential_model(layer_generator, input_dim, output_dim): + """Creates a 3-layer sequential model.""" + model = tf.keras.Sequential() + model.add(tf.keras.Input(shape=(input_dim,))) + layer1 = layer_generator(input_dim, output_dim) + model.add(layer1) + if isinstance(layer1, tf.keras.layers.Embedding): + # Having multiple consecutive embedding layers does not make sense since + # embedding layers only map integers to real-valued vectors. + model.add(tf.keras.layers.Dense(output_dim)) + else: + model.add(layer_generator(output_dim, output_dim)) + model.add(tf.keras.layers.Dense(1)) + return model + + +def make_two_layer_functional_model(layer_generator, input_dim, output_dim): + """Creates a 2-layer 1-input functional model with a pre-output square op.""" + inputs = tf.keras.Input(shape=(input_dim,)) + layer1 = layer_generator(input_dim, output_dim) + temp1 = layer1(inputs) + temp2 = tf.square(temp1) + outputs = tf.keras.layers.Dense(1)(temp2) + return tf.keras.Model(inputs=inputs, outputs=outputs) + + +def make_two_tower_model(layer_generator, input_dim, output_dim): + """Creates a 2-layer 2-input functional model.""" + inputs1 = tf.keras.Input(shape=(input_dim,)) + layer1 = layer_generator(input_dim, output_dim) + temp1 = layer1(inputs1) + inputs2 = tf.keras.Input(shape=(input_dim,)) + layer2 = layer_generator(input_dim, output_dim) + temp2 = layer2(inputs2) + temp3 = tf.add(temp1, temp2) + outputs = tf.keras.layers.Dense(1)(temp3) + return tf.keras.Model(inputs=[inputs1, inputs2], outputs=outputs) + + +def make_bow_model(layer_generator, input_dims, output_dim): + """Creates a simple embedding bow model.""" + del layer_generator + inputs = tf.keras.Input(shape=input_dims) + # For the Embedding layer, input_dim is the vocabulary size. This should + # be distinguished from the input_dim argument, which is the number of ids + # in eache example. + emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim) + feature_embs = emb_layer(inputs) + reduction_axes = tf.range(1, len(feature_embs.shape)) + example_embs = tf.expand_dims( + tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1 + ) + return tf.keras.Model(inputs=inputs, outputs=example_embs) + + +def make_dense_bow_model(layer_generator, input_dims, output_dim): + """Creates an embedding bow model with a `Dense` layer.""" + del layer_generator + inputs = tf.keras.Input(shape=input_dims) + # For the Embedding layer, input_dim is the vocabulary size. This should + # be distinguished from the input_dim argument, which is the number of ids + # in eache example. + cardinality = 10 + emb_layer = tf.keras.layers.Embedding( + input_dim=cardinality, output_dim=output_dim + ) + feature_embs = emb_layer(inputs) + reduction_axes = tf.range(1, len(feature_embs.shape)) + example_embs = tf.expand_dims( + tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1 + ) + outputs = tf.keras.layers.Dense(1)(example_embs) + return tf.keras.Model(inputs=inputs, outputs=outputs) + + +def make_weighted_bow_model(layer_generator, input_dims, output_dim): + """Creates a weighted embedding bow model.""" + # NOTE: This model only accepts dense input tensors. + del layer_generator + inputs = tf.keras.Input(shape=input_dims) + # For the Embedding layer, input_dim is the vocabulary size. This should + # be distinguished from the input_dim argument, which is the number of ids + # in eache example. + cardinality = 10 + emb_layer = tf.keras.layers.Embedding( + input_dim=cardinality, output_dim=output_dim + ) + feature_embs = emb_layer(inputs) + feature_weights = tf.random.uniform(tf.shape(feature_embs)) + weighted_embs = feature_embs * feature_weights + reduction_axes = tf.range(1, len(weighted_embs.shape)) + example_embs = tf.expand_dims( + tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1 + ) + outputs = tf.keras.layers.Dense(1)(example_embs) + return tf.keras.Model(inputs=inputs, outputs=outputs) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 64046e7..6564f9d 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -13,16 +13,12 @@ # limitations under the License. """Utility functions that help in the computation of per-example gradient norms.""" -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Text, Tuple, Union +from typing import Any, List, Optional, Set, Tuple from absl import logging import tensorflow as tf - from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr - -PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] - -GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]] +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases def has_internal_compute_graph(input_object: Any): @@ -38,9 +34,9 @@ def has_internal_compute_graph(input_object: Any): def model_forward_pass( input_model: tf.keras.Model, - inputs: PackedTensors, - generator_fn: GeneratorFunction = None, -) -> Tuple[PackedTensors, List[Any]]: + inputs: type_aliases.PackedTensors, + generator_fn: type_aliases.GeneratorFunction = None, +) -> Tuple[type_aliases.PackedTensors, List[Any]]: """Does a forward pass of a model and returns useful intermediates. NOTE: the graph traversal algorithm is an adaptation of the logic in the @@ -192,7 +188,7 @@ def add_aggregate_noise( def generate_model_outputs_using_core_keras_layers( input_model: tf.keras.Model, custom_layer_set: Optional[Set[type]] = None, # pylint: disable=g-bare-generic -) -> PackedTensors: +) -> type_aliases.PackedTensors: """Returns the model outputs generated by only core Keras layers. Args: diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py index 838b4f7..47ab328 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py @@ -49,27 +49,12 @@ aggregation at the microbatch level. # The detailed algorithm can be found in go/fast-dpsgd-mb. # copybara.strip_end -from typing import Any, Callable, Dict, Iterable, Optional, Text, Tuple, Type, Union +from typing import Type + import tensorflow as tf - - -# ============================================================================== -# Type aliases -# ============================================================================== -InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] - -OutputTensor = Union[tf.Tensor, Iterable[tf.Tensor]] - -BatchSize = Union[int, tf.Tensor] - -SquareNormFunction = Callable[[OutputTensor], tf.Tensor] - -RegistryFunctionOutput = Tuple[Any, OutputTensor, SquareNormFunction] - -RegistryFunction = Callable[ - [Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape], - RegistryFunctionOutput, -] +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions.dense import dense_layer_computation +from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions.embedding import embedding_layer_computation # ============================================================================== @@ -87,14 +72,16 @@ class LayerRegistry: """Checks if a layer instance's class is in the registry.""" return hash(layer_instance.__class__) in self._registry - def lookup(self, layer_instance: tf.keras.layers.Layer) -> RegistryFunction: + def lookup( + self, layer_instance: tf.keras.layers.Layer + ) -> type_aliases.RegistryFunction: """Returns the layer registry function for a given layer instance.""" return self._registry[hash(layer_instance.__class__)] def insert( self, layer_class: Type[tf.keras.layers.Layer], - layer_registry_function: RegistryFunction, + layer_registry_function: type_aliases.RegistryFunction, ): """Inserts a layer registry function into the internal dictionaries.""" layer_key = hash(layer_class) @@ -102,222 +89,6 @@ class LayerRegistry: self._registry[layer_key] = layer_registry_function -# ============================================================================== -# Utilities -# ============================================================================== -def maybe_add_microbatch_axis( - x: tf.Tensor, - num_microbatches: Optional[BatchSize], -) -> tf.Tensor: - """Adds the microbatch axis. - - Args: - x: the input tensor. - num_microbatches: If None, x is returned unchanged. Otherwise, must divide - the batch size. - - Returns: - The input tensor x, reshaped from [batch_size, ...] to - [num_microbatches, batch_size / num_microbatches, ...]. - """ - if num_microbatches is None: - return x - with tf.control_dependencies( - [tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)] - ): - return tf.reshape( - x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0) - ) - - -# ============================================================================== -# Supported Keras layers -# ============================================================================== -def dense_layer_computation( - layer_instance: tf.keras.layers.Dense, - input_args: Tuple[Any, ...], - input_kwargs: Dict[Text, Any], - tape: tf.GradientTape, - num_microbatches: Optional[tf.Tensor] = None, -) -> RegistryFunctionOutput: - """Registry function for `tf.keras.layers.Dense`. - - The logic for this computation is based on the following paper: - https://arxiv.org/abs/1510.01799 - - For the sake of efficiency, we fuse the variables and square grad norms - for the kernel weights and bias vector together. - - Args: - layer_instance: A `tf.keras.layers.Dense` instance. - input_args: A `tuple` containing the first part of `layer_instance` input. - Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return - a valid output. - input_kwargs: A `tuple` containing the second part of `layer_instance` - input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should - return a valid output. - tape: A `tf.GradientTape` instance that will be used to watch the output - `base_vars`. - num_microbatches: An optional numeric value or scalar `tf.Tensor` for - indicating whether and how the losses are grouped into microbatches. If - not None, num_microbatches must divide the batch size. - - Returns: - A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the - intermediate Tensor used in the chain-rule / "fast" clipping trick, - `outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is - a function that takes one input, a `tf.Tensor` that represents the output - of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a - `tf.GradientTape` instance that records the dense layer computation and - `summed_loss` is the sum of the per-example losses of the underlying model. - This function then returns the per-example squared L2 gradient norms of the - trainable variables in `layer_instance`. These squared norms should be a 1D - `tf.Tensor` of length `batch_size`. - """ - if input_kwargs: - raise ValueError("Dense layer calls should not receive kwargs.") - del input_kwargs # Unused in dense layer calls. - if len(input_args) != 1: - raise ValueError("Only layer inputs of length 1 are permitted.") - orig_activation = layer_instance.activation - layer_instance.activation = None - base_vars = layer_instance(*input_args) - tape.watch(base_vars) - layer_instance.activation = orig_activation - outputs = orig_activation(base_vars) if orig_activation else base_vars - - def sqr_norm_fn(base_vars_grads): - - def _compute_gramian(x): - if num_microbatches is not None: - x_microbatched = maybe_add_microbatch_axis(x, num_microbatches) - return tf.matmul(x_microbatched, x_microbatched, transpose_b=True) - else: - # Special handling for better efficiency - return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x))) - - inputs_gram = _compute_gramian(*input_args) - base_vars_grads_gram = _compute_gramian(base_vars_grads) - if layer_instance.use_bias: - # Adding a bias term is equivalent to a layer with no bias term and which - # adds an additional variable to the layer input that only takes a - # constant value of 1.0. This is thus equivalent to adding 1.0 to the sum - # of the squared values of the inputs. - inputs_gram += 1.0 - return tf.reduce_sum( - inputs_gram * base_vars_grads_gram, - axis=tf.range(1, tf.rank(inputs_gram)), - ) - - return base_vars, outputs, sqr_norm_fn - - -def embedding_layer_computation( - layer_instance: tf.keras.layers.Embedding, - input_args: Tuple[Any, ...], - input_kwargs: Dict[Text, Any], - tape: tf.GradientTape, - num_microbatches: Optional[tf.Tensor] = None, -) -> RegistryFunctionOutput: - """Registry function for `tf.keras.layers.Embedding`. - - The logic of this computation is based on the `tf.keras.layers.Dense` - computation and the fact that an embedding layer is just a dense layer - with no activation function and an output vector of the form X*W for input - X, where the i-th row of W is the i-th embedding vector and the j-th row of - X is a one-hot vector representing the input of example j. - - Args: - layer_instance: A `tf.keras.layers.Embedding` instance. - input_args: See `dense_layer_computation()`. - input_kwargs: See `dense_layer_computation()`. - tape: See `dense_layer_computation()`. - num_microbatches: See `dense_layer_computation()`. - - Returns: - See `dense_layer_computation()`. - """ - if input_kwargs: - raise ValueError("Embedding layer calls should not receive kwargs.") - del input_kwargs # Unused in embedding layer calls. - if len(input_args) != 1: - raise ValueError("Only layer inputs of length 1 are permitted.") - if hasattr(layer_instance, "sparse"): # for backwards compatibility - if layer_instance.sparse: - raise NotImplementedError("Sparse output tensors are not supported.") - if isinstance(input_args[0], tf.SparseTensor): - raise NotImplementedError("Sparse input tensors are not supported.") - - # Disable experimental features. - if hasattr(layer_instance, "_use_one_hot_matmul"): - if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access - raise NotImplementedError( - "The experimental embedding feature" - "'_use_one_hot_matmul' is not supported." - ) - input_ids = tf.cast(*input_args, tf.int32) - base_vars = layer_instance.trainable_variables[0] - tape.watch(base_vars) - outputs = tf.nn.embedding_lookup(base_vars, input_ids) - - def sqr_norm_fn(base_vars_grads): - # Get a 1D tensor of the row indices. - nrows = tf.shape(input_ids)[0] - if isinstance(input_ids, tf.RaggedTensor): - row_indices = tf.expand_dims( - input_ids.merge_dims(1, -1).value_rowids(), axis=-1 - ) - elif isinstance(input_ids, tf.Tensor): - ncols = tf.reduce_prod(tf.shape(input_ids)[1:]) - repeats = tf.repeat(ncols, nrows) - row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1]) - else: - raise NotImplementedError( - "Cannot parse input_ids of type %s" % input_ids.__class__.__name__ - ) - row_indices = tf.cast(row_indices, tf.int32) - if num_microbatches is not None: - microbatch_size = tf.cast(nrows / num_microbatches, tf.int32) - nrows = num_microbatches - row_indices = tf.cast( - tf.math.floordiv(row_indices, microbatch_size), tf.int32 - ) - # Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()` - # call. The sum is reduced by the repeated embedding indices and batch - # index. It is adapted from the logic in: - # tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices - if not isinstance(base_vars_grads, tf.IndexedSlices): - raise NotImplementedError( - "Cannot parse embedding gradients of type: %s" - % base_vars_grads.__class__.__name__ - ) - slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1) - paired_indices = tf.concat( - [tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)], - axis=1, - ) - (unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2( - x=paired_indices, axis=[0] - ) - unique_batch_ids = unique_paired_indices[:, 0] - summed_gradients = tf.math.unsorted_segment_sum( - base_vars_grads.values, - new_index_positions, - tf.shape(unique_paired_indices)[0], - ) - # Compute the squared gradient norms at the per-example level. - sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1) - summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0]) - return tf.sparse.segment_sum( - sqr_gradient_sum, - summed_data_range, - tf.sort(unique_batch_ids), - num_segments=nrows, - ) # fill in empty inputs - - return base_vars, outputs, sqr_norm_fn - - # ============================================================================== # Main factory methods # ============================================================================== diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD new file mode 100644 index 0000000..a377287 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD @@ -0,0 +1,50 @@ +load("@rules_python//python:defs.bzl", "py_library", "py_test") + +package( + default_visibility = ["//visibility:public"], +) + +py_library( + name = "dense", + srcs = ["dense.py"], + srcs_version = "PY3", + deps = [ + "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases", + ], +) + +py_test( + name = "dense_test", + size = "large", + srcs = ["dense_test.py"], + python_version = "PY3", + shard_count = 8, + srcs_version = "PY3", + deps = [ + "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", + "//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", + ], +) + +py_library( + name = "embedding", + srcs = ["embedding.py"], + srcs_version = "PY3", + deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"], +) + +py_test( + name = "embedding_test", + size = "large", + srcs = ["embedding_test.py"], + python_version = "PY3", + shard_count = 8, + srcs_version = "PY3", + deps = [ + "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", + "//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", + ], +) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py new file mode 100644 index 0000000..aa61e3b --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense.py @@ -0,0 +1,100 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast clipping function for `tf.keras.layers.Dense`.""" + +from typing import Any, Dict, Optional, Text, Tuple +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases + + +def dense_layer_computation( + layer_instance: tf.keras.layers.Dense, + input_args: Tuple[Any, ...], + input_kwargs: Dict[Text, Any], + tape: tf.GradientTape, + num_microbatches: Optional[tf.Tensor] = None, +) -> type_aliases.RegistryFunctionOutput: + """Registry function for `tf.keras.layers.Dense`. + + The logic for this computation is based on the following paper: + https://arxiv.org/abs/1510.01799 + + For the sake of efficiency, we fuse the variables and square grad norms + for the kernel weights and bias vector together. + + Args: + layer_instance: A `tf.keras.layers.Dense` instance. + input_args: A `tuple` containing the first part of `layer_instance` input. + Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return + a valid output. + input_kwargs: A `tuple` containing the second part of `layer_instance` + input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should + return a valid output. + tape: A `tf.GradientTape` instance that will be used to watch the output + `base_vars`. + num_microbatches: An optional numeric value or scalar `tf.Tensor` for + indicating whether and how the losses are grouped into microbatches. If + not None, num_microbatches must divide the batch size. + + Returns: + A `tuple` `(base_vars, outputs, sqr_norm_fn)`. `base_vars` is the + intermediate Tensor used in the chain-rule / "fast" clipping trick, + `outputs` is the result of `layer_instance(*inputs)`, and `sqr_norm_fn` is + a function that takes one input, a `tf.Tensor` that represents the output + of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a + `tf.GradientTape` instance that records the dense layer computation and + `summed_loss` is the sum of the per-example losses of the underlying model. + This function then returns the per-example squared L2 gradient norms of the + trainable variables in `layer_instance`. These squared norms should be a 1D + `tf.Tensor` of length `batch_size`. + """ + if input_kwargs: + raise ValueError("Dense layer calls should not receive kwargs.") + del input_kwargs # Unused in dense layer calls. + if len(input_args) != 1: + raise ValueError("Only layer inputs of length 1 are permitted.") + orig_activation = layer_instance.activation + layer_instance.activation = None + base_vars = layer_instance(*input_args) + tape.watch(base_vars) + layer_instance.activation = orig_activation + outputs = orig_activation(base_vars) if orig_activation else base_vars + + def sqr_norm_fn(base_vars_grads): + def _compute_gramian(x): + if num_microbatches is not None: + x_microbatched = common_manip_utils.maybe_add_microbatch_axis( + x, + num_microbatches, + ) + return tf.matmul(x_microbatched, x_microbatched, transpose_b=True) + else: + # Special handling for better efficiency + return tf.reduce_sum(tf.square(x), axis=tf.range(1, tf.rank(x))) + + inputs_gram = _compute_gramian(*input_args) + base_vars_grads_gram = _compute_gramian(base_vars_grads) + if layer_instance.use_bias: + # Adding a bias term is equivalent to a layer with no bias term and which + # adds an additional variable to the layer input that only takes a + # constant value of 1.0. This is thus equivalent to adding 1.0 to the sum + # of the squared values of the inputs. + inputs_gram += 1.0 + return tf.reduce_sum( + inputs_gram * base_vars_grads_gram, + axis=tf.range(1, tf.rank(inputs_gram)), + ) + + return base_vars, outputs, sqr_norm_fn diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_test.py new file mode 100644 index 0000000..0f4451c --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_test.py @@ -0,0 +1,100 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry + + +# ============================================================================== +# Helper functions. +# ============================================================================== +def get_dense_layer_generators(): + def sigmoid_dense_layer(b): + return tf.keras.layers.Dense(b, activation='sigmoid') + + return { + 'pure_dense': lambda a, b: tf.keras.layers.Dense(b), + 'sigmoid_dense': lambda a, b: sigmoid_dense_layer(b), + } + + +def get_dense_model_generators(): + return { + 'seq1': common_test_utils.make_two_layer_sequential_model, + 'seq2': common_test_utils.make_three_layer_sequential_model, + 'func1': common_test_utils.make_two_layer_functional_model, + 'tower1': common_test_utils.make_two_tower_model, + } + + +# ============================================================================== +# Main tests. +# ============================================================================== +class GradNormTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + model_name=list(get_dense_model_generators().keys()), + layer_name=list(get_dense_layer_generators().keys()), + input_dim=[4], + output_dim=[2], + per_example_loss_fn=[None, common_test_utils.test_loss_fn], + num_microbatches=[None, 1, 2], + is_eager=[True, False], + partial=[True, False], + weighted=[True, False], + ) + def test_gradient_norms_on_various_models( + self, + model_name, + layer_name, + input_dim, + output_dim, + per_example_loss_fn, + num_microbatches, + is_eager, + partial, + weighted, + ): + model_generator = get_dense_model_generators()[model_name] + layer_generator = get_dense_layer_generators()[layer_name] + x_batches, weight_batches = common_test_utils.get_nd_test_batches(input_dim) + default_registry = layer_registry.make_default_layer_registry() + for x_batch, weight_batch in zip(x_batches, weight_batches): + batch_size = x_batch.shape[0] + if num_microbatches is not None and batch_size % num_microbatches != 0: + continue + computed_norms, true_norms = ( + common_test_utils.get_computed_and_true_norms( + model_generator, + layer_generator, + input_dim, + output_dim, + per_example_loss_fn, + num_microbatches, + is_eager, + x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch, + weight_batch=weight_batch if weighted else None, + registry=default_registry, + partial=partial, + ) + ) + expected_size = num_microbatches or batch_size + self.assertEqual(computed_norms.shape[0], expected_size) + self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py new file mode 100644 index 0000000..2b0887b --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding.py @@ -0,0 +1,124 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast clipping function for `tf.keras.layers.Embedding`.""" + +from typing import Any, Dict, Optional, Text, Tuple +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases + + +def embedding_layer_computation( + layer_instance: tf.keras.layers.Embedding, + input_args: Tuple[Any, ...], + input_kwargs: Dict[Text, Any], + tape: tf.GradientTape, + num_microbatches: Optional[tf.Tensor] = None, +) -> type_aliases.RegistryFunctionOutput: + """Registry function for `tf.keras.layers.Embedding`. + + The logic of this computation is based on the `tf.keras.layers.Dense` + computation and the fact that an embedding layer is just a dense layer + with no activation function and an output vector of the form X*W for input + X, where the i-th row of W is the i-th embedding vector and the j-th row of + X is a one-hot vector representing the input of example j. + + Args: + layer_instance: A `tf.keras.layers.Embedding` instance. + input_args: See `dense_layer_computation()` in `dense.py`. + input_kwargs: See `dense_layer_computation()` in `dense.py`. + tape: See `dense_layer_computation()` in `dense.py`. + num_microbatches: See `dense_layer_computation()` in `dense.py`. + + Returns: + See `dense_layer_computation()` in `dense.py`. + """ + if input_kwargs: + raise ValueError("Embedding layer calls should not receive kwargs.") + del input_kwargs # Unused in embedding layer calls. + if len(input_args) != 1: + raise ValueError("Only layer inputs of length 1 are permitted.") + if hasattr(layer_instance, "sparse"): # for backwards compatibility + if layer_instance.sparse: + raise NotImplementedError("Sparse output tensors are not supported.") + if isinstance(input_args[0], tf.SparseTensor): + raise NotImplementedError("Sparse input tensors are not supported.") + + # Disable experimental features. + if hasattr(layer_instance, "_use_one_hot_matmul"): + if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access + raise NotImplementedError( + "The experimental embedding feature" + "'_use_one_hot_matmul' is not supported." + ) + input_ids = tf.cast(*input_args, tf.int32) + base_vars = layer_instance.trainable_variables[0] + tape.watch(base_vars) + outputs = tf.nn.embedding_lookup(base_vars, input_ids) + + def sqr_norm_fn(base_vars_grads): + # Get a 1D tensor of the row indices. + nrows = tf.shape(input_ids)[0] + if isinstance(input_ids, tf.RaggedTensor): + row_indices = tf.expand_dims( + input_ids.merge_dims(1, -1).value_rowids(), axis=-1 + ) + elif isinstance(input_ids, tf.Tensor): + ncols = tf.reduce_prod(tf.shape(input_ids)[1:]) + repeats = tf.repeat(ncols, nrows) + row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1]) + else: + raise NotImplementedError( + "Cannot parse input_ids of type %s" % input_ids.__class__.__name__ + ) + row_indices = tf.cast(row_indices, tf.int32) + if num_microbatches is not None: + microbatch_size = tf.cast(nrows / num_microbatches, tf.int32) + nrows = num_microbatches + row_indices = tf.cast( + tf.math.floordiv(row_indices, microbatch_size), tf.int32 + ) + # Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()` + # call. The sum is reduced by the repeated embedding indices and batch + # index. It is adapted from the logic in: + # tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices + if not isinstance(base_vars_grads, tf.IndexedSlices): + raise NotImplementedError( + "Cannot parse embedding gradients of type: %s" + % base_vars_grads.__class__.__name__ + ) + slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1) + paired_indices = tf.concat( + [tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)], + axis=1, + ) + (unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2( + x=paired_indices, axis=[0] + ) + unique_batch_ids = unique_paired_indices[:, 0] + summed_gradients = tf.math.unsorted_segment_sum( + base_vars_grads.values, + new_index_positions, + tf.shape(unique_paired_indices)[0], + ) + # Compute the squared gradient norms at the per-example level. + sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1) + summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0]) + return tf.sparse.segment_sum( + sqr_gradient_sum, + summed_data_range, + tf.sort(unique_batch_ids), + num_segments=nrows, + ) # fill in empty inputs + + return base_vars, outputs, sqr_norm_fn diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py new file mode 100644 index 0000000..c818afa --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py @@ -0,0 +1,111 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry + + +# ============================================================================== +# Helper functions. +# ============================================================================== +def get_embedding_model_generators(): + return { + 'bow1': common_test_utils.make_bow_model, + 'bow2': common_test_utils.make_dense_bow_model, + 'weighted_bow1': common_test_utils.make_weighted_bow_model, + } + + +# ============================================================================== +# Main tests. +# ============================================================================== +class GradNormTest(tf.test.TestCase, parameterized.TestCase): + + # TODO(weiweikong): Test sparse input tensors when the GitHub CI environment + # supports them for embeddings. + @parameterized.product( + x_batch=[ + # 2D inputs. + tf.convert_to_tensor([[0, 1]], dtype_hint=tf.int32), + tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32), + tf.ragged.constant( + [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32 + ), + tf.ragged.constant( + [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1], [0, 1]], + dtype=tf.int32, + ), + # 3D inputs. + tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32), + tf.convert_to_tensor( + [[[0, 1]], [[1, 1]], [[0, 0]]], dtype_hint=tf.int32 + ), + tf.ragged.constant( + [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]], + dtype=tf.int32, + ), + tf.ragged.constant( + [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]], [[0]]], + dtype=tf.int32, + ), + ], + model_name=list(get_embedding_model_generators().keys()), + output_dim=[2], + per_example_loss_fn=[None, common_test_utils.test_loss_fn], + num_microbatches=[None, 2], + is_eager=[True, False], + partial=[True, False], + ) + def test_gradient_norms_on_various_models( + self, + x_batch, + model_name, + output_dim, + per_example_loss_fn, + num_microbatches, + is_eager, + partial, + ): + batch_size = x_batch.shape[0] + # The following are invalid test combinations, and are skipped. + if ( + num_microbatches is not None and batch_size % num_microbatches != 0 + ) or ( + model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor) + ): + return + default_registry = layer_registry.make_default_layer_registry() + model_generator = get_embedding_model_generators()[model_name] + computed_norms, true_norms = ( + common_test_utils.get_computed_and_true_norms( + model_generator=model_generator, + layer_generator=None, + input_dims=x_batch.shape[1:], + output_dim=output_dim, + per_example_loss_fn=per_example_loss_fn, + num_microbatches=num_microbatches, + is_eager=is_eager, + x_batch=x_batch, + registry=default_registry, + partial=partial, + ) + ) + self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size) + self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py new file mode 100644 index 0000000..e8b7f91 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py @@ -0,0 +1,49 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A collection of type aliases used throughout the clipping library.""" + +from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union +import tensorflow as tf + + +# Tensorflow aliases. +PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] + +InputTensors = PackedTensors + +OutputTensors = Union[tf.Tensor, Iterable[tf.Tensor]] + +BatchSize = Union[int, tf.Tensor] + +LossFn = Callable[..., tf.Tensor] + +# Layer Registry aliases. +SquareNormFunction = Callable[[OutputTensors], tf.Tensor] + +RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction] + +RegistryFunction = Callable[ + [Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape], + RegistryFunctionOutput, +] + +# Clipping aliases. +GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]] + +# Testing aliases. +LayerGenerator = Callable[[int, int], tf.keras.layers.Layer] + +ModelGenerator = Callable[ + [LayerGenerator, Union[int, List[int]], int], tf.keras.Model +] diff --git a/tensorflow_privacy/privacy/keras_models/BUILD b/tensorflow_privacy/privacy/keras_models/BUILD index 2ce8eba..cb493b9 100644 --- a/tensorflow_privacy/privacy/keras_models/BUILD +++ b/tensorflow_privacy/privacy/keras_models/BUILD @@ -17,6 +17,7 @@ py_library( srcs_version = "PY3", deps = [ "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", + "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", ], diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 80132df..29cb7d5 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -16,8 +16,8 @@ from absl import logging import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads +from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils -from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr _PRIVATIZED_LOSS_NAME = 'privatized_loss' @@ -143,7 +143,8 @@ def make_dp_model_class(cls): if ( layer_registry is not None and gradient_clipping_utils.all_trainable_layers_are_registered( - self, layer_registry + self, + layer_registry, ) and gradient_clipping_utils.has_internal_compute_graph(self) ): @@ -273,10 +274,16 @@ def make_dp_model_class(cls): grads = clipped_grads else: logging.info('Computing gradients using original clipping algorithm.') + # Computes per-example clipped gradients directly. This is called # if at least one of the layers cannot use the "fast" gradient clipping # algorithm. - reshape_fn = lambda z: lr.maybe_add_microbatch_axis(z, num_microbatches) + def reshape_fn(z): + return common_manip_utils.maybe_add_microbatch_axis( + z, + num_microbatches, + ) + microbatched_data = tf.nest.map_structure(reshape_fn, data) clipped_grads = tf.vectorized_map( self._compute_per_example_grads,