Sparsity Preserving DP-SGD in TF Privacy [1 of 4]
Adds layer registry and type aliases. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 648747866
This commit is contained in:
parent
00384db109
commit
348895a7a3
3 changed files with 96 additions and 0 deletions
14
tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD
Normal file
14
tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "type_aliases",
|
||||||
|
srcs = ["type_aliases.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "layer_registry",
|
||||||
|
srcs = ["layer_registry.py"],
|
||||||
|
deps = [":type_aliases"],
|
||||||
|
)
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2024, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Registry of layer classes to their contribution histogram functions."""
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Main class
|
||||||
|
# ==============================================================================
|
||||||
|
class LayerRegistry:
|
||||||
|
"""Custom container for layer registry functions."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Basic initialization of various internal dictionaries."""
|
||||||
|
self._layer_class_dict = {}
|
||||||
|
self._registry = {}
|
||||||
|
|
||||||
|
def is_elem(self, layer_instance: tf.keras.layers.Layer) -> bool:
|
||||||
|
"""Checks if a layer instance's class is in the registry."""
|
||||||
|
return hash(layer_instance.__class__) in self._registry
|
||||||
|
|
||||||
|
def lookup(
|
||||||
|
self, layer_instance: tf.keras.layers.Layer
|
||||||
|
) -> type_aliases.SparsityPreservingNoiseLayerRegistryFunction:
|
||||||
|
"""Returns the layer registry function for a given layer instance."""
|
||||||
|
return self._registry[hash(layer_instance.__class__)]
|
||||||
|
|
||||||
|
def insert(
|
||||||
|
self,
|
||||||
|
layer_class: Type[tf.keras.layers.Layer],
|
||||||
|
layer_registry_function: type_aliases.SparsityPreservingNoiseLayerRegistryFunction,
|
||||||
|
):
|
||||||
|
"""Inserts a layer registry function into the internal dictionaries."""
|
||||||
|
layer_key = hash(layer_class)
|
||||||
|
self._layer_class_dict[layer_key] = layer_class
|
||||||
|
self._registry[layer_key] = layer_registry_function
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2024, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Type aliases for sparsity preserving noise."""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Mapping, Sequence
|
||||||
|
from typing import Any
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
InputArgs = Sequence[Any]
|
||||||
|
InputKwargs = Mapping[str, Any]
|
||||||
|
SparseGradient = tf.IndexedSlices
|
||||||
|
ContributionCountHistogram = tf.SparseTensor
|
||||||
|
ContributionCountHistogramFn = Callable[
|
||||||
|
[SparseGradient], Mapping[str, ContributionCountHistogram]
|
||||||
|
]
|
||||||
|
NumMicrobatches = int | tf.Tensor
|
||||||
|
SparsityPreservingNoiseLayerRegistryFunction = Callable[
|
||||||
|
[tf.keras.layers.Layer, InputArgs, InputKwargs, NumMicrobatches | None],
|
||||||
|
ContributionCountHistogramFn,
|
||||||
|
]
|
Loading…
Reference in a new issue