diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD b/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD new file mode 100644 index 0000000..e28361d --- /dev/null +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD @@ -0,0 +1,14 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +py_library( + name = "type_aliases", + srcs = ["type_aliases.py"], +) + +py_library( + name = "layer_registry", + srcs = ["layer_registry.py"], + deps = [":type_aliases"], +) diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/layer_registry.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/layer_registry.py new file mode 100644 index 0000000..6c7d9a4 --- /dev/null +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/layer_registry.py @@ -0,0 +1,51 @@ +# Copyright 2024, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Registry of layer classes to their contribution histogram functions.""" + +from typing import Type + +import tensorflow as tf +from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases + + +# ============================================================================== +# Main class +# ============================================================================== +class LayerRegistry: + """Custom container for layer registry functions.""" + + def __init__(self): + """Basic initialization of various internal dictionaries.""" + self._layer_class_dict = {} + self._registry = {} + + def is_elem(self, layer_instance: tf.keras.layers.Layer) -> bool: + """Checks if a layer instance's class is in the registry.""" + return hash(layer_instance.__class__) in self._registry + + def lookup( + self, layer_instance: tf.keras.layers.Layer + ) -> type_aliases.SparsityPreservingNoiseLayerRegistryFunction: + """Returns the layer registry function for a given layer instance.""" + return self._registry[hash(layer_instance.__class__)] + + def insert( + self, + layer_class: Type[tf.keras.layers.Layer], + layer_registry_function: type_aliases.SparsityPreservingNoiseLayerRegistryFunction, + ): + """Inserts a layer registry function into the internal dictionaries.""" + layer_key = hash(layer_class) + self._layer_class_dict[layer_key] = layer_class + self._registry[layer_key] = layer_registry_function diff --git a/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py b/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py new file mode 100644 index 0000000..c63d083 --- /dev/null +++ b/tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py @@ -0,0 +1,31 @@ +# Copyright 2024, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Type aliases for sparsity preserving noise.""" + +from collections.abc import Callable, Mapping, Sequence +from typing import Any +import tensorflow as tf + +InputArgs = Sequence[Any] +InputKwargs = Mapping[str, Any] +SparseGradient = tf.IndexedSlices +ContributionCountHistogram = tf.SparseTensor +ContributionCountHistogramFn = Callable[ + [SparseGradient], Mapping[str, ContributionCountHistogram] +] +NumMicrobatches = int | tf.Tensor +SparsityPreservingNoiseLayerRegistryFunction = Callable[ + [tf.keras.layers.Layer, InputArgs, InputKwargs, NumMicrobatches | None], + ContributionCountHistogramFn, +]