Compare commits
10 commits
8294cec132
...
d965556ebb
Author | SHA1 | Date | |
---|---|---|---|
|
d965556ebb | ||
|
e8856835a6 | ||
|
e3fd3afdf8 | ||
|
66d05a22a3 | ||
|
b3963971e3 | ||
|
93c7e54327 | ||
|
38d80cae92 | ||
|
bf6cf4dec9 | ||
|
e42b574465 | ||
|
09c68750d7 |
20 changed files with 855 additions and 203 deletions
2
setup.py
2
setup.py
|
@ -37,7 +37,7 @@ setuptools.setup(
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'absl-py>=1.0,==1.*',
|
'absl-py>=1.0,==1.*',
|
||||||
'dm-tree==0.1.8',
|
'dm-tree==0.1.8',
|
||||||
'dp-accounting==0.4.4',
|
'dp-accounting==0.4.4', # TODO(b/364653784)
|
||||||
'numpy~=1.21',
|
'numpy~=1.21',
|
||||||
'packaging~=22.0',
|
'packaging~=22.0',
|
||||||
'scikit-learn>=1.0,==1.*',
|
'scikit-learn>=1.0,==1.*',
|
||||||
|
|
|
@ -43,8 +43,11 @@ py_library(
|
||||||
srcs = ["gradient_clipping_utils.py"],
|
srcs = ["gradient_clipping_utils.py"],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":common_manip_utils",
|
||||||
":layer_registry",
|
":layer_registry",
|
||||||
":type_aliases",
|
":type_aliases",
|
||||||
|
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
|
||||||
|
"//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -54,7 +57,11 @@ py_test(
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 8,
|
shard_count = 8,
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [":gradient_clipping_utils"],
|
deps = [
|
||||||
|
":gradient_clipping_utils",
|
||||||
|
":layer_registry",
|
||||||
|
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -83,6 +90,7 @@ py_library(
|
||||||
py_library(
|
py_library(
|
||||||
name = "noise_utils",
|
name = "noise_utils",
|
||||||
srcs = ["noise_utils.py"],
|
srcs = ["noise_utils.py"],
|
||||||
|
deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -94,6 +102,7 @@ py_test(
|
||||||
deps = [
|
deps = [
|
||||||
":clip_grads",
|
":clip_grads",
|
||||||
":common_test_utils",
|
":common_test_utils",
|
||||||
|
":gradient_clipping_utils",
|
||||||
":layer_registry",
|
":layer_registry",
|
||||||
":type_aliases",
|
":type_aliases",
|
||||||
],
|
],
|
||||||
|
|
|
@ -22,7 +22,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
from collections.abc import Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -32,73 +32,81 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
|
|
||||||
def _infer_per_example_loss_fn(model: tf.keras.Model):
|
def _compute_gradient_norms_internal(
|
||||||
"""Infer the per-example loss from model config."""
|
registry_fn_outputs_list: Sequence[
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput
|
||||||
|
],
|
||||||
|
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
|
||||||
|
trainable_vars: Optional[Sequence[tf.Variable]] = None,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Computes the per-example loss gradient norms for given data.
|
||||||
|
|
||||||
def _convert(loss_fn):
|
Args:
|
||||||
loss_config = loss_fn.get_config()
|
registry_fn_outputs_list: A sequence of RegistryGeneratorFunctionOutput
|
||||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
containing information required to compute the gradient norms and
|
||||||
return loss_fn.from_config(loss_config)
|
contribution counts. Output from
|
||||||
|
`gradient_clipping_utils.model_forward_backward_pass()`.
|
||||||
|
layer_grad_vars: A mapping of layer id to a list of gradients for each
|
||||||
|
trainable variable in the layer. Output from
|
||||||
|
`gradient_clipping_utils.model_forward_backward_pass()`.
|
||||||
|
trainable_vars: The list of variables included in computing the gradient
|
||||||
|
norm. When a layer has multiple variables, we include all the variables if
|
||||||
|
any of the variables is in the list. If `trainable_vars` is None, all the
|
||||||
|
variables are included.
|
||||||
|
|
||||||
model_loss = model.loss
|
Returns:
|
||||||
if isinstance(model_loss, tf.keras.losses.Loss):
|
A scalar vector, whose i-th entry is the norm of the gradient of the i-th
|
||||||
return _convert(model_loss)
|
weighted example loss (when num_microbatches is None) or the norm of the
|
||||||
elif isinstance(model_loss, dict):
|
gradient of the i-th microbatch loss (define as a mean over the microbatch).
|
||||||
# Note that we cannot call the public method `.get_compile_config()` because
|
Note that when the loss is weighted (`weight_batch` is not None), weights
|
||||||
# it calls a numpy function, which is not supported inside a `tf.function`
|
are applied prior to clipping.
|
||||||
# wrapped function.
|
|
||||||
compile_config = model._compile_config.config # pylint: disable=protected-access
|
Raises:
|
||||||
if compile_config is None:
|
ValueError: If `layer_grad_vars` is empty.
|
||||||
raise ValueError('Model must be compiled for loss function conversion')
|
ValueError: If the number of gradients for a layer is not equal to the
|
||||||
# Does a weighted mean of the configured losses. Note that we cannot build
|
number of squared norm functions for that layer.
|
||||||
# from the config of the compiled loss because (i) it builds a
|
"""
|
||||||
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
|
if trainable_vars is not None:
|
||||||
# during its construction, (ii) non-unique `tf.Variables` cannot be used
|
# Create a set using `ref()` for fast set membership check. tf.Variable
|
||||||
# inside a `tf.function`, which is usually where this function is used.
|
# itself is not hashable.
|
||||||
if 'loss_weights' not in compile_config:
|
trainable_vars = set([v.ref() for v in trainable_vars])
|
||||||
raise ValueError(
|
|
||||||
'Models with multiple loss must have corresponding loss weights for'
|
layer_sqr_norm_fns = collections.defaultdict(list)
|
||||||
' loss function conversion'
|
# The case of shared weights:
|
||||||
|
# If a layer is called k times, it will appear k times in filtered_outputs,
|
||||||
|
# with the same id, but potentially with different v and f. The code below
|
||||||
|
# groups filtered_outputs by layer_id, so we can correctly compute gradient
|
||||||
|
# norms. The gradient norm of a layer that occurs k times is computed as
|
||||||
|
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
|
||||||
|
# occurrence. This is an over-estimate of the actual norm. For more details,
|
||||||
|
# see the explanation in go/dp-sgd-shared-weights.
|
||||||
|
for registry_fn_output in registry_fn_outputs_list:
|
||||||
|
if trainable_vars is None or any(
|
||||||
|
w.ref() in trainable_vars
|
||||||
|
for w in registry_fn_output.layer_trainable_weights
|
||||||
|
):
|
||||||
|
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
|
||||||
|
registry_fn_output.layer_sqr_norm_fn
|
||||||
)
|
)
|
||||||
weights = compile_config['loss_weights']
|
|
||||||
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
|
|
||||||
num_losses = len(weights)
|
|
||||||
|
|
||||||
def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
|
if not layer_grad_vars:
|
||||||
loss_values = []
|
raise ValueError('The gradient list cannot be empty.')
|
||||||
if model_loss.keys() - y_pred.keys():
|
sqr_norm_list = []
|
||||||
raise ValueError(
|
for layer_id in layer_sqr_norm_fns.keys():
|
||||||
'y_pred must contain the same keys and the model losses, but '
|
fns = layer_sqr_norm_fns[layer_id]
|
||||||
'got %s and %s' % (y_pred.keys(), model_loss.keys())
|
grads = layer_grad_vars[layer_id]
|
||||||
)
|
# Number of duplicates for this layer in `filtered_outputs`.
|
||||||
if model_loss.keys() - y_true.keys():
|
num_passes = len(fns)
|
||||||
raise ValueError(
|
if len(fns) != len(grads):
|
||||||
'y_true must contain the same keys and the model losses, but '
|
raise ValueError(
|
||||||
'got %s and %s' % (y_true.keys(), model_loss.keys())
|
'There must be as many gradients as squared norm functions.'
|
||||||
)
|
)
|
||||||
if sample_weight is not None:
|
# See go/dp-sgd-shared-weights for more details.
|
||||||
if model_loss.keys() - sample_weight.keys():
|
for fn, grad in zip(fns, grads):
|
||||||
raise ValueError(
|
sqr_norm_list.append(num_passes * fn(grad))
|
||||||
'sample_weight must contain the same keys and the model losses,'
|
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
|
||||||
' but got %s and %s' % (y_true.keys(), model_loss.keys())
|
gradient_norms = tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||||
)
|
return gradient_norms
|
||||||
for k in y_true.keys():
|
|
||||||
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
|
|
||||||
sgl_value = (
|
|
||||||
weights[k]
|
|
||||||
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
|
|
||||||
/ num_losses
|
|
||||||
)
|
|
||||||
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
|
|
||||||
return tf.math.add_n(loss_values)
|
|
||||||
|
|
||||||
return _per_example_loss_fn
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
'Unsupported type for loss function conversion: {}'.format(
|
|
||||||
type(model_loss)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_gradient_norms(
|
def compute_gradient_norms(
|
||||||
|
@ -110,7 +118,7 @@ def compute_gradient_norms(
|
||||||
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
|
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
|
||||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||||
trainable_vars: Optional[Sequence[tf.Variable]] = None,
|
trainable_vars: Optional[Sequence[tf.Variable]] = None,
|
||||||
):
|
) -> tf.Tensor:
|
||||||
"""Computes the per-example loss gradient norms for given data.
|
"""Computes the per-example loss gradient norms for given data.
|
||||||
|
|
||||||
Applies a variant of the approach given in
|
Applies a variant of the approach given in
|
||||||
|
@ -154,90 +162,28 @@ def compute_gradient_norms(
|
||||||
"""
|
"""
|
||||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||||
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
||||||
tape, layer_registry, num_microbatches
|
tape=tape,
|
||||||
|
layer_registry=layer_registry,
|
||||||
|
sparse_noise_layer_registry=None,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
)
|
)
|
||||||
# First loop computes the model outputs, summed loss, and generator outputs.
|
layer_grad_vars, generator_outputs_list = (
|
||||||
with tape:
|
gradient_clipping_utils.model_forward_backward_pass(
|
||||||
model_outputs, generator_outputs_list = (
|
tape=tape,
|
||||||
gradient_clipping_utils.model_forward_pass(
|
input_model=input_model,
|
||||||
input_model, x_batch, generator_fn=registry_generator_fn
|
x_batch=x_batch,
|
||||||
)
|
y_batch=y_batch,
|
||||||
)
|
registry_generator_fn=registry_generator_fn,
|
||||||
|
weight_batch=weight_batch,
|
||||||
# Ignore the original loss function's reduction to get per-example loss.
|
per_example_loss_fn=per_example_loss_fn,
|
||||||
if per_example_loss_fn is None:
|
num_microbatches=num_microbatches,
|
||||||
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
|
|
||||||
|
|
||||||
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
|
|
||||||
if losses.shape is None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"The unreduced (or per-example) loss's shape cannot be `None`"
|
|
||||||
)
|
)
|
||||||
if len(losses.shape) != 1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
'The unreduced (or per-example) loss needs to have a shape of length '
|
|
||||||
'one, but received an unreduced loss of shape length %s'
|
|
||||||
% len(losses.shape)
|
|
||||||
)
|
|
||||||
if num_microbatches is not None:
|
|
||||||
losses = tf.reduce_mean(
|
|
||||||
common_manip_utils.maybe_add_microbatch_axis(
|
|
||||||
losses, num_microbatches
|
|
||||||
),
|
|
||||||
axis=1,
|
|
||||||
)
|
|
||||||
summed_loss = tf.reduce_sum(losses)
|
|
||||||
# Unwrap the generator outputs so that the next loop avoids duplicating
|
|
||||||
# backprop ops.
|
|
||||||
filtered_outputs = [t for t in generator_outputs_list if t is not None]
|
|
||||||
if trainable_vars is not None:
|
|
||||||
# Create a set using `ref()` for fast set membership check. tf.Variable
|
|
||||||
# itself is not hashable.
|
|
||||||
trainable_vars = set([v.ref() for v in trainable_vars])
|
|
||||||
layer_vars = collections.defaultdict(list)
|
|
||||||
layer_sqr_norm_fns = collections.defaultdict(list)
|
|
||||||
# The case of shared weights:
|
|
||||||
# If a layer is called k times, it will appear k times in filtered_outputs,
|
|
||||||
# with the same id, but potentially with different v and f. The code below
|
|
||||||
# groups filtered_outputs by layer_id, so we can correctly compute gradient
|
|
||||||
# norms. The gradient norm of a layer that occurs k times is computed as
|
|
||||||
# $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th
|
|
||||||
# occurrence. This is an over-estimate of the actual norm. For more details,
|
|
||||||
# see the explanation in go/dp-sgd-shared-weights.
|
|
||||||
for registry_fn_output in filtered_outputs:
|
|
||||||
if trainable_vars is None or any(
|
|
||||||
w.ref() in trainable_vars
|
|
||||||
for w in registry_fn_output.layer_trainable_weights
|
|
||||||
):
|
|
||||||
layer_vars[registry_fn_output.layer_id].append(
|
|
||||||
registry_fn_output.layer_vars
|
|
||||||
)
|
|
||||||
layer_sqr_norm_fns[registry_fn_output.layer_id].append(
|
|
||||||
registry_fn_output.layer_sqr_norm_fn
|
|
||||||
)
|
|
||||||
# Second loop evaluates the squared L2 norm functions and appends the results.
|
|
||||||
layer_grad_vars = tape.gradient(
|
|
||||||
summed_loss,
|
|
||||||
layer_vars,
|
|
||||||
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
|
||||||
)
|
)
|
||||||
if not layer_grad_vars:
|
return _compute_gradient_norms_internal(
|
||||||
raise ValueError('The gradient list cannot be empty.')
|
registry_fn_outputs_list=generator_outputs_list,
|
||||||
sqr_norm_list = []
|
layer_grad_vars=layer_grad_vars,
|
||||||
for layer_id in layer_sqr_norm_fns.keys():
|
trainable_vars=trainable_vars,
|
||||||
fns = layer_sqr_norm_fns[layer_id]
|
)
|
||||||
grads = layer_grad_vars[layer_id]
|
|
||||||
# Number of duplicates for this layer in `filtered_outputs`.
|
|
||||||
num_passes = len(fns)
|
|
||||||
if len(fns) != len(grads):
|
|
||||||
raise ValueError(
|
|
||||||
'There must be as many gradients as squared norm functions.'
|
|
||||||
)
|
|
||||||
# See go/dp-sgd-shared-weights for more details.
|
|
||||||
for fn, grad in zip(fns, grads):
|
|
||||||
sqr_norm_list.append(num_passes * fn(grad))
|
|
||||||
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
|
|
||||||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
|
||||||
|
|
||||||
|
|
||||||
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
|
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
|
||||||
|
@ -267,14 +213,17 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
|
||||||
|
|
||||||
def compute_clipped_gradients_and_outputs(
|
def compute_clipped_gradients_and_outputs(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
|
registry_fn_outputs_list: Sequence[
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput
|
||||||
|
],
|
||||||
|
layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]],
|
||||||
l2_norm_clip: float,
|
l2_norm_clip: float,
|
||||||
layer_registry: lr.LayerRegistry,
|
|
||||||
x_batch: type_aliases.InputTensors,
|
x_batch: type_aliases.InputTensors,
|
||||||
y_batch: type_aliases.OutputTensors,
|
y_batch: type_aliases.OutputTensors,
|
||||||
weight_batch: Optional[tf.Tensor] = None,
|
weight_batch: Optional[tf.Tensor] = None,
|
||||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||||
clipping_loss: Optional[type_aliases.LossFn] = None,
|
clipping_loss: Optional[type_aliases.LossFn] = None,
|
||||||
) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]:
|
) -> tuple[Sequence[type_aliases.Tensor], tf.Tensor, tf.Tensor]:
|
||||||
"""Computes the per-example clipped loss gradient and other useful outputs.
|
"""Computes the per-example clipped loss gradient and other useful outputs.
|
||||||
|
|
||||||
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
|
Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main
|
||||||
|
@ -287,15 +236,16 @@ def compute_clipped_gradients_and_outputs(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_model: The `tf.keras.Model` from which to obtain the layers from.
|
input_model: The `tf.keras.Model` from which to obtain the layers from.
|
||||||
|
registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput
|
||||||
|
containing information required to compute the gradient norms and
|
||||||
|
contribution counts. Output from
|
||||||
|
`gradient_clipping_utils.model_forward_backward_pass()`.
|
||||||
|
layer_grad_vars: A mapping of layer id to a list of gradients for each
|
||||||
|
trainablev ariable in the layer. Output from
|
||||||
|
`gradient_clipping_utils.model_forward_backward_pass()`.
|
||||||
l2_norm_clip: A `float` indicating the norm to which per-example gradients
|
l2_norm_clip: A `float` indicating the norm to which per-example gradients
|
||||||
will be clipped. That is, all gradients of the per-example loss functions
|
will be clipped. That is, all gradients of the per-example loss functions
|
||||||
will have norm at most `l2_norm_clip`.
|
will have norm at most `l2_norm_clip`.
|
||||||
layer_registry: A `dict` of layers that support "fast" gradient norm
|
|
||||||
computations. The key is the class of the layer and the value is a
|
|
||||||
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
|
|
||||||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
|
||||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
|
||||||
trainable weights (see `layer_registry_factories.py` for examples).
|
|
||||||
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||||
first axes of each tensor must be the batch dimension.
|
first axes of each tensor must be the batch dimension.
|
||||||
y_batch: An `OutputTensor` representing a batch of output labels. The first
|
y_batch: An `OutputTensor` representing a batch of output labels. The first
|
||||||
|
@ -330,13 +280,9 @@ def compute_clipped_gradients_and_outputs(
|
||||||
)
|
)
|
||||||
if clipping_loss is None:
|
if clipping_loss is None:
|
||||||
clipping_loss = input_model.compiled_loss
|
clipping_loss = input_model.compiled_loss
|
||||||
gradient_norms = compute_gradient_norms(
|
gradient_norms = _compute_gradient_norms_internal(
|
||||||
input_model,
|
registry_fn_outputs_list=registry_fn_outputs_list,
|
||||||
layer_registry,
|
layer_grad_vars=layer_grad_vars,
|
||||||
x_batch,
|
|
||||||
y_batch,
|
|
||||||
weight_batch,
|
|
||||||
num_microbatches=num_microbatches,
|
|
||||||
trainable_vars=input_model.trainable_variables,
|
trainable_vars=input_model.trainable_variables,
|
||||||
)
|
)
|
||||||
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
||||||
|
|
|
@ -19,6 +19,7 @@ from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
|
@ -122,6 +123,30 @@ class CustomLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_model_forward_backward_pass(
|
||||||
|
model: tf.keras.Model,
|
||||||
|
x_batch: type_aliases.InputTensors,
|
||||||
|
y_batch: type_aliases.OutputTensors,
|
||||||
|
):
|
||||||
|
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||||
|
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
||||||
|
tape=tape,
|
||||||
|
layer_registry=layer_registry.make_default_layer_registry(),
|
||||||
|
sparse_noise_layer_registry=None,
|
||||||
|
num_microbatches=None,
|
||||||
|
)
|
||||||
|
layer_grad_vars, registry_fn_outputs_list = (
|
||||||
|
gradient_clipping_utils.model_forward_backward_pass(
|
||||||
|
tape=tape,
|
||||||
|
input_model=model,
|
||||||
|
x_batch=x_batch,
|
||||||
|
y_batch=y_batch,
|
||||||
|
registry_generator_fn=registry_generator_fn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return layer_grad_vars, registry_fn_outputs_list
|
||||||
|
|
||||||
|
|
||||||
class ComputeClippedGradsAndOutputsTest(
|
class ComputeClippedGradsAndOutputsTest(
|
||||||
tf.test.TestCase, parameterized.TestCase
|
tf.test.TestCase, parameterized.TestCase
|
||||||
):
|
):
|
||||||
|
@ -153,13 +178,17 @@ class ComputeClippedGradsAndOutputsTest(
|
||||||
y_batch = tf.reshape(
|
y_batch = tf.reshape(
|
||||||
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
|
1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1]
|
||||||
)
|
)
|
||||||
|
layer_grad_vars, registry_fn_outputs_list = (
|
||||||
|
_run_model_forward_backward_pass(self._model, x_batch, y_batch)
|
||||||
|
)
|
||||||
# Stop early for efficiency.
|
# Stop early for efficiency.
|
||||||
if reduction == 'none':
|
if reduction == 'none':
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
clip_grads.compute_clipped_gradients_and_outputs(
|
clip_grads.compute_clipped_gradients_and_outputs(
|
||||||
self._model,
|
self._model,
|
||||||
|
registry_fn_outputs_list,
|
||||||
|
layer_grad_vars,
|
||||||
l2_norm_clip,
|
l2_norm_clip,
|
||||||
layer_registry.make_default_layer_registry(),
|
|
||||||
x_batch,
|
x_batch,
|
||||||
y_batch,
|
y_batch,
|
||||||
)
|
)
|
||||||
|
@ -169,10 +198,12 @@ class ComputeClippedGradsAndOutputsTest(
|
||||||
y_pred = self._model(x_batch)
|
y_pred = self._model(x_batch)
|
||||||
loss_value = loss_fn(y_pred, y_batch)
|
loss_value = loss_fn(y_pred, y_batch)
|
||||||
true_grads = tape.gradient(loss_value, self._model.trainable_variables)
|
true_grads = tape.gradient(loss_value, self._model.trainable_variables)
|
||||||
|
|
||||||
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
|
clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs(
|
||||||
self._model,
|
self._model,
|
||||||
|
registry_fn_outputs_list,
|
||||||
|
layer_grad_vars,
|
||||||
l2_norm_clip,
|
l2_norm_clip,
|
||||||
layer_registry.make_default_layer_registry(),
|
|
||||||
x_batch,
|
x_batch,
|
||||||
y_batch,
|
y_batch,
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,13 +13,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||||
|
|
||||||
|
import collections
|
||||||
from collections.abc import Callable, Sequence, Set
|
from collections.abc import Callable, Sequence, Set
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases as sn_type_aliases
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
@ -27,6 +31,9 @@ class RegistryGeneratorFunctionOutput:
|
||||||
layer_id: str
|
layer_id: str
|
||||||
layer_vars: Optional[Sequence[tf.Variable]]
|
layer_vars: Optional[Sequence[tf.Variable]]
|
||||||
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
|
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
|
||||||
|
varname_to_count_contribution_fn: Optional[
|
||||||
|
dict[str, sn_type_aliases.ContributionCountHistogramFn]
|
||||||
|
]
|
||||||
layer_trainable_weights: Optional[Sequence[tf.Variable]]
|
layer_trainable_weights: Optional[Sequence[tf.Variable]]
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,6 +51,7 @@ def has_internal_compute_graph(input_object: Any):
|
||||||
def get_registry_generator_fn(
|
def get_registry_generator_fn(
|
||||||
tape: tf.GradientTape,
|
tape: tf.GradientTape,
|
||||||
layer_registry: lr.LayerRegistry,
|
layer_registry: lr.LayerRegistry,
|
||||||
|
sparse_noise_layer_registry: snlr.LayerRegistry,
|
||||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||||
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
|
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
|
||||||
"""Creates the generator function for `model_forward_backward_pass()`.
|
"""Creates the generator function for `model_forward_backward_pass()`.
|
||||||
|
@ -56,6 +64,10 @@ def get_registry_generator_fn(
|
||||||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
||||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||||
trainable
|
trainable
|
||||||
|
sparse_noise_layer_registry: A `LayerRegistry` instance containing functions
|
||||||
|
that help compute contribution counts for sparse noise. See
|
||||||
|
`tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
|
||||||
|
more details.
|
||||||
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
|
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
|
||||||
microbatches. If not None, indicates that the loss is grouped into
|
microbatches. If not None, indicates that the loss is grouped into
|
||||||
num_microbatches (in this case, the batch dimension needs to be a multiple
|
num_microbatches (in this case, the batch dimension needs to be a multiple
|
||||||
|
@ -81,6 +93,16 @@ def get_registry_generator_fn(
|
||||||
'be used for efficient gradient clipping.'
|
'be used for efficient gradient clipping.'
|
||||||
% layer_instance.__class__.__name__
|
% layer_instance.__class__.__name__
|
||||||
)
|
)
|
||||||
|
varname_to_count_contribution_fn = None
|
||||||
|
if sparse_noise_layer_registry and sparse_noise_layer_registry.is_elem(
|
||||||
|
layer_instance
|
||||||
|
):
|
||||||
|
count_contribution_registry_fn = sparse_noise_layer_registry.lookup(
|
||||||
|
layer_instance
|
||||||
|
)
|
||||||
|
varname_to_count_contribution_fn = count_contribution_registry_fn(
|
||||||
|
layer_instance, args, kwargs, num_microbatches
|
||||||
|
)
|
||||||
registry_fn = layer_registry.lookup(layer_instance)
|
registry_fn = layer_registry.lookup(layer_instance)
|
||||||
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
|
||||||
layer_instance, args, kwargs, tape, num_microbatches
|
layer_instance, args, kwargs, tape, num_microbatches
|
||||||
|
@ -89,6 +111,7 @@ def get_registry_generator_fn(
|
||||||
layer_id=str(id(layer_instance)),
|
layer_id=str(id(layer_instance)),
|
||||||
layer_vars=layer_vars,
|
layer_vars=layer_vars,
|
||||||
layer_sqr_norm_fn=layer_sqr_norm_fn,
|
layer_sqr_norm_fn=layer_sqr_norm_fn,
|
||||||
|
varname_to_count_contribution_fn=varname_to_count_contribution_fn,
|
||||||
layer_trainable_weights=layer_instance.trainable_weights,
|
layer_trainable_weights=layer_instance.trainable_weights,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -98,6 +121,149 @@ def get_registry_generator_fn(
|
||||||
return registry_generator_fn
|
return registry_generator_fn
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_per_example_loss_fn(model: tf.keras.Model):
|
||||||
|
"""Infer the per-example loss from model config."""
|
||||||
|
|
||||||
|
def _convert(loss_fn):
|
||||||
|
loss_config = loss_fn.get_config()
|
||||||
|
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||||
|
return loss_fn.from_config(loss_config)
|
||||||
|
|
||||||
|
model_loss = model.loss
|
||||||
|
if isinstance(model_loss, tf.keras.losses.Loss):
|
||||||
|
return _convert(model_loss)
|
||||||
|
elif isinstance(model_loss, dict):
|
||||||
|
# Note that we cannot call the public method `.get_compile_config()` because
|
||||||
|
# it calls a numpy function, which is not supported inside a `tf.function`
|
||||||
|
# wrapped function.
|
||||||
|
compile_config = model._compile_config.config # pylint: disable=protected-access
|
||||||
|
if compile_config is None:
|
||||||
|
raise ValueError('Model must be compiled for loss function conversion')
|
||||||
|
# Does a weighted mean of the configured losses. Note that we cannot build
|
||||||
|
# from the config of the compiled loss because (i) it builds a
|
||||||
|
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
|
||||||
|
# during its construction, (ii) non-unique `tf.Variables` cannot be used
|
||||||
|
# inside a `tf.function`, which is usually where this function is used.
|
||||||
|
if 'loss_weights' not in compile_config:
|
||||||
|
raise ValueError(
|
||||||
|
'Models with multiple loss must have corresponding loss weights for'
|
||||||
|
' loss function conversion'
|
||||||
|
)
|
||||||
|
weights = compile_config['loss_weights']
|
||||||
|
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
|
||||||
|
num_losses = len(weights)
|
||||||
|
|
||||||
|
def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
|
||||||
|
loss_values = []
|
||||||
|
if model_loss.keys() - y_pred.keys():
|
||||||
|
raise ValueError(
|
||||||
|
'y_pred must contain the same keys and the model losses, but '
|
||||||
|
'got %s and %s' % (y_pred.keys(), model_loss.keys())
|
||||||
|
)
|
||||||
|
if model_loss.keys() - y_true.keys():
|
||||||
|
raise ValueError(
|
||||||
|
'y_true must contain the same keys and the model losses, but '
|
||||||
|
'got %s and %s' % (y_true.keys(), model_loss.keys())
|
||||||
|
)
|
||||||
|
if sample_weight is not None:
|
||||||
|
if model_loss.keys() - sample_weight.keys():
|
||||||
|
raise ValueError(
|
||||||
|
'sample_weight must contain the same keys and the model losses,'
|
||||||
|
' but got %s and %s' % (y_true.keys(), model_loss.keys())
|
||||||
|
)
|
||||||
|
for k in y_true.keys():
|
||||||
|
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
|
||||||
|
sgl_value = (
|
||||||
|
weights[k]
|
||||||
|
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
|
||||||
|
/ num_losses
|
||||||
|
)
|
||||||
|
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
|
||||||
|
return tf.math.add_n(loss_values)
|
||||||
|
|
||||||
|
return _per_example_loss_fn
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Unsupported type for loss function conversion: {}'.format(
|
||||||
|
type(model_loss)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def model_forward_backward_pass(
|
||||||
|
tape: tf.GradientTape,
|
||||||
|
input_model: tf.keras.Model,
|
||||||
|
x_batch: type_aliases.InputTensors,
|
||||||
|
y_batch: type_aliases.OutputTensors,
|
||||||
|
registry_generator_fn: Optional[
|
||||||
|
Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]
|
||||||
|
],
|
||||||
|
weight_batch: Optional[tf.Tensor] = None,
|
||||||
|
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
|
||||||
|
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||||
|
trainable_vars: Optional[Sequence[tf.Variable]] = None,
|
||||||
|
) -> tuple[
|
||||||
|
dict[str, list[type_aliases.Tensor]], list[RegistryGeneratorFunctionOutput]
|
||||||
|
]:
|
||||||
|
"""Does a forward and backward pass of a model and returns useful intermediates."""
|
||||||
|
# First loop computes the model outputs, summed loss, and generator outputs.
|
||||||
|
with tape:
|
||||||
|
model_outputs, generator_outputs_list = model_forward_pass(
|
||||||
|
input_model, x_batch, generator_fn=registry_generator_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ignore the original loss function's reduction to get per-example loss.
|
||||||
|
if per_example_loss_fn is None:
|
||||||
|
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
|
||||||
|
|
||||||
|
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
|
||||||
|
if losses.shape is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The unreduced (or per-example) loss's shape cannot be `None`"
|
||||||
|
)
|
||||||
|
if len(losses.shape) != 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
'The unreduced (or per-example) loss needs to have a shape of length '
|
||||||
|
'one, but received an unreduced loss of shape length %s'
|
||||||
|
% len(losses.shape)
|
||||||
|
)
|
||||||
|
if num_microbatches is not None:
|
||||||
|
losses = tf.reduce_mean(
|
||||||
|
common_manip_utils.maybe_add_microbatch_axis(
|
||||||
|
losses, num_microbatches
|
||||||
|
),
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
summed_loss = tf.reduce_sum(losses)
|
||||||
|
# Unwrap the generator outputs so that the next loop avoids duplicating
|
||||||
|
# backprop ops.
|
||||||
|
filtered_outputs = [t for t in generator_outputs_list if t is not None]
|
||||||
|
|
||||||
|
if trainable_vars is not None:
|
||||||
|
# Create a set using `ref()` for fast set membership check. tf.Variable
|
||||||
|
# itself is not hashable.
|
||||||
|
trainable_vars = set([v.ref() for v in trainable_vars])
|
||||||
|
layer_vars = collections.defaultdict(list)
|
||||||
|
for registry_fn_output in filtered_outputs:
|
||||||
|
if trainable_vars is None or any(
|
||||||
|
w.ref() in trainable_vars
|
||||||
|
for w in registry_fn_output.layer_trainable_weights
|
||||||
|
):
|
||||||
|
layer_vars[registry_fn_output.layer_id].append(
|
||||||
|
registry_fn_output.layer_vars
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_grad_vars = tape.gradient(
|
||||||
|
summed_loss,
|
||||||
|
layer_vars,
|
||||||
|
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||||
|
)
|
||||||
|
if not layer_grad_vars:
|
||||||
|
raise ValueError('The gradient list cannot be empty.')
|
||||||
|
|
||||||
|
return layer_grad_vars, filtered_outputs
|
||||||
|
|
||||||
|
|
||||||
def model_forward_pass(
|
def model_forward_pass(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
inputs: type_aliases.PackedTensors,
|
inputs: type_aliases.PackedTensors,
|
||||||
|
|
|
@ -17,6 +17,8 @@ from typing import Any
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
@ -175,5 +177,92 @@ class GenerateOutputsUsingCoreKerasLayers(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryGeneratorFnTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def _get_sparse_layer_registry(self):
|
||||||
|
def count_contribution_fn(_):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def registry_fn(*_):
|
||||||
|
return {'var': count_contribution_fn}
|
||||||
|
|
||||||
|
registry = snlr.LayerRegistry()
|
||||||
|
registry.insert(tf.keras.layers.Embedding, registry_fn)
|
||||||
|
return registry, count_contribution_fn
|
||||||
|
|
||||||
|
def _get_layer_registry(self):
|
||||||
|
var = tf.Variable(1.0)
|
||||||
|
output = tf.ones((1, 1))
|
||||||
|
|
||||||
|
def sqr_norm_fn(_):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def registry_fn(*_):
|
||||||
|
return [var], output, sqr_norm_fn
|
||||||
|
|
||||||
|
registry = lr.LayerRegistry()
|
||||||
|
registry.insert(tf.keras.layers.Embedding, registry_fn)
|
||||||
|
registry.insert(tf.keras.layers.Dense, registry_fn)
|
||||||
|
return registry, var, output, sqr_norm_fn
|
||||||
|
|
||||||
|
def test_registry_generator_fn(self):
|
||||||
|
inputs = tf.constant([[0, 1]])
|
||||||
|
model = tf.keras.Sequential([
|
||||||
|
tf.keras.layers.Embedding(10, 1),
|
||||||
|
tf.keras.layers.Dense(1),
|
||||||
|
])
|
||||||
|
|
||||||
|
sparse_layer_registry, count_contribution_fn = (
|
||||||
|
self._get_sparse_layer_registry()
|
||||||
|
)
|
||||||
|
layer_registry, var, output, sqr_norm_fn = self._get_layer_registry()
|
||||||
|
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
|
||||||
|
tape=tf.GradientTape(),
|
||||||
|
layer_registry=layer_registry,
|
||||||
|
sparse_noise_layer_registry=sparse_layer_registry,
|
||||||
|
num_microbatches=None,
|
||||||
|
)
|
||||||
|
embedding_layer = model.layers[0]
|
||||||
|
out, embedding_registry_generator_fn_output = registry_generator_fn(
|
||||||
|
embedding_layer,
|
||||||
|
[inputs],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
expected_embedding_registry_generator_fn_output = (
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id=str(id(embedding_layer)),
|
||||||
|
layer_vars=[var],
|
||||||
|
layer_sqr_norm_fn=sqr_norm_fn,
|
||||||
|
varname_to_count_contribution_fn={'var': count_contribution_fn},
|
||||||
|
layer_trainable_weights=embedding_layer.trainable_weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
embedding_registry_generator_fn_output,
|
||||||
|
expected_embedding_registry_generator_fn_output,
|
||||||
|
)
|
||||||
|
self.assertEqual(out, output)
|
||||||
|
dense_layer = model.layers[1]
|
||||||
|
out, dense_registry_generator_fn_output = registry_generator_fn(
|
||||||
|
dense_layer,
|
||||||
|
[inputs],
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
expected_dense_registry_generator_fn_output = (
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id=str(id(dense_layer)),
|
||||||
|
layer_vars=[var],
|
||||||
|
layer_sqr_norm_fn=sqr_norm_fn,
|
||||||
|
varname_to_count_contribution_fn=None,
|
||||||
|
layer_trainable_weights=dense_layer.trainable_weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
dense_registry_generator_fn_output,
|
||||||
|
expected_dense_registry_generator_fn_output,
|
||||||
|
)
|
||||||
|
self.assertEqual(out, output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -14,10 +14,21 @@
|
||||||
"""Utility functions that help in adding noise to gradients."""
|
"""Utility functions that help in adding noise to gradients."""
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
import dataclasses
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SparsityPreservingNoiseConfig:
|
||||||
|
"""Configuration for adding noise to gradients."""
|
||||||
|
|
||||||
|
sparse_noise_multiplier: float = 0.0
|
||||||
|
sparse_selection_threshold: int = 0
|
||||||
|
sparse_contribution_counts: Optional[Sequence[tf.SparseTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
def _infer_loss_reduction_type(model: tf.keras.Model):
|
def _infer_loss_reduction_type(model: tf.keras.Model):
|
||||||
|
@ -44,21 +55,53 @@ def _infer_loss_reduction_type(model: tf.keras.Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_dense_aggregate_noise(
|
||||||
|
grad: tf.Tensor,
|
||||||
|
noise_multiplier: float,
|
||||||
|
sensitivity: float,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Adds dense noise to a dense gradient."""
|
||||||
|
return grad + tf.random.normal(
|
||||||
|
tf.shape(grad), mean=0.0, stddev=noise_multiplier * sensitivity
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_sparse_aggregate_noise(
|
||||||
|
grad: tf.IndexedSlices,
|
||||||
|
contribution_counts: tf.SparseTensor,
|
||||||
|
noise_multiplier: float,
|
||||||
|
noise_multiplier_sparse: float,
|
||||||
|
sensitivity: float,
|
||||||
|
sparse_selection_threshold: int,
|
||||||
|
) -> tf.IndexedSlices:
|
||||||
|
"""Adds sparse noise to a sparse gradient."""
|
||||||
|
return sparse_noise_utils.add_sparse_noise(
|
||||||
|
grad=grad,
|
||||||
|
contribution_counts=contribution_counts,
|
||||||
|
noise_multiplier=noise_multiplier,
|
||||||
|
noise_multiplier_sparse=noise_multiplier_sparse,
|
||||||
|
l2_norm_clip=sensitivity,
|
||||||
|
threshold=sparse_selection_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_aggregate_noise(
|
def add_aggregate_noise(
|
||||||
clipped_grads: list[tf.Tensor],
|
clipped_grads: list[tf.Tensor | tf.IndexedSlices],
|
||||||
batch_size: tf.Tensor,
|
batch_size: tf.Tensor,
|
||||||
l2_norm_clip: float,
|
l2_norm_clip: float,
|
||||||
noise_multiplier: float,
|
noise_multiplier: float,
|
||||||
loss_reduction: Optional[Literal['mean', 'sum']] = None,
|
loss_reduction: Optional[Literal['mean', 'sum']] = None,
|
||||||
loss_model: Optional[tf.keras.Model] = None,
|
loss_model: Optional[tf.keras.Model] = None,
|
||||||
) -> Sequence[tf.Tensor]:
|
sparse_noise_config: Optional[SparsityPreservingNoiseConfig] = None,
|
||||||
|
) -> Sequence[tf.Tensor | tf.IndexedSlices]:
|
||||||
"""Adds noise to a collection of clipped gradients.
|
"""Adds noise to a collection of clipped gradients.
|
||||||
|
|
||||||
The magnitude of the noise depends on the aggregation strategy of the
|
The magnitude of the noise depends on the aggregation strategy of the
|
||||||
input model's loss function.
|
input model's loss function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
clipped_grads: A list of `tf.Tensor`s representing the clipped gradients.
|
clipped_grads: A list of `tf.Tensor`s or `tf.IndexedSlices`s representing
|
||||||
|
the clipped gradients.
|
||||||
batch_size: The batch size. Used for normalizing the noise when
|
batch_size: The batch size. Used for normalizing the noise when
|
||||||
`loss_reduction` is 'sum'.
|
`loss_reduction` is 'sum'.
|
||||||
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
l2_norm_clip: Clipping norm (max L2 norm of each gradient).
|
||||||
|
@ -68,11 +111,14 @@ def add_aggregate_noise(
|
||||||
aggregation type must be inferred from `input_model.loss`.
|
aggregation type must be inferred from `input_model.loss`.
|
||||||
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
|
loss_model: An optional `tf.keras.Model` used to infer the loss reduction
|
||||||
strategy from if `loss_reduction` is `None`.
|
strategy from if `loss_reduction` is `None`.
|
||||||
|
sparse_noise_config: A `SparsityPreservingNoiseConfig` instance containing
|
||||||
|
the configuration for adding sparse noise. If None, all noise added is
|
||||||
|
dense.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of tensors containing the clipped gradients, but with the right
|
A list of tensors containing the clipped gradients, but with the right
|
||||||
amount of Gaussian noise added to them (depending on the reduction
|
amount of Gaussian or sparse Gaussain noise added to them (depending on
|
||||||
strategy of the loss function).
|
the reduction strategy of the loss function).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
|
ValueError: If both `loss_model` and `loss_reduction` are `None` or if
|
||||||
|
@ -103,13 +149,36 @@ def add_aggregate_noise(
|
||||||
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
|
'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if sparse_noise_config is None:
|
||||||
|
sparse_contribution_counts = tf.nest.map_structure(
|
||||||
|
lambda x: None, clipped_grads
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sparse_contribution_counts = sparse_noise_config.sparse_contribution_counts
|
||||||
|
|
||||||
scale = l2_norm_clip
|
scale = l2_norm_clip
|
||||||
if loss_reduction == 'mean':
|
if loss_reduction == 'mean':
|
||||||
scale /= tf.cast(batch_size, tf.float32)
|
scale /= tf.cast(batch_size, tf.float32)
|
||||||
|
|
||||||
def add_noise(g):
|
def add_noise(grad, contribution_counts):
|
||||||
return g + tf.random.normal(
|
if (
|
||||||
tf.shape(g), mean=0.0, stddev=noise_multiplier * scale
|
sparse_noise_config is not None
|
||||||
)
|
and isinstance(grad, tf.IndexedSlices)
|
||||||
|
and contribution_counts is not None
|
||||||
|
):
|
||||||
|
return _add_sparse_aggregate_noise(
|
||||||
|
grad=grad,
|
||||||
|
contribution_counts=contribution_counts,
|
||||||
|
noise_multiplier=noise_multiplier,
|
||||||
|
noise_multiplier_sparse=sparse_noise_config.sparse_noise_multiplier,
|
||||||
|
sensitivity=scale,
|
||||||
|
sparse_selection_threshold=sparse_noise_config.sparse_selection_threshold,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _add_dense_aggregate_noise(
|
||||||
|
grad=grad, noise_multiplier=noise_multiplier, sensitivity=scale
|
||||||
|
)
|
||||||
|
|
||||||
return tf.nest.map_structure(add_noise, clipped_grads)
|
return tf.nest.map_structure(
|
||||||
|
add_noise, clipped_grads, sparse_contribution_counts
|
||||||
|
)
|
||||||
|
|
|
@ -70,3 +70,95 @@ class NoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
computed_std = np.std(noised_grads[0] - clipped_grads[0])
|
computed_std = np.std(noised_grads[0] - clipped_grads[0])
|
||||||
expected_std = l2_norm_clip * noise_multiplier * scale
|
expected_std = l2_norm_clip * noise_multiplier * scale
|
||||||
self.assertNear(computed_std, expected_std, 0.1 * expected_std)
|
self.assertNear(computed_std, expected_std, 0.1 * expected_std)
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
l2_norm_clip=[3.0, 5.0],
|
||||||
|
noise_multiplier=[2.0, 4.0],
|
||||||
|
sparse_noise_multiplier=[1.0],
|
||||||
|
batch_size=[1, 2, 10],
|
||||||
|
model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'],
|
||||||
|
noise_fn_reduction=[None, 'mean', 'sum'],
|
||||||
|
)
|
||||||
|
def test_sparse_noise_is_computed_correctly(
|
||||||
|
self,
|
||||||
|
l2_norm_clip,
|
||||||
|
noise_multiplier,
|
||||||
|
sparse_noise_multiplier,
|
||||||
|
batch_size,
|
||||||
|
model_fn_reduction,
|
||||||
|
noise_fn_reduction,
|
||||||
|
):
|
||||||
|
# Skip invalid combinations.
|
||||||
|
if model_fn_reduction is None and noise_fn_reduction is None:
|
||||||
|
return
|
||||||
|
if model_fn_reduction is not None and noise_fn_reduction is not None:
|
||||||
|
return
|
||||||
|
# Make an simple model container for storing the loss.
|
||||||
|
if model_fn_reduction is not None:
|
||||||
|
linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
|
||||||
|
linear_model.compile(
|
||||||
|
loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
linear_model = None
|
||||||
|
# The main computation is done on a deterministic dummy vector.
|
||||||
|
num_units = 100
|
||||||
|
dense_grad = tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1)
|
||||||
|
sparse_grad = tf.IndexedSlices(
|
||||||
|
values=tf.ones((3, 4)),
|
||||||
|
indices=tf.constant([0, 3, 5]),
|
||||||
|
dense_shape=tf.constant([8, 4]),
|
||||||
|
)
|
||||||
|
sparse_grad_contribution_counts = tf.SparseTensor(
|
||||||
|
indices=[[0], [3], [5]],
|
||||||
|
values=[10.0, 10.0, 20.0],
|
||||||
|
dense_shape=[8],
|
||||||
|
)
|
||||||
|
|
||||||
|
sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
|
||||||
|
sparse_noise_multiplier=sparse_noise_multiplier,
|
||||||
|
sparse_selection_threshold=8,
|
||||||
|
sparse_contribution_counts=[None, sparse_grad_contribution_counts],
|
||||||
|
)
|
||||||
|
|
||||||
|
sparse_noised_grad, dense_noised_grad = noise_utils.add_aggregate_noise(
|
||||||
|
clipped_grads=[dense_grad, sparse_grad],
|
||||||
|
batch_size=batch_size,
|
||||||
|
l2_norm_clip=l2_norm_clip,
|
||||||
|
noise_multiplier=noise_multiplier,
|
||||||
|
loss_model=linear_model,
|
||||||
|
sparse_noise_config=sparse_noise_config,
|
||||||
|
)
|
||||||
|
self.assertContainsSubset(
|
||||||
|
sparse_grad.indices.numpy().tolist(),
|
||||||
|
sparse_noised_grad.indices.numpy().tolist(),
|
||||||
|
)
|
||||||
|
sparse_noised_grad_dense = tf.scatter_nd(
|
||||||
|
tf.reshape(sparse_noised_grad.indices, (-1, 1)),
|
||||||
|
sparse_noised_grad.values,
|
||||||
|
shape=(8, 4),
|
||||||
|
).numpy()
|
||||||
|
sparse_noised_grad_valid_indices = sparse_noised_grad_dense[
|
||||||
|
sparse_grad.indices.numpy()
|
||||||
|
]
|
||||||
|
sparse_grad_values = sparse_grad.values.numpy()
|
||||||
|
self.assertTrue(
|
||||||
|
np.all(
|
||||||
|
np.not_equal(sparse_noised_grad_valid_indices, sparse_grad_values)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scale = (
|
||||||
|
1.0
|
||||||
|
if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum'
|
||||||
|
else 1.0 / batch_size
|
||||||
|
)
|
||||||
|
# The only measure that varies is the standard deviation of the variation.
|
||||||
|
expected_std = l2_norm_clip * noise_multiplier * scale
|
||||||
|
|
||||||
|
sparse_computed_std = np.std(
|
||||||
|
sparse_noised_grad_valid_indices - sparse_grad_values
|
||||||
|
)
|
||||||
|
self.assertNear(sparse_computed_std, expected_std, 0.1 * expected_std)
|
||||||
|
|
||||||
|
dense_computed_std = np.std(dense_noised_grad - dense_grad)
|
||||||
|
self.assertNear(dense_computed_std, expected_std, 0.1 * expected_std)
|
||||||
|
|
|
@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
class GradNormTpuTest(embedding_test.GradNormTest):
|
class GradNormTpuTest(embedding_test.GradNormTest):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
tf.config.experimental.disable_mlir_bridge()
|
||||||
super(embedding_test.GradNormTest, self).setUp()
|
super(embedding_test.GradNormTest, self).setUp()
|
||||||
self.strategy = common_test_utils.create_tpu_strategy()
|
self.strategy = common_test_utils.create_tpu_strategy()
|
||||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
|
|
@ -80,8 +80,11 @@ def layer_normalization_computation(
|
||||||
stacked_grads = tf.stack(grads, axis=-1)
|
stacked_grads = tf.stack(grads, axis=-1)
|
||||||
if num_microbatches is not None:
|
if num_microbatches is not None:
|
||||||
stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
|
stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
|
||||||
grads, num_microbatches
|
stacked_grads, num_microbatches
|
||||||
)
|
)
|
||||||
|
# We will need to sum over the new microbatch size axis (axis=1) in order
|
||||||
|
# to account for microbatch aggregation.
|
||||||
|
stacked_grads = tf.reduce_sum(stacked_grads, axis=1)
|
||||||
reduction_axes = tf.range(1, tf.rank(stacked_grads))
|
reduction_axes = tf.range(1, tf.rank(stacked_grads))
|
||||||
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
|
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
|
||||||
|
|
||||||
|
|
|
@ -134,7 +134,7 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
atol = 1e-1 if self.using_tpu else 1e-2
|
atol = 1e-1 if self.using_tpu else 1e-2
|
||||||
|
|
||||||
# Each batched input is a reshape of a `tf.range()` call.
|
# Each batched input is a reshape of a `tf.range()` call.
|
||||||
batch_size = 2
|
batch_size = 6
|
||||||
example_size = np.prod(input_dims)
|
example_size = np.prod(input_dims)
|
||||||
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
|
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
|
||||||
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
|
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
|
||||||
|
@ -147,7 +147,9 @@ class GradNormTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
common_test_utils.assert_replica_values_are_close(self, true_norms)
|
||||||
computed_norms = computed_norms.values[0]
|
computed_norms = computed_norms.values[0]
|
||||||
true_norms = true_norms.values[0]
|
true_norms = true_norms.values[0]
|
||||||
self.assertEqual(tf.shape(computed_norms)[0], batch_size)
|
self.assertEqual(
|
||||||
|
tf.shape(computed_norms)[0], num_microbatches or batch_size
|
||||||
|
)
|
||||||
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
|
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
|
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
tf.config.experimental.disable_mlir_bridge()
|
||||||
super(nlp_on_device_embedding_test.GradNormTest, self).setUp()
|
super(nlp_on_device_embedding_test.GradNormTest, self).setUp()
|
||||||
self.strategy = common_test_utils.create_tpu_strategy()
|
self.strategy = common_test_utils.create_tpu_strategy()
|
||||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
|
|
@ -19,11 +19,13 @@ import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
# Tensorflow aliases.
|
# Tensorflow aliases.
|
||||||
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]]
|
Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor]
|
||||||
|
|
||||||
|
PackedTensors = Union[Tensor, Iterable[Tensor], Mapping[str, Tensor]]
|
||||||
|
|
||||||
InputTensors = PackedTensors
|
InputTensors = PackedTensors
|
||||||
|
|
||||||
OutputTensors = Union[tf.Tensor, Iterable[tf.Tensor]]
|
OutputTensors = Union[Tensor, Iterable[Tensor]]
|
||||||
|
|
||||||
BatchSize = Union[int, tf.Tensor]
|
BatchSize = Union[int, tf.Tensor]
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,8 @@ py_library(
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils",
|
||||||
|
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
|
||||||
|
"//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,6 +37,7 @@ py_test(
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "dp_keras_model_distributed_test",
|
name = "dp_keras_model_distributed_test",
|
||||||
|
timeout = "long",
|
||||||
srcs = ["dp_keras_model_distributed_test.py"],
|
srcs = ["dp_keras_model_distributed_test.py"],
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
|
|
|
@ -13,16 +13,37 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Keras Model for vectorized dpsgd with XLA acceleration."""
|
"""Keras Model for vectorized dpsgd with XLA acceleration."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
|
||||||
|
|
||||||
|
|
||||||
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
|
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SparsityPreservingDPSGDConfig:
|
||||||
|
"""Config for adding sparsity preserving noise to the gradients."""
|
||||||
|
|
||||||
|
# The ratio of how the noise is split between partition selection and gradient
|
||||||
|
# noise.
|
||||||
|
sparse_selection_ratio: float = 0.0
|
||||||
|
# The threshold to use for private partition selection.
|
||||||
|
sparse_selection_threshold: int = 100
|
||||||
|
# A `LayerRegistry` instance containing functions that help compute
|
||||||
|
# contribution counts for sparse layers. See
|
||||||
|
# `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
|
||||||
|
# more details.
|
||||||
|
sparse_selection_layer_registry: snlr.LayerRegistry | None = None
|
||||||
|
|
||||||
|
|
||||||
def make_dp_model_class(cls):
|
def make_dp_model_class(cls):
|
||||||
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
||||||
|
|
||||||
|
@ -104,6 +125,9 @@ def make_dp_model_class(cls):
|
||||||
num_microbatches=None,
|
num_microbatches=None,
|
||||||
use_xla=True,
|
use_xla=True,
|
||||||
layer_registry=None,
|
layer_registry=None,
|
||||||
|
sparsity_preserving_dpsgd_config: (
|
||||||
|
SparsityPreservingDPSGDConfig | None
|
||||||
|
) = None,
|
||||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -118,6 +142,9 @@ def make_dp_model_class(cls):
|
||||||
help compute gradient norms quickly. See
|
help compute gradient norms quickly. See
|
||||||
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||||
more details.
|
more details.
|
||||||
|
sparsity_preserving_dpsgd_config: If provided, uses partition selection
|
||||||
|
and sparse noise for privatizing sparse gradients for layers in
|
||||||
|
`sparsity_preserving_dpsgd_config.sparse_selection_layer_registry`.
|
||||||
*args: These will be passed on to the base class `__init__` method.
|
*args: These will be passed on to the base class `__init__` method.
|
||||||
**kwargs: These will be passed on to the base class `__init__` method.
|
**kwargs: These will be passed on to the base class `__init__` method.
|
||||||
"""
|
"""
|
||||||
|
@ -127,6 +154,8 @@ def make_dp_model_class(cls):
|
||||||
self._layer_registry = layer_registry
|
self._layer_registry = layer_registry
|
||||||
self._clipping_loss = None
|
self._clipping_loss = None
|
||||||
|
|
||||||
|
self._sparsity_preserving_dpsgd_config = sparsity_preserving_dpsgd_config
|
||||||
|
|
||||||
# Given that `num_microbatches` was added as an argument after the fact,
|
# Given that `num_microbatches` was added as an argument after the fact,
|
||||||
# this check helps detect unintended calls to the earlier API.
|
# this check helps detect unintended calls to the earlier API.
|
||||||
# In particular, boolean values supplied to `use_xla` in the earlier API
|
# In particular, boolean values supplied to `use_xla` in the earlier API
|
||||||
|
@ -274,27 +303,94 @@ def make_dp_model_class(cls):
|
||||||
# trick, and uses these norms to clip the per-example gradients.
|
# trick, and uses these norms to clip the per-example gradients.
|
||||||
# NOTE: Reshaping of the input according to the effective number of
|
# NOTE: Reshaping of the input according to the effective number of
|
||||||
# microbatches is done here.
|
# microbatches is done here.
|
||||||
|
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||||
|
|
||||||
|
sparse_noise_layer_registry = None
|
||||||
|
if self._sparsity_preserving_dpsgd_config is not None:
|
||||||
|
sparse_noise_layer_registry = (
|
||||||
|
self._sparsity_preserving_dpsgd_config.sparse_selection_layer_registry
|
||||||
|
)
|
||||||
|
registry_generator_fn = (
|
||||||
|
gradient_clipping_utils.get_registry_generator_fn(
|
||||||
|
tape=tape,
|
||||||
|
layer_registry=self._layer_registry,
|
||||||
|
sparse_noise_layer_registry=sparse_noise_layer_registry,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
layer_grad_vars, registry_fn_outputs_list = (
|
||||||
|
gradient_clipping_utils.model_forward_backward_pass(
|
||||||
|
tape=tape,
|
||||||
|
input_model=self,
|
||||||
|
x_batch=x,
|
||||||
|
y_batch=y,
|
||||||
|
registry_generator_fn=registry_generator_fn,
|
||||||
|
weight_batch=weights,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
trainable_vars=self.trainable_variables,
|
||||||
|
)
|
||||||
|
)
|
||||||
clipped_grads, y_pred, clipping_loss = (
|
clipped_grads, y_pred, clipping_loss = (
|
||||||
clip_grads.compute_clipped_gradients_and_outputs(
|
clip_grads.compute_clipped_gradients_and_outputs(
|
||||||
input_model=self,
|
input_model=self,
|
||||||
|
registry_fn_outputs_list=registry_fn_outputs_list,
|
||||||
|
layer_grad_vars=layer_grad_vars,
|
||||||
x_batch=x,
|
x_batch=x,
|
||||||
y_batch=y,
|
y_batch=y,
|
||||||
weight_batch=weights,
|
weight_batch=weights,
|
||||||
l2_norm_clip=self._l2_norm_clip,
|
l2_norm_clip=self._l2_norm_clip,
|
||||||
layer_registry=self._layer_registry,
|
|
||||||
num_microbatches=self._num_microbatches,
|
num_microbatches=self._num_microbatches,
|
||||||
clipping_loss=self._clipping_loss,
|
clipping_loss=self._clipping_loss,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
|
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
|
||||||
if self._noise_multiplier > 0:
|
noise_multiplier, noise_multiplier_sparse = self._noise_multiplier, None
|
||||||
|
contribution_counts = None
|
||||||
|
if self._sparsity_preserving_dpsgd_config is not None:
|
||||||
|
logging.info('Using sparse noise.')
|
||||||
|
|
||||||
|
varname_to_contribution_counts_fns = (
|
||||||
|
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
|
||||||
|
registry_fn_outputs_list,
|
||||||
|
self.trainable_variables,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
contribution_counts = sparse_noise_utils.get_contribution_counts(
|
||||||
|
self.trainable_variables,
|
||||||
|
clipped_grads,
|
||||||
|
varname_to_contribution_counts_fns,
|
||||||
|
)
|
||||||
|
|
||||||
|
noise_multiplier_sparse, noise_multiplier = (
|
||||||
|
sparse_noise_utils.split_noise_multiplier(
|
||||||
|
noise_multiplier,
|
||||||
|
self._sparsity_preserving_dpsgd_config.sparse_selection_ratio,
|
||||||
|
contribution_counts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
'Split noise multiplier for gradient noise: %s and partition'
|
||||||
|
' selection: %s',
|
||||||
|
noise_multiplier,
|
||||||
|
noise_multiplier_sparse,
|
||||||
|
)
|
||||||
|
|
||||||
|
if noise_multiplier > 0:
|
||||||
|
sparse_noise_config = None
|
||||||
|
if self._sparsity_preserving_dpsgd_config is not None:
|
||||||
|
sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
|
||||||
|
sparse_noise_multiplier=noise_multiplier_sparse,
|
||||||
|
sparse_selection_threshold=self._sparsity_preserving_dpsgd_config.sparse_selection_threshold,
|
||||||
|
sparse_contribution_counts=contribution_counts,
|
||||||
|
)
|
||||||
grads = noise_utils.add_aggregate_noise(
|
grads = noise_utils.add_aggregate_noise(
|
||||||
clipped_grads,
|
clipped_grads,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
self._l2_norm_clip,
|
self._l2_norm_clip,
|
||||||
self._noise_multiplier,
|
noise_multiplier,
|
||||||
loss_reduction=None,
|
loss_reduction=None,
|
||||||
loss_model=self,
|
loss_model=self,
|
||||||
|
sparse_noise_config=sparse_noise_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
grads = clipped_grads
|
grads = clipped_grads
|
||||||
|
|
|
@ -261,7 +261,8 @@ class AttackInputData:
|
||||||
# Contains ground-truth classes. For single-label classification, classes are
|
# Contains ground-truth classes. For single-label classification, classes are
|
||||||
# assumed to be integers starting from 0. For multi-label classification,
|
# assumed to be integers starting from 0. For multi-label classification,
|
||||||
# label is assumed to be multi-hot, i.e., labels is a binary array of shape
|
# label is assumed to be multi-hot, i.e., labels is a binary array of shape
|
||||||
# (num_examples, num_classes).
|
# (num_examples, num_classes). Additionally used to compute the loss when
|
||||||
|
# loss_train/test is not provided. Leave empty for non-classification models.
|
||||||
labels_train: Optional[np.ndarray] = None
|
labels_train: Optional[np.ndarray] = None
|
||||||
labels_test: Optional[np.ndarray] = None
|
labels_test: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
@ -270,7 +271,7 @@ class AttackInputData:
|
||||||
sample_weight_test: Optional[np.ndarray] = None
|
sample_weight_test: Optional[np.ndarray] = None
|
||||||
|
|
||||||
# Explicitly specified loss. If provided, this is used instead of deriving
|
# Explicitly specified loss. If provided, this is used instead of deriving
|
||||||
# loss from logits and labels
|
# loss from logits and labels.
|
||||||
loss_train: Optional[np.ndarray] = None
|
loss_train: Optional[np.ndarray] = None
|
||||||
loss_test: Optional[np.ndarray] = None
|
loss_test: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,19 @@ licenses(["notice"])
|
||||||
py_library(
|
py_library(
|
||||||
name = "sparse_noise_utils",
|
name = "sparse_noise_utils",
|
||||||
srcs = ["sparse_noise_utils.py"],
|
srcs = ["sparse_noise_utils.py"],
|
||||||
|
deps = [
|
||||||
|
":type_aliases",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "sparse_noise_utils_test",
|
name = "sparse_noise_utils_test",
|
||||||
srcs = ["sparse_noise_utils_test.py"],
|
srcs = ["sparse_noise_utils_test.py"],
|
||||||
deps = [":sparse_noise_utils"],
|
deps = [
|
||||||
|
":sparse_noise_utils",
|
||||||
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
|
|
@ -16,10 +16,13 @@
|
||||||
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
|
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import collections
|
||||||
from typing import Mapping, Optional, Sequence
|
from typing import Mapping, Optional, Sequence
|
||||||
|
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
|
||||||
import tensorflow_probability as tfp
|
import tensorflow_probability as tfp
|
||||||
|
|
||||||
|
|
||||||
|
@ -163,7 +166,7 @@ def sample_true_positive_indices(
|
||||||
tf.shape(contribution_count_values),
|
tf.shape(contribution_count_values),
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
stddev=noise_multiplier,
|
stddev=noise_multiplier,
|
||||||
dtype=tf.float32,
|
dtype=contribution_count_values.dtype,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
noised_contribution_counts_indices = contribution_counts.indices[
|
noised_contribution_counts_indices = contribution_counts.indices[
|
||||||
|
@ -278,7 +281,7 @@ def add_sparse_gradient_noise(
|
||||||
"""
|
"""
|
||||||
filtered_grad_values = tf.gather(grad, indices)
|
filtered_grad_values = tf.gather(grad, indices)
|
||||||
sparse_noise_values = tf.random.normal(
|
sparse_noise_values = tf.random.normal(
|
||||||
filtered_grad_values.shape, mean=0.0, stddev=noise_stddev
|
tf.shape(filtered_grad_values), mean=0.0, stddev=noise_stddev
|
||||||
)
|
)
|
||||||
filtered_noised_grad_values = filtered_grad_values + sparse_noise_values
|
filtered_noised_grad_values = filtered_grad_values + sparse_noise_values
|
||||||
return tf.IndexedSlices(
|
return tf.IndexedSlices(
|
||||||
|
@ -288,15 +291,60 @@ def add_sparse_gradient_noise(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_varname_to_contribution_counts_fns(
|
||||||
|
registry_fn_outputs_list: Sequence[
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput
|
||||||
|
],
|
||||||
|
trainable_vars: Sequence[tf.Variable],
|
||||||
|
) -> Mapping[str, type_aliases.ContributionCountHistogramFn]:
|
||||||
|
"""Extracts a map of contribution count fns from generator outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput`
|
||||||
|
instances returned by
|
||||||
|
`gradient_clipping_utils.model_forward_backward_pass`.
|
||||||
|
trainable_vars: A list of trainable variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `dict` from varname to contribution counts functions
|
||||||
|
"""
|
||||||
|
if trainable_vars is not None:
|
||||||
|
# Create a set using `ref()` for fast set membership check. tf.Variable
|
||||||
|
# itself is not hashable.
|
||||||
|
trainable_vars = set([v.ref() for v in trainable_vars])
|
||||||
|
|
||||||
|
varname_to_contribution_counts_fns = collections.defaultdict(list)
|
||||||
|
for registry_fn_output in registry_fn_outputs_list:
|
||||||
|
if trainable_vars is None or any(
|
||||||
|
w.ref() in trainable_vars
|
||||||
|
for w in registry_fn_output.layer_trainable_weights
|
||||||
|
):
|
||||||
|
if registry_fn_output.varname_to_count_contribution_fn is not None:
|
||||||
|
duplicate_varnames = set(
|
||||||
|
registry_fn_output.varname_to_count_contribution_fn.keys()
|
||||||
|
) & set(varname_to_contribution_counts_fns.keys())
|
||||||
|
if duplicate_varnames:
|
||||||
|
raise ValueError(
|
||||||
|
'Duplicate varnames: {duplicate_varnames} found in contribution'
|
||||||
|
' counts functions.'
|
||||||
|
)
|
||||||
|
varname_to_contribution_counts_fns.update(
|
||||||
|
registry_fn_output.varname_to_count_contribution_fn
|
||||||
|
)
|
||||||
|
return varname_to_contribution_counts_fns
|
||||||
|
|
||||||
|
|
||||||
def get_contribution_counts(
|
def get_contribution_counts(
|
||||||
trainable_vars: list[tf.Variable],
|
trainable_vars: Sequence[tf.Variable],
|
||||||
grads: list[tf.Tensor],
|
grads: Sequence[tf.Tensor],
|
||||||
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor],
|
varname_to_contribution_counts_fns: Mapping[
|
||||||
) -> list[tf.Tensor | None]:
|
str, type_aliases.ContributionCountHistogramFn
|
||||||
|
],
|
||||||
|
) -> Sequence[type_aliases.ContributionCountHistogram | None]:
|
||||||
"""Gets the contribution counts for each variable in the Model.
|
"""Gets the contribution counts for each variable in the Model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainable_vars: A list of the trainable variables in the Model.
|
trainable_vars: A list of trainable variables.
|
||||||
grads: A corresponding list of gradients for each trainable variable.
|
grads: A corresponding list of gradients for each trainable variable.
|
||||||
varname_to_contribution_counts_fns: A mapping from variable name to a list
|
varname_to_contribution_counts_fns: A mapping from variable name to a list
|
||||||
of functions to get the contribution counts for that variable.
|
of functions to get the contribution counts for that variable.
|
||||||
|
@ -314,15 +362,10 @@ def get_contribution_counts(
|
||||||
if var.name not in varname_to_contribution_counts_fns:
|
if var.name not in varname_to_contribution_counts_fns:
|
||||||
contribution_counts_list.append(None)
|
contribution_counts_list.append(None)
|
||||||
continue
|
continue
|
||||||
contribution_counts_fns = varname_to_contribution_counts_fns[var.name]
|
contribution_counts_fn = varname_to_contribution_counts_fns[var.name]
|
||||||
if not contribution_counts_fns or not contribution_counts_fns[0]:
|
if not contribution_counts_fn:
|
||||||
contribution_counts_list.append(None)
|
contribution_counts_list.append(None)
|
||||||
continue
|
continue
|
||||||
if len(contribution_counts_fns) > 1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
'Sparse noise is not supported for shared weight variables.'
|
|
||||||
)
|
|
||||||
contribution_counts_fn = contribution_counts_fns[0]
|
|
||||||
contribution_counts = contribution_counts_fn(grad)
|
contribution_counts = contribution_counts_fn(grad)
|
||||||
contribution_counts_list.append(contribution_counts)
|
contribution_counts_list.append(contribution_counts)
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
|
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -368,7 +369,7 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
tf.ones((1, 2)),
|
tf.ones((1, 2)),
|
||||||
]
|
]
|
||||||
varname_to_contribution_counts_fns = {
|
varname_to_contribution_counts_fns = {
|
||||||
'var1:0': [lambda grad: 1.0],
|
'var1:0': lambda grad: 1.0,
|
||||||
'var2:0': None,
|
'var2:0': None,
|
||||||
}
|
}
|
||||||
contribution_counts = sparse_noise_utils.get_contribution_counts(
|
contribution_counts = sparse_noise_utils.get_contribution_counts(
|
||||||
|
@ -436,6 +437,96 @@ class SparseNoiseUtilsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
|
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_extract_varname_to_contribution_counts_fns(self):
|
||||||
|
def fn1(_):
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def fn2(_):
|
||||||
|
return 2.0
|
||||||
|
|
||||||
|
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
|
||||||
|
var2 = tf.Variable(tf.ones((1, 2)), name='var2')
|
||||||
|
var3 = tf.Variable(tf.ones((1, 2)), name='var3')
|
||||||
|
|
||||||
|
registry_fn_outputs_list = [
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id='layer1',
|
||||||
|
layer_vars=[var1],
|
||||||
|
layer_sqr_norm_fn=None,
|
||||||
|
layer_trainable_weights=[var1],
|
||||||
|
varname_to_count_contribution_fn=None,
|
||||||
|
),
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id='layer2',
|
||||||
|
layer_vars=[var2],
|
||||||
|
layer_sqr_norm_fn=None,
|
||||||
|
layer_trainable_weights=[var2],
|
||||||
|
varname_to_count_contribution_fn={
|
||||||
|
'var2:0': [fn2],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id='layer3',
|
||||||
|
layer_vars=[var3],
|
||||||
|
layer_sqr_norm_fn=None,
|
||||||
|
layer_trainable_weights=[var3],
|
||||||
|
varname_to_count_contribution_fn={
|
||||||
|
'var3:0': [fn1],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
expected_varname_to_contribution_counts_fns = {
|
||||||
|
'var2:0': [fn2],
|
||||||
|
'var3:0': [fn1],
|
||||||
|
}
|
||||||
|
varname_to_contribution_counts_fns = (
|
||||||
|
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
|
||||||
|
registry_fn_outputs_list,
|
||||||
|
trainable_vars=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
varname_to_contribution_counts_fns,
|
||||||
|
expected_varname_to_contribution_counts_fns,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_extract_varname_to_contribution_counts_fns_duplicate_varnames(self):
|
||||||
|
def fn1(_):
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def fn2(_):
|
||||||
|
return 2.0
|
||||||
|
|
||||||
|
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
|
||||||
|
var2 = tf.Variable(tf.ones((1, 2)), name='var1')
|
||||||
|
|
||||||
|
registry_fn_outputs_list = [
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id='layer1',
|
||||||
|
layer_vars=[var1],
|
||||||
|
layer_sqr_norm_fn=None,
|
||||||
|
layer_trainable_weights=[var1],
|
||||||
|
varname_to_count_contribution_fn={
|
||||||
|
'var1:0': [fn1],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
|
||||||
|
layer_id='layer2',
|
||||||
|
layer_vars=[var2],
|
||||||
|
layer_sqr_norm_fn=None,
|
||||||
|
layer_trainable_weights=[var2],
|
||||||
|
varname_to_count_contribution_fn={
|
||||||
|
'var1:0': [fn2],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
|
||||||
|
registry_fn_outputs_list,
|
||||||
|
trainable_vars=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -19,10 +19,10 @@ import tensorflow as tf
|
||||||
|
|
||||||
InputArgs = Sequence[Any]
|
InputArgs = Sequence[Any]
|
||||||
InputKwargs = Mapping[str, Any]
|
InputKwargs = Mapping[str, Any]
|
||||||
SparseGradient = tf.IndexedSlices
|
SparseGradient = tf.IndexedSlices | tf.SparseTensor
|
||||||
ContributionCountHistogram = tf.SparseTensor
|
ContributionCountHistogram = tf.SparseTensor
|
||||||
ContributionCountHistogramFn = Callable[
|
ContributionCountHistogramFn = Callable[
|
||||||
[SparseGradient], Mapping[str, ContributionCountHistogram]
|
[SparseGradient], ContributionCountHistogram
|
||||||
]
|
]
|
||||||
NumMicrobatches = int | tf.Tensor
|
NumMicrobatches = int | tf.Tensor
|
||||||
SparsityPreservingNoiseLayerRegistryFunction = Callable[
|
SparsityPreservingNoiseLayerRegistryFunction = Callable[
|
||||||
|
|
Loading…
Reference in a new issue