diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD b/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD index 8698b51..a112f83 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD @@ -5,12 +5,19 @@ licenses(["notice"]) py_library( name = "sparse_noise_utils", srcs = ["sparse_noise_utils.py"], + deps = [ + ":type_aliases", + "//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils", + ], ) py_test( name = "sparse_noise_utils_test", srcs = ["sparse_noise_utils_test.py"], - deps = [":sparse_noise_utils"], + deps = [ + ":sparse_noise_utils", + "//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils", + ], ) py_library( diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py index 839a559..9830150 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py @@ -16,10 +16,13 @@ For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357. """ +import collections from typing import Mapping, Optional, Sequence from scipy import stats import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils +from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases import tensorflow_probability as tfp @@ -288,15 +291,60 @@ def add_sparse_gradient_noise( ) +def extract_varname_to_contribution_counts_fns( + registry_fn_outputs_list: Sequence[ + gradient_clipping_utils.RegistryGeneratorFunctionOutput + ], + trainable_vars: Sequence[tf.Variable], +) -> Mapping[str, type_aliases.ContributionCountHistogramFn]: + """Extracts a map of contribution count fns from generator outputs. + + Args: + registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput` + instances returned by + `gradient_clipping_utils.model_forward_backward_pass`. + trainable_vars: A list of trainable variables. + + Returns: + A `dict` from varname to contribution counts functions + """ + if trainable_vars is not None: + # Create a set using `ref()` for fast set membership check. tf.Variable + # itself is not hashable. + trainable_vars = set([v.ref() for v in trainable_vars]) + + varname_to_contribution_counts_fns = collections.defaultdict(list) + for registry_fn_output in registry_fn_outputs_list: + if trainable_vars is None or any( + w.ref() in trainable_vars + for w in registry_fn_output.layer_trainable_weights + ): + if registry_fn_output.varname_to_count_contribution_fn is not None: + duplicate_varnames = set( + registry_fn_output.varname_to_count_contribution_fn.keys() + ) & set(varname_to_contribution_counts_fns.keys()) + if duplicate_varnames: + raise ValueError( + 'Duplicate varnames: {duplicate_varnames} found in contribution' + ' counts functions.' + ) + varname_to_contribution_counts_fns.update( + registry_fn_output.varname_to_count_contribution_fn + ) + return varname_to_contribution_counts_fns + + def get_contribution_counts( - trainable_vars: list[tf.Variable], - grads: list[tf.Tensor], - varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor], -) -> list[tf.Tensor | None]: + trainable_vars: Sequence[tf.Variable], + grads: Sequence[tf.Tensor], + varname_to_contribution_counts_fns: Mapping[ + str, type_aliases.ContributionCountHistogramFn + ], +) -> Sequence[type_aliases.ContributionCountHistogram | None]: """Gets the contribution counts for each variable in the Model. Args: - trainable_vars: A list of the trainable variables in the Model. + trainable_vars: A list of trainable variables. grads: A corresponding list of gradients for each trainable variable. varname_to_contribution_counts_fns: A mapping from variable name to a list of functions to get the contribution counts for that variable. diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py index 38e11b2..35b26ad 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py @@ -17,6 +17,7 @@ from absl.testing import parameterized import numpy as np from scipy import stats import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils @@ -436,6 +437,96 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase): np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy())) ) + def test_extract_varname_to_contribution_counts_fns(self): + def fn1(_): + return 1.0 + + def fn2(_): + return 2.0 + + var1 = tf.Variable(tf.ones((1, 2)), name='var1') + var2 = tf.Variable(tf.ones((1, 2)), name='var2') + var3 = tf.Variable(tf.ones((1, 2)), name='var3') + + registry_fn_outputs_list = [ + gradient_clipping_utils.RegistryGeneratorFunctionOutput( + layer_id='layer1', + layer_vars=[var1], + layer_sqr_norm_fn=None, + layer_trainable_weights=[var1], + varname_to_count_contribution_fn=None, + ), + gradient_clipping_utils.RegistryGeneratorFunctionOutput( + layer_id='layer2', + layer_vars=[var2], + layer_sqr_norm_fn=None, + layer_trainable_weights=[var2], + varname_to_count_contribution_fn={ + 'var2:0': [fn2], + }, + ), + gradient_clipping_utils.RegistryGeneratorFunctionOutput( + layer_id='layer3', + layer_vars=[var3], + layer_sqr_norm_fn=None, + layer_trainable_weights=[var3], + varname_to_count_contribution_fn={ + 'var3:0': [fn1], + }, + ), + ] + expected_varname_to_contribution_counts_fns = { + 'var2:0': [fn2], + 'var3:0': [fn1], + } + varname_to_contribution_counts_fns = ( + sparse_noise_utils.extract_varname_to_contribution_counts_fns( + registry_fn_outputs_list, + trainable_vars=None, + ) + ) + self.assertEqual( + varname_to_contribution_counts_fns, + expected_varname_to_contribution_counts_fns, + ) + + def test_extract_varname_to_contribution_counts_fns_duplicate_varnames(self): + def fn1(_): + return 1.0 + + def fn2(_): + return 2.0 + + var1 = tf.Variable(tf.ones((1, 2)), name='var1') + var2 = tf.Variable(tf.ones((1, 2)), name='var1') + + registry_fn_outputs_list = [ + gradient_clipping_utils.RegistryGeneratorFunctionOutput( + layer_id='layer1', + layer_vars=[var1], + layer_sqr_norm_fn=None, + layer_trainable_weights=[var1], + varname_to_count_contribution_fn={ + 'var1:0': [fn1], + }, + ), + gradient_clipping_utils.RegistryGeneratorFunctionOutput( + layer_id='layer2', + layer_vars=[var2], + layer_sqr_norm_fn=None, + layer_trainable_weights=[var2], + varname_to_count_contribution_fn={ + 'var1:0': [fn2], + }, + ), + ] + + with self.assertRaises(ValueError): + sparse_noise_utils.extract_varname_to_contribution_counts_fns( + registry_fn_outputs_list, + trainable_vars=None, + ) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py index 82c7e76..48283b3 100644 --- a/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py @@ -22,7 +22,7 @@ InputKwargs = Mapping[str, Any] SparseGradient = tf.IndexedSlices ContributionCountHistogram = tf.SparseTensor ContributionCountHistogramFn = Callable[ - [SparseGradient], Mapping[str, ContributionCountHistogram] + [SparseGradient], ContributionCountHistogram ] NumMicrobatches = int | tf.Tensor SparsityPreservingNoiseLayerRegistryFunction = Callable[