diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index d325680..6514711 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -36,14 +36,20 @@ def compute_true_gradient_norms(input_model, x_batch, y_batch): sqr_norms = [] for var in input_model.trainable_variables: jacobian = tape.jacobian(loss, var, experimental_use_pfor=False) - reduction_axes = tf.range(1, tf.rank(jacobian)) + reduction_axes = tf.range(1, len(jacobian.shape)) sqr_norms.append(tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)) sqr_norm_tsr = tf.stack(sqr_norms, axis=1) return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) def get_computed_and_true_norms( - model_generator, layer_generator, input_dim, output_dim, is_eager, x_input + model_generator, + layer_generator, + input_dims, + output_dim, + is_eager, + x_input, + rng_seed=777, ): """Obtains the true and computed gradient norms for a model and batch input. @@ -58,10 +64,11 @@ def get_computed_and_true_norms( layer_generator: A function which takes in two arguments: `idim` and `odim`. Returns a `tf.keras.layers.Layer` that accepts input tensors of dimension `idim` and returns output tensors of dimension `odim`. - input_dim: The input dimension of the test `tf.keras.Model` instance. + input_dims: The input dimension(s) of the test `tf.keras.Model` instance. output_dim: The output dimension of the test `tf.keras.Model` instance. is_eager: A `bool` that is `True` if the model should be run eagerly. x_input: `tf.Tensor` inputs to be tested. + rng_seed: An `int` used to initialize model weights. Returns: A `tuple` `(computed_norm, true_norms)`. The first element contains the @@ -70,8 +77,7 @@ def get_computed_and_true_norms( model and layer generators. The second element contains the true clipped gradient norms under the aforementioned setting. """ - tf.keras.utils.set_random_seed(777) - model = model_generator(layer_generator, input_dim, output_dim) + model = model_generator(layer_generator, input_dims, output_dim) model.compile( optimizer=tf.keras.optimizers.SGD(learning_rate=1.0), loss=tf.keras.losses.MeanSquaredError( @@ -82,9 +88,11 @@ def get_computed_and_true_norms( y_pred = model(x_input) y_batch = tf.ones_like(y_pred) registry = layer_registry.make_default_layer_registry() + tf.keras.utils.set_random_seed(rng_seed) computed_norms = clip_grads.compute_gradient_norms( model, x_input, y_batch, layer_registry=registry ) + tf.keras.utils.set_random_seed(rng_seed) true_norms = compute_true_gradient_norms(model, x_input, y_batch) return (computed_norms, true_norms) @@ -140,27 +148,58 @@ def make_two_tower_model(layer_generator, input_dim, output_dim): return tf.keras.Model(inputs=[inputs1, inputs2], outputs=outputs) -def make_bow_model(layer_generator, input_dim, output_dim): +def make_bow_model(layer_generator, input_dims, output_dim): del layer_generator - inputs = tf.keras.Input(shape=(input_dim,)) + inputs = tf.keras.Input(shape=input_dims) # For the Embedding layer, input_dim is the vocabulary size. This should # be distinguished from the input_dim argument, which is the number of ids # in eache example. emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim) feature_embs = emb_layer(inputs) - example_embs = tf.reduce_sum(feature_embs, axis=1) + reduction_axes = tf.range(1, len(feature_embs.shape)) + example_embs = tf.expand_dims( + tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1 + ) return tf.keras.Model(inputs=inputs, outputs=example_embs) -def make_dense_bow_model(layer_generator, input_dim, output_dim): +def make_dense_bow_model(layer_generator, input_dims, output_dim): del layer_generator - inputs = tf.keras.Input(shape=(input_dim,)) + inputs = tf.keras.Input(shape=input_dims) # For the Embedding layer, input_dim is the vocabulary size. This should # be distinguished from the input_dim argument, which is the number of ids # in eache example. - emb_layer = tf.keras.layers.Embedding(input_dim=10, output_dim=output_dim) + cardinality = 10 + emb_layer = tf.keras.layers.Embedding( + input_dim=cardinality, output_dim=output_dim + ) feature_embs = emb_layer(inputs) - example_embs = tf.reduce_sum(feature_embs, axis=1) + reduction_axes = tf.range(1, len(feature_embs.shape)) + example_embs = tf.expand_dims( + tf.reduce_sum(feature_embs, axis=reduction_axes), axis=-1 + ) + outputs = tf.keras.layers.Dense(1)(example_embs) + return tf.keras.Model(inputs=inputs, outputs=outputs) + + +def make_weighted_bow_model(layer_generator, input_dims, output_dim): + # NOTE: This model only accepts dense input tensors. + del layer_generator + inputs = tf.keras.Input(shape=input_dims) + # For the Embedding layer, input_dim is the vocabulary size. This should + # be distinguished from the input_dim argument, which is the number of ids + # in eache example. + cardinality = 10 + emb_layer = tf.keras.layers.Embedding( + input_dim=cardinality, output_dim=output_dim + ) + feature_embs = emb_layer(inputs) + feature_weights = tf.random.uniform(tf.shape(feature_embs)) + weighted_embs = feature_embs * feature_weights + reduction_axes = tf.range(1, len(weighted_embs.shape)) + example_embs = tf.expand_dims( + tf.reduce_sum(weighted_embs, axis=reduction_axes), axis=-1 + ) outputs = tf.keras.layers.Dense(1)(example_embs) return tf.keras.Model(inputs=inputs, outputs=outputs) @@ -211,6 +250,7 @@ def get_embedding_model_generators(): return { 'bow1': make_bow_model, 'bow2': make_dense_bow_model, + 'weighted_bow1': make_weighted_bow_model, } @@ -267,10 +307,20 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): # supports them for embeddings. @parameterized.product( x_batch=[ - tf.convert_to_tensor([[0, 1], [1, 0]], dtype_hint=tf.int32), + # 2D inputs. + tf.convert_to_tensor([[0, 1]], dtype_hint=tf.int32), tf.convert_to_tensor([[0, 1], [1, 1], [0, 0]], dtype_hint=tf.int32), tf.ragged.constant( - [[0], [1], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32 + [[0], [1], [], [0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.int32 + ), + # 3D inputs. + tf.convert_to_tensor([[[0, 1]]], dtype_hint=tf.int32), + tf.convert_to_tensor( + [[[0, 1]], [[1, 1]], [[0, 0]]], dtype_hint=tf.int32 + ), + tf.ragged.constant( + [[[0]], [[1]], [], [[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]], + dtype=tf.int32, ), ], model_name=list(get_embedding_model_generators().keys()), @@ -280,16 +330,21 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase): def test_gradient_norms_on_various_models( self, x_batch, model_name, output_dim, is_eager ): - model_generator = get_embedding_model_generators()[model_name] - (computed_norms, true_norms) = get_computed_and_true_norms( - model_generator=model_generator, - layer_generator=None, - input_dim=2, - output_dim=output_dim, - is_eager=is_eager, - x_input=x_batch, - ) - self.assertAllClose(computed_norms, true_norms) + valid_test_input = ( + not isinstance(x_batch, tf.RaggedTensor) + and model_name == 'weighted_bow1' + ) or (model_name != 'weighted_bow1') + if valid_test_input: + model_generator = get_embedding_model_generators()[model_name] + (computed_norms, true_norms) = get_computed_and_true_norms( + model_generator=model_generator, + layer_generator=None, + input_dims=x_batch.shape[1:], + output_dim=output_dim, + is_eager=is_eager, + x_input=x_batch, + ) + self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) if __name__ == '__main__':