forked from 626_privacy/tensorflow_privacy
Generalize the internal API to allow for more general models + layers.
PiperOrigin-RevId: 509518753
This commit is contained in:
parent
6ee988885a
commit
410814ec39
9 changed files with 340 additions and 300 deletions
|
@ -9,8 +9,8 @@ py_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "layer_registry_factories",
|
name = "layer_registry",
|
||||||
srcs = ["layer_registry_factories.py"],
|
srcs = ["layer_registry.py"],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ py_library(
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
":gradient_clipping_utils",
|
":gradient_clipping_utils",
|
||||||
":layer_registry_factories",
|
":layer_registry",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,6 +31,6 @@ py_test(
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
":clip_grads",
|
":clip_grads",
|
||||||
":layer_registry_factories",
|
":layer_registry",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,40 +25,42 @@ import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
|
||||||
|
|
||||||
def combine_pre_and_post_sqr_norms(pre_sqr_norm, post_grad, layer_hash):
|
def get_registry_generator_fn(tape, layer_registry):
|
||||||
"""Combines pre and post-activation tensors for a given variable.
|
"""Creates the generator function for `compute_gradient_norms()`."""
|
||||||
|
if layer_registry is None:
|
||||||
The logic for combining norms depends on the variable's underlying layer.
|
# Needed for backwards compatibility.
|
||||||
|
registry_generator_fn = None
|
||||||
Args:
|
|
||||||
pre_sqr_norm: A `tf.Tensor` whose first dimension is the batch dimension.
|
|
||||||
Contains squared norms that are related to the pre-activation Tensor.
|
|
||||||
post_grad: A `tf.Tensor` whose first dimension is the batch dimension.
|
|
||||||
Contains gradients that are related to the post-activation Tensor.
|
|
||||||
layer_hash: A `float` that is the hash of the variable's underlying layer
|
|
||||||
class.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
|
||||||
per-example loss function with respect to the given variable.
|
|
||||||
"""
|
|
||||||
post_sqr_grads = tf.square(post_grad)
|
|
||||||
if layer_hash == hash(tf.keras.layers.Embedding):
|
|
||||||
scaled_grads = tf.expand_dims(pre_sqr_norm, axis=-1) * post_sqr_grads
|
|
||||||
reduction_axes = tf.range(1, tf.rank(scaled_grads))
|
|
||||||
return tf.reduce_sum(scaled_grads, axis=reduction_axes)
|
|
||||||
else:
|
else:
|
||||||
reduction_axes = tf.range(1, tf.rank(post_sqr_grads))
|
|
||||||
post_sqr_norm = tf.reduce_sum(post_sqr_grads, axis=reduction_axes)
|
def registry_generator_fn(layer_instance, args, kwargs):
|
||||||
return pre_sqr_norm * post_sqr_norm
|
if layer_instance.trainable_variables:
|
||||||
|
# Only trainable variables factor into the gradient.
|
||||||
|
if not layer_registry.is_elem(layer_instance):
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Layer %s is not in the registry of known layers that can '
|
||||||
|
'be used for efficient gradient clipping.'
|
||||||
|
% layer_instance.__class__.__name__
|
||||||
|
)
|
||||||
|
registry_fn = layer_registry.lookup(layer_instance)
|
||||||
|
(layer_vars, transform, layer_sqr_norm_fn) = registry_fn(
|
||||||
|
layer_instance, args
|
||||||
|
)
|
||||||
|
if tape is not None:
|
||||||
|
tape.watch(layer_vars)
|
||||||
|
layer_outputs = transform(layer_vars) if transform else layer_vars
|
||||||
|
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
|
||||||
|
else:
|
||||||
|
# Non-trainable layer.
|
||||||
|
return layer_instance(*args, **kwargs), None
|
||||||
|
|
||||||
|
return registry_generator_fn
|
||||||
|
|
||||||
|
|
||||||
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
||||||
"""Computes the per-example loss gradient norms for given data.
|
"""Computes the per-example loss gradient norms for given data.
|
||||||
|
|
||||||
Applies the approach given in https://arxiv.org/pdf/2009.03106.pdf, except
|
Applies a variant of the approach given in
|
||||||
the batch matrix multiplication operation in Algorithm 2 is replaced with
|
https://arxiv.org/pdf/2009.03106.pdf
|
||||||
the computation of two norm computations.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_model: The `tf.keras.Model` from which to obtain the layers from. The
|
input_model: The `tf.keras.Model` from which to obtain the layers from. The
|
||||||
|
@ -68,24 +70,22 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
||||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||||
must be the batch dimension. The number of examples should match the
|
must be the batch dimension. The number of examples should match the
|
||||||
number of examples in `x_batch`.
|
number of examples in `x_batch`.
|
||||||
layer_registry: A `dict` of layers that support "fast" gradient norm
|
layer_registry: A `LayerRegistry` instance containing functions that help
|
||||||
computations. The key is the class of the layer and the value is a
|
compute gradient norms quickly. See
|
||||||
function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where
|
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||||
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
|
more details.
|
||||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
|
||||||
trainable weights (see `layer_registry_factories.py` for examples).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
|
||||||
per-example loss function.
|
per-example loss function.
|
||||||
"""
|
"""
|
||||||
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
|
||||||
# First loop computes the norms of the layer inputs, caches these inputs,
|
registry_generator_fn = get_registry_generator_fn(tape, layer_registry)
|
||||||
# and computes the summed loss.
|
# First loop computes the model outputs, summed loss, and generator outputs.
|
||||||
with tape:
|
with tape:
|
||||||
model_outputs, pre_norm_list, var_list, layer_hash_list = (
|
model_outputs, generator_outputs_list = (
|
||||||
gradient_clipping_utils.forward_norm_pass(
|
gradient_clipping_utils.model_forward_pass(
|
||||||
input_model, x_batch, tape, layer_registry
|
input_model, x_batch, generator_fn=registry_generator_fn
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Ignore the original loss function's reduction to get per-example loss.
|
# Ignore the original loss function's reduction to get per-example loss.
|
||||||
|
@ -94,20 +94,19 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
|
||||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||||
losses = per_example_loss_fn(y_batch, model_outputs)
|
losses = per_example_loss_fn(y_batch, model_outputs)
|
||||||
summed_loss = tf.reduce_sum(losses)
|
summed_loss = tf.reduce_sum(losses)
|
||||||
# Second loop computes the norm of the gradient of the loss with respect to
|
# Unwrap the generator outputs so that the next loop avoids duplicating
|
||||||
# the pre-activation tensors, and multiplies these norms with the results of
|
# backprop ops.
|
||||||
# the first loop.
|
filtered_outputs = [t for t in generator_outputs_list if t is not None]
|
||||||
full_norm_list = []
|
vars_list = [a for (a, b) in filtered_outputs]
|
||||||
grads = tape.gradient(summed_loss, var_list)
|
sqr_norm_fns_list = [b for (a, b) in filtered_outputs]
|
||||||
for i in range(len(var_list)):
|
# Second loop evaluates the squared L2 norm functions and appends the results.
|
||||||
full_norm = combine_pre_and_post_sqr_norms(
|
grads_list = tape.gradient(summed_loss, vars_list)
|
||||||
pre_norm_list[i], grads[i], layer_hash_list[i]
|
sqr_norm_list = []
|
||||||
)
|
for grads, f in zip(grads_list, sqr_norm_fns_list):
|
||||||
full_norm_list.append(full_norm)
|
sqr_norm_list.append(f(grads))
|
||||||
del tape
|
del tape
|
||||||
# Post-processing for compatibility with non-eager mode (very annoying).
|
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
|
||||||
full_norm_tsr = tf.stack(full_norm_list, axis=1)
|
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||||
return tf.sqrt(tf.reduce_sum(full_norm_tsr, axis=1))
|
|
||||||
|
|
||||||
|
|
||||||
def compute_clip_weights(l2_norm_clip, gradient_norms):
|
def compute_clip_weights(l2_norm_clip, gradient_norms):
|
||||||
|
|
|
@ -17,7 +17,7 @@ from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
@ -81,7 +81,7 @@ def get_computed_and_true_norms(
|
||||||
)
|
)
|
||||||
y_pred = model(x_input)
|
y_pred = model(x_input)
|
||||||
y_batch = tf.ones_like(y_pred)
|
y_batch = tf.ones_like(y_pred)
|
||||||
registry = layer_registry_factories.make_default_layer_registry()
|
registry = layer_registry.make_default_layer_registry()
|
||||||
computed_norms = clip_grads.compute_gradient_norms(
|
computed_norms = clip_grads.compute_gradient_norms(
|
||||||
model, x_input, y_batch, layer_registry=registry
|
model, x_input, y_batch, layer_registry=registry
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,8 +28,19 @@ def has_internal_compute_graph(input_object):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward_norm_pass(input_model, x_batch, tape, layer_registry):
|
def _get_internal_layers(input_layer):
|
||||||
"""Does a forward pass of a model and returns some useful intermediates.
|
"""Returns a list of layers that are nested within a given layer."""
|
||||||
|
internal_layers = []
|
||||||
|
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
|
||||||
|
for layer in input_layer.layers:
|
||||||
|
internal_layers.extend(_get_internal_layers(layer))
|
||||||
|
else:
|
||||||
|
internal_layers.append(input_layer)
|
||||||
|
return internal_layers
|
||||||
|
|
||||||
|
|
||||||
|
def model_forward_pass(input_model, inputs, generator_fn=None):
|
||||||
|
"""Does a forward pass of a model and returns useful intermediates.
|
||||||
|
|
||||||
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
||||||
_run_internal_graph() method in the functional.Functional class. Hence,
|
_run_internal_graph() method in the functional.Functional class. Hence,
|
||||||
|
@ -37,44 +48,42 @@ def forward_norm_pass(input_model, x_batch, tape, layer_registry):
|
||||||
instance is an instance of the functional.Functional class.
|
instance is an instance of the functional.Functional class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_model: A Keras functional model to compute the quantities for.
|
input_model: A `tf.keras.Model` to compute the quantities for.
|
||||||
x_batch: A collection of Tensors to be fed into the input layer of the
|
inputs: Arbitrary input to be fed into the input layer of the model. It is
|
||||||
model.
|
expected that `input_model(inputs)` returns a valid output.
|
||||||
tape: A tf.GradientTape() instance used to record certain operations.
|
generator_fn: A function with signature `(tf.keras.layers.Layer, Any, Any)
|
||||||
Assumes that this function is being called inside of this tape.
|
-> (tf.Tensor, Any)`, where we require `generator_fn(layer_instance, args,
|
||||||
layer_registry: A dictionary of layers that support "fast" gradient norm
|
kwargs)[0] == layer_instance(*args, **kwargs)`. If `None`, then
|
||||||
computations. The key is the class of the layer and the value is a
|
`layer_fn(layer_instance, args, kwargs)[1] == None`.
|
||||||
function that returns a triple (norm_list, var_list, transform). For more
|
|
||||||
details, see `layer_registry_factories.py`.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Four objects (outputs, norm_list, var_list, layer_hash_list). The first are
|
A `tuple` `(outputs, generator_outputs_list)`. `outputs` is the
|
||||||
the outputs that are generated as a result of a forward pass. The second is
|
`tf.Tensor` that is generated as a result of a forward pass.
|
||||||
either `None` or a collection of squared norms for each example that helps
|
`generator_outputs_list` is a `list` whose i-th entry is the output of
|
||||||
in computing the gradient norm of an example (if such a "nice" computation
|
`generator_fn(lyr, args, kwargs)[1]` where `lyr` is the i-th
|
||||||
exists). The third is an ordered list of tf.Tensor() objects that are
|
layer when the compute graph of `input_model` is traversed in BFS order.
|
||||||
intended to be evaluated with respect to the summed loss of the model. The
|
|
||||||
fourth is a list whose i-th element is the hash of the layer class of the
|
|
||||||
i-th element of var_list.
|
|
||||||
"""
|
"""
|
||||||
# TODO: Avoid or remove the references to protected methods of `input_model`. # pylint: disable=g-bad-todo
|
# TODO: Avoid or remove the references to protected methods of `input_model`. # pylint: disable=g-bad-todo
|
||||||
|
|
||||||
|
# Default generator.
|
||||||
|
generator_outputs_list = []
|
||||||
|
if generator_fn is None:
|
||||||
|
|
||||||
|
def generator_fn(layer_instance, args, kwargs):
|
||||||
|
return layer_instance(*args, **kwargs), None
|
||||||
|
|
||||||
# Prepare the inputs and BFS variables.
|
# Prepare the inputs and BFS variables.
|
||||||
flattened_inputs = input_model._flatten_to_reference_inputs(x_batch) # pylint: disable=protected-access
|
flattened_inputs = input_model._flatten_to_reference_inputs(inputs) # pylint: disable=protected-access
|
||||||
tensor_dict = {}
|
tensor_dict = {}
|
||||||
tensor_usage_count = input_model._tensor_usage_count # pylint: disable=protected-access
|
tensor_usage_count = input_model._tensor_usage_count # pylint: disable=protected-access
|
||||||
for x, y in zip(input_model.inputs, flattened_inputs):
|
for x, y in zip(input_model.inputs, flattened_inputs):
|
||||||
y = input_model._conform_to_reference_input(y, ref_input=x) # pylint: disable=protected-access
|
y = input_model._conform_to_reference_input(y, ref_input=x) # pylint: disable=protected-access
|
||||||
x_id = str(id(x))
|
x_id = str(id(x))
|
||||||
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||||||
|
|
||||||
# Main computations.
|
|
||||||
nodes_by_depth = input_model._nodes_by_depth # pylint: disable=protected-access
|
nodes_by_depth = input_model._nodes_by_depth # pylint: disable=protected-access
|
||||||
depth_keys = list(nodes_by_depth.keys())
|
depth_keys = list(nodes_by_depth.keys())
|
||||||
depth_keys.sort(reverse=True)
|
depth_keys.sort(reverse=True)
|
||||||
var_list = []
|
|
||||||
norm_list = []
|
|
||||||
layer_hash_list = []
|
|
||||||
node_outputs = None
|
|
||||||
# Perform BFS feedforward computations.
|
# Perform BFS feedforward computations.
|
||||||
for depth in depth_keys:
|
for depth in depth_keys:
|
||||||
for node in nodes_by_depth[depth]:
|
for node in nodes_by_depth[depth]:
|
||||||
|
@ -85,62 +94,28 @@ def forward_norm_pass(input_model, x_batch, tape, layer_registry):
|
||||||
args, kwargs = node.map_arguments(tensor_dict)
|
args, kwargs = node.map_arguments(tensor_dict)
|
||||||
if has_internal_compute_graph(node.layer):
|
if has_internal_compute_graph(node.layer):
|
||||||
# If this node has an internal computational graph, we can recurse.
|
# If this node has an internal computational graph, we can recurse.
|
||||||
node_outputs, node_norm_list, node_var_list, node_layer_hash_list = (
|
node_layer_outputs, node_generator_outputs = model_forward_pass(
|
||||||
forward_norm_pass(node.layer, args, tape, layer_registry)
|
node.layer, args, generator_fn
|
||||||
)
|
)
|
||||||
var_list += node_var_list
|
generator_outputs_list.extend(node_generator_outputs)
|
||||||
norm_list += node_norm_list
|
|
||||||
layer_hash_list += node_layer_hash_list
|
|
||||||
else:
|
else:
|
||||||
# Either pass through or record some metadata.
|
# Otherwise, we parse the node directly.
|
||||||
if not node.layer.trainable_variables:
|
node_layers = _get_internal_layers(node.layer)
|
||||||
node_outputs = node.layer(*args, **kwargs)
|
for layer in node_layers:
|
||||||
else:
|
node_layer_outputs, layer_generator_outputs = generator_fn(
|
||||||
lyr_hash = hash(node.layer.__class__)
|
layer, args, kwargs
|
||||||
if lyr_hash not in layer_registry:
|
|
||||||
raise NotImplementedError(
|
|
||||||
'Layer %s is not in the registry of known layers that can'
|
|
||||||
'be used for efficient gradient clipping.'
|
|
||||||
% node.layer.__class__.__name__
|
|
||||||
)
|
)
|
||||||
lyr = layer_registry[lyr_hash]
|
generator_outputs_list.append(layer_generator_outputs)
|
||||||
node_norms, node_vars, transform = lyr(node.layer, args)
|
args = (node_layer_outputs,)
|
||||||
tape.watch(node_vars)
|
kwargs = {}
|
||||||
node_outputs = transform(node_vars) if transform else node_vars
|
|
||||||
var_list.append(node_vars)
|
# Update the current dictionary of inputs for the next node.
|
||||||
norm_list.append(node_norms)
|
for x_id, y in zip(
|
||||||
layer_hash_list.append(lyr_hash)
|
node.flat_output_ids, tf.nest.flatten(node_layer_outputs)
|
||||||
# update the current dictionary of inputs for the next node.
|
):
|
||||||
for x_id, y in zip(node.flat_output_ids, tf.nest.flatten(node_outputs)):
|
|
||||||
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||||||
|
|
||||||
return node_outputs, norm_list, var_list, layer_hash_list
|
return node_layer_outputs, generator_outputs_list
|
||||||
|
|
||||||
|
|
||||||
def get_trainable_hidden_layers(input_model):
|
|
||||||
"""Obtains the trainable hidden layers of a Keras model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_model: The Keras model to obtain the layers from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of Keras layers where the tensorflow.keras.layers.Layer
|
|
||||||
ancestor class MUST precede any existing tensorflow.keras.models.Model
|
|
||||||
ancestor class.
|
|
||||||
"""
|
|
||||||
hidden_layers = []
|
|
||||||
for l in input_model.layers:
|
|
||||||
for c in l.__class__.__mro__:
|
|
||||||
if c == tf.keras.models.Model:
|
|
||||||
hidden_layers += get_trainable_hidden_layers(l)
|
|
||||||
break
|
|
||||||
elif c == tf.keras.layers.InputLayer:
|
|
||||||
break
|
|
||||||
elif c == tf.keras.layers.Layer:
|
|
||||||
if l.trainable_variables:
|
|
||||||
hidden_layers.append(l)
|
|
||||||
break
|
|
||||||
return hidden_layers
|
|
||||||
|
|
||||||
|
|
||||||
def all_trainable_layers_are_registered(input_model, layer_registry):
|
def all_trainable_layers_are_registered(input_model, layer_registry):
|
||||||
|
@ -148,20 +123,18 @@ def all_trainable_layers_are_registered(input_model, layer_registry):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_model: The Keras model from which to obtain the layers from.
|
input_model: The Keras model from which to obtain the layers from.
|
||||||
layer_registry: A dictionary of layers that support "fast" gradient norm
|
layer_registry: A `LayerRegistry` instance containing functions that help
|
||||||
computations. The key is the class of the layer and the value is a
|
compute gradient norms quickly. See
|
||||||
function that returns a triple (output, sqr_grad_norms, vars), where
|
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||||
output is the pre-activator tensor, sqr_grad_norms is the square of the
|
more details.
|
||||||
norm of the layer's input, and vars is an ordered list of the trainable
|
|
||||||
weights.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if all the trainable layers in `input_model` are in `layer_registry`.
|
True if all the trainable layers in `input_model` are in `layer_registry`.
|
||||||
False otherwise.
|
False otherwise.
|
||||||
"""
|
"""
|
||||||
hidden_layers = get_trainable_hidden_layers(input_model)
|
for layer in input_model.layers:
|
||||||
for l in hidden_layers:
|
for sublayer in _get_internal_layers(layer):
|
||||||
if hash(l.__class__) not in layer_registry:
|
if not layer_registry.is_elem(sublayer) and sublayer.trainable_variables:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -212,3 +185,19 @@ def add_aggregate_noise(
|
||||||
)
|
)
|
||||||
|
|
||||||
return tf.nest.map_structure(add_noise, clipped_grads)
|
return tf.nest.map_structure(add_noise, clipped_grads)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_model_outputs_using_core_keras_layers(input_model):
|
||||||
|
"""Returns the model outputs generated by only core Keras layers."""
|
||||||
|
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
|
||||||
|
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
|
||||||
|
|
||||||
|
def generator_fn(layer_instance, args, kwargs):
|
||||||
|
if hash(layer_instance.__class__) in cust_hash_set:
|
||||||
|
# Using `.call()` does not register the layer in the compute graph of
|
||||||
|
# a forward pass.
|
||||||
|
return layer_instance.call(*args, **kwargs), None
|
||||||
|
else:
|
||||||
|
return layer_instance(*args, **kwargs), None
|
||||||
|
|
||||||
|
return model_forward_pass(input_model, input_model.inputs, generator_fn)[0]
|
||||||
|
|
|
@ -0,0 +1,195 @@
|
||||||
|
# Copyright 2022, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Defines the layer registry class and useful factory functions.
|
||||||
|
|
||||||
|
Defines "fast" gradient norm layer registry functions for use in the "fast"
|
||||||
|
gradient clipping algorithm. Specifically, each registry function takes
|
||||||
|
in two inputs (i) a layer instance and (ii) `tf.Tensor` inputs to produce three
|
||||||
|
outputs: (a) a differentiable `tf.Tensor` `Z`, (b) either `None` or a function
|
||||||
|
that maps the object in (a) to the layer instance's output when using the
|
||||||
|
inputs in (ii), and (c) a function `F` that generates the per-example
|
||||||
|
squared gradient norms when it is fed an object representing the gradient of
|
||||||
|
the summed loss with respect to `Z` in (a). If (b) is `None`, then (a) is
|
||||||
|
expected to contain the layer outputs.
|
||||||
|
|
||||||
|
When a layer registry function is defined, it is generally assumed that the
|
||||||
|
following relation holds:
|
||||||
|
|
||||||
|
`|dL/dW|^2 == F(grad_Z)`
|
||||||
|
|
||||||
|
where `gradient_Z` is the gradient of the summed loss with respect to `Z`.
|
||||||
|
|
||||||
|
For example, if the layer instance is tf.keras.layers.Dense, Z contains the
|
||||||
|
pre-activation tensors, i.e., `z = X * w` for input `X`, and `g` is a tensor
|
||||||
|
whose i-th entry is the L2 norm of the i-th input vector, then
|
||||||
|
|
||||||
|
`F(grad_Z) = g^2 * l2_row_norm(grad_Z)^2`,
|
||||||
|
|
||||||
|
where `l2_row_norm(y)` computes the L2 norm for each row of an input `y`.
|
||||||
|
Details of this decomposition can be found in https://arxiv.org/abs/1510.01799
|
||||||
|
"""
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Main class
|
||||||
|
# ==============================================================================
|
||||||
|
class LayerRegistry:
|
||||||
|
"""Custom container for layer registry functions."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Basic initialization of various internal dictionaries."""
|
||||||
|
self._layer_class_dict = {}
|
||||||
|
self._registry = {}
|
||||||
|
|
||||||
|
def is_elem(self, layer_instance):
|
||||||
|
"""Checks if a layer instance's class is in the registry."""
|
||||||
|
return hash(layer_instance.__class__) in self._registry
|
||||||
|
|
||||||
|
def lookup(self, layer_instance):
|
||||||
|
"""Returns the layer registry function for a given layer instance."""
|
||||||
|
return self._registry[hash(layer_instance.__class__)]
|
||||||
|
|
||||||
|
def insert(self, layer_class, layer_registry_function):
|
||||||
|
"""Inserts a layer registry function into the internal dictionaries."""
|
||||||
|
layer_key = hash(layer_class)
|
||||||
|
self._layer_class_dict[layer_key] = layer_class
|
||||||
|
self._registry[layer_key] = layer_registry_function
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Supported Keras layers
|
||||||
|
# ==============================================================================
|
||||||
|
def dense_layer_computation(layer_instance, inputs):
|
||||||
|
"""Registry function for `tf.keras.layers.Dense`.
|
||||||
|
|
||||||
|
The logic for this computation is based on the following paper:
|
||||||
|
https://arxiv.org/abs/1510.01799
|
||||||
|
|
||||||
|
For the sake of efficiency, we fuse the variables and square grad norms
|
||||||
|
for the kernel weights and bias vector together.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_instance: A `tf.keras.layers.Dense` instance.
|
||||||
|
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||||
|
`layer_instance(inputs)` returns a valid output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `tuple` `(base_vars, transform, sqr_norm_fn)`. `base_vars` is the
|
||||||
|
intermediate Tensor used in the chain-rule / "fast" clipping trick,
|
||||||
|
`transform` is a function that maps `base_vars` to the layer outputs, and
|
||||||
|
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
|
||||||
|
represents the output of the call `tape.gradient(summed_loss, base_vars)`
|
||||||
|
where `tape` is a `tf.GradientTape` instance that records the dense
|
||||||
|
layer computation and `summed_loss` is the sum of the per-example losses
|
||||||
|
of the underlying model. This function then returns the per-example squared
|
||||||
|
L2 gradient norms of the trainable variables in `layer_instance`. These
|
||||||
|
squared norms should be a 1D `tf.Tensor` of length `batch_size`.
|
||||||
|
"""
|
||||||
|
orig_activation = layer_instance.activation
|
||||||
|
layer_instance.activation = None
|
||||||
|
base_vars = layer_instance(*inputs)
|
||||||
|
layer_instance.activation = orig_activation
|
||||||
|
def sqr_norm_fn(base_vars_grads):
|
||||||
|
sqr_inputs = tf.square(*inputs)
|
||||||
|
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
|
||||||
|
input_sqr_norms = tf.reduce_sum(sqr_inputs, axis=inputs_reduction_axes)
|
||||||
|
if layer_instance.use_bias:
|
||||||
|
# Adding a bias term is equivalent to a layer with no bias term and which
|
||||||
|
# adds an additional variable to the layer input that only takes a
|
||||||
|
# constant value of 1.0. This is thus equivalent to adding 1.0 to the sum
|
||||||
|
# of the squared values of the inputs.
|
||||||
|
input_sqr_norms += tf.cast(1.0, dtype=input_sqr_norms.dtype)
|
||||||
|
reduction_axes = tf.range(1, tf.rank(base_vars_grads))
|
||||||
|
base_vars_sqr_norms = tf.reduce_sum(
|
||||||
|
tf.square(base_vars_grads), axis=reduction_axes
|
||||||
|
)
|
||||||
|
return input_sqr_norms * base_vars_sqr_norms
|
||||||
|
|
||||||
|
return base_vars, layer_instance.activation, sqr_norm_fn
|
||||||
|
|
||||||
|
|
||||||
|
def embedding_layer_computation(layer_instance, inputs):
|
||||||
|
"""Registry function for `tf.keras.layers.Embedding`.
|
||||||
|
|
||||||
|
The logic of this computation is based on the `tf.keras.layers.Dense`
|
||||||
|
computation and the fact that an embedding layer is just a dense layer
|
||||||
|
with no activation function and an output vector of the form X*W for input
|
||||||
|
X, where the i-th row of W is the i-th embedding vector and the j-th row of
|
||||||
|
X is a one-hot vector representing the input of example j.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_instance: A `tf.keras.layers.Embedding` instance.
|
||||||
|
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
||||||
|
`layer_instance(inputs)` returns a valid output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `tuple` `(base_vars, None, sqr_norm_fn)`, `base_vars` is the
|
||||||
|
intermediate Tensor used in the chain-rule / "fast" clipping trick, and
|
||||||
|
`sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that
|
||||||
|
represents the output of the call `tape.gradient(summed_loss, base_vars)`
|
||||||
|
where `tape` is a `tf.GradientTape` instance that records the dense
|
||||||
|
layer computation and `summed_loss` is the sum of the per-example losses
|
||||||
|
of the underlying model. This function then returns the per-example squared
|
||||||
|
L2 gradient norms of the trainable variables in `layer_instance`. These
|
||||||
|
squared norms should be a 1D `tf.Tensor` of length `batch_size`.
|
||||||
|
"""
|
||||||
|
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
||||||
|
if layer_instance.sparse:
|
||||||
|
raise NotImplementedError("Sparse output vectors are not supported.")
|
||||||
|
if len(inputs[0].shape) != 2:
|
||||||
|
raise NotImplementedError("Only 2D embedding inputs are supported.")
|
||||||
|
# The logic below is applied to properly handle repeated embedding indices.
|
||||||
|
# Specifically, sqr_grad_norms will contain the total counts of each embedding
|
||||||
|
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
|
||||||
|
# function in clip_grads.py). An example is as follows:
|
||||||
|
#
|
||||||
|
# inputs =
|
||||||
|
# [[0 0 0 1 2 2],
|
||||||
|
# [0 2 2 2 1 1]]
|
||||||
|
#
|
||||||
|
# counts =
|
||||||
|
# [[3 1 2]
|
||||||
|
# [1 2 3]]
|
||||||
|
#
|
||||||
|
# input_counts =
|
||||||
|
# [[3 3 3 1 2 2],
|
||||||
|
# [1 3 3 3 2 2]]
|
||||||
|
#
|
||||||
|
base_vars = layer_instance(*inputs)
|
||||||
|
def sqr_norm_fn(base_vars_grads):
|
||||||
|
indices = tf.cast(*inputs, tf.int32)
|
||||||
|
if isinstance(indices, tf.SparseTensor):
|
||||||
|
indices = tf.sparse.to_dense(indices)
|
||||||
|
counts = tf.math.bincount(indices, axis=-1)
|
||||||
|
input_counts = tf.expand_dims(
|
||||||
|
tf.cast(tf.gather(counts, indices, batch_dims=1), base_vars.dtype),
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
scaled_grads = input_counts * tf.square(base_vars_grads)
|
||||||
|
reduction_axes = tf.range(1, tf.rank(scaled_grads))
|
||||||
|
return tf.reduce_sum(scaled_grads, axis=reduction_axes)
|
||||||
|
|
||||||
|
return base_vars, None, sqr_norm_fn
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Main factory methods
|
||||||
|
# ==============================================================================
|
||||||
|
def make_default_layer_registry():
|
||||||
|
registry = LayerRegistry()
|
||||||
|
registry.insert(tf.keras.layers.Dense, dense_layer_computation)
|
||||||
|
registry.insert(tf.keras.layers.Embedding, embedding_layer_computation)
|
||||||
|
return registry
|
|
@ -1,143 +0,0 @@
|
||||||
# Copyright 2022, The TensorFlow Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Generates a default layer registry.
|
|
||||||
|
|
||||||
Defines "fast" gradient norm layer registry functions for use in the "fast"
|
|
||||||
gradient clipping algorithm. Specifically, each registry function takes
|
|
||||||
in two inputs (i) a layer instance and (ii) `tf.Tensor` inputs to produce three
|
|
||||||
outputs: (a) a `tf.Tensor` `G` of gradient norms, (b) a differentiable
|
|
||||||
`tf.Tensor` `Z`, and (c) either `None` or a function that maps the object in
|
|
||||||
(b) to the layer instance's output when using the inputs in (ii). If (c) is
|
|
||||||
`None`, then (b) contains the layer outputs.
|
|
||||||
|
|
||||||
When a layer registry function is defined, it is generally assumed that the
|
|
||||||
following relation holds for each pair `(g, z)` in `zip(G, Z)`:
|
|
||||||
|
|
||||||
`|dL/dw|^2 == |dL/dz|^2 * g^2`
|
|
||||||
|
|
||||||
where `L` is any per-example loss function and `w` are the trainable variables
|
|
||||||
corresponding to `(g, z)`.
|
|
||||||
|
|
||||||
For example, this relation holds if the layer instance is tf.keras.layers.Dense,
|
|
||||||
Z contains the pre-activation tensors, i.e., `z = X * w` for input `X`, and `g`
|
|
||||||
is the norm of the input corresponding to the given per-example loss (see the
|
|
||||||
formulae in https://arxiv.org/abs/1510.01799 for more details).
|
|
||||||
|
|
||||||
The registry functions are registered in a `dict` (registry) whose key is the
|
|
||||||
hash of the layer class and whose value is the registry function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Supported Keras layers
|
|
||||||
# ==============================================================================
|
|
||||||
def dense_layer_computation(layer_instance, inputs):
|
|
||||||
"""Registry function for `tf.keras.layers.Dense`.
|
|
||||||
|
|
||||||
The logic for this computation is based on the following paper:
|
|
||||||
https://arxiv.org/abs/1510.01799
|
|
||||||
|
|
||||||
For the sake of efficiency, we fuse the variables and square grad norms
|
|
||||||
for the kernel weights and bias vector together.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_instance: A `tf.keras.layers.Dense` instance.
|
|
||||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
|
||||||
`layer_instance(inputs)` returns a valid output.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `tuple` `(sqr_grad_norms, base_vars, transform)`, where `norms` is a 1D
|
|
||||||
`tf.Tensor` of the squared l2-norms of the input tensors, `base_vars` is the
|
|
||||||
intermediate Tensor used in the chain-rule / "fast" clipping trick, and
|
|
||||||
`transform` is a function that maps `base_vars` to the layer outputs.
|
|
||||||
"""
|
|
||||||
orig_activation = layer_instance.activation
|
|
||||||
layer_instance.activation = None
|
|
||||||
base_vars = layer_instance(*inputs)
|
|
||||||
sqr_inputs = tf.square(*inputs)
|
|
||||||
inputs_reduction_axes = tf.range(1, tf.rank(sqr_inputs))
|
|
||||||
sqr_grad_norms = tf.reduce_sum(tf.square(*inputs), axis=inputs_reduction_axes)
|
|
||||||
if layer_instance.use_bias:
|
|
||||||
# Adding a bias term is equivalent to a layer with no bias term and which
|
|
||||||
# adds an additional variable to the layer input that only takes a constant
|
|
||||||
# value of 1.0. This is thus equivalent to adding 1.0 to the sum of the
|
|
||||||
# squared values of the inputs.
|
|
||||||
sqr_grad_norms += tf.cast(1.0, dtype=sqr_grad_norms.dtype)
|
|
||||||
layer_instance.activation = orig_activation
|
|
||||||
return sqr_grad_norms, base_vars, layer_instance.activation
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_layer_computation(layer_instance, inputs):
|
|
||||||
"""Registry function for `tf.keras.layers.Embedding`.
|
|
||||||
|
|
||||||
The logic of this computation is based on the `tf.keras.layers.Dense`
|
|
||||||
computation and the fact that an embedding layer is just a dense layer
|
|
||||||
with no activation function and an output vector of the form X*W for input
|
|
||||||
X, where the i-th row of W is the i-th embedding vector and the j-th row of
|
|
||||||
X is a one-hot vector representing the input of example j.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_instance: A `tf.keras.layers.Embedding` instance.
|
|
||||||
inputs: A `tf.Tensor` which can be passed into the layer instance, i.e.,
|
|
||||||
`layer_instance(inputs)` returns a valid output.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `tuple` `(sqr_grad_norms, base_vars, None)`, where `sqr_grad_norms` is
|
|
||||||
a `tf.Tensor` that is related to the squared l2-norms of the input tensors
|
|
||||||
and `base_vars` is the intermediate Tensor used in the chain-rule / "fast"
|
|
||||||
clipping trick.
|
|
||||||
"""
|
|
||||||
if hasattr(layer_instance, "sparse"): # for backwards compatibility
|
|
||||||
if layer_instance.sparse:
|
|
||||||
raise NotImplementedError("Sparse output vectors are not supported.")
|
|
||||||
if tf.rank(*inputs) != 2:
|
|
||||||
raise NotImplementedError("Only 2D embedding inputs are supported.")
|
|
||||||
# The logic below is applied to properly handle repeated embedding indices.
|
|
||||||
# Specifically, sqr_grad_norms will contain the total counts of each embedding
|
|
||||||
# index (see how it is processed in the combine_pre_and_post_sqr_norms()
|
|
||||||
# function in clip_grads.py). An example is as follows:
|
|
||||||
#
|
|
||||||
# inputs =
|
|
||||||
# [[0 0 0 1 2 2],
|
|
||||||
# [0 2 2 2 1 1]]
|
|
||||||
#
|
|
||||||
# counts =
|
|
||||||
# [[3 1 2]
|
|
||||||
# [1 2 3]]
|
|
||||||
#
|
|
||||||
# sqr_grad_norms =
|
|
||||||
# [[3 3 3 1 2 2],
|
|
||||||
# [1 3 3 3 2 2]]
|
|
||||||
#
|
|
||||||
base_vars = layer_instance(*inputs)
|
|
||||||
indices = tf.cast(*inputs, tf.int32)
|
|
||||||
if isinstance(indices, tf.SparseTensor):
|
|
||||||
indices = tf.sparse.to_dense(indices)
|
|
||||||
counts = tf.math.bincount(indices, axis=-1)
|
|
||||||
sqr_grad_norms = tf.cast(
|
|
||||||
tf.gather(counts, indices, batch_dims=1), base_vars.dtype
|
|
||||||
)
|
|
||||||
return sqr_grad_norms, base_vars, None
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Main factory methods
|
|
||||||
# ==============================================================================
|
|
||||||
def make_default_layer_registry():
|
|
||||||
registry = {}
|
|
||||||
registry[hash(tf.keras.layers.Dense)] = dense_layer_computation
|
|
||||||
registry[hash(tf.keras.layers.Embedding)] = embedding_layer_computation
|
|
||||||
return registry
|
|
|
@ -18,7 +18,7 @@ py_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ py_test(
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry_factories",
|
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
|
||||||
"//tensorflow_privacy/privacy/keras_models:dp_keras_model",
|
"//tensorflow_privacy/privacy/keras_models:dp_keras_model",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -96,13 +96,10 @@ def make_dp_model_class(cls):
|
||||||
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
noise_multiplier: Ratio of the standard deviation to the clipping norm.
|
||||||
num_microbatches: Number of microbatches.
|
num_microbatches: Number of microbatches.
|
||||||
use_xla: If `True`, compiles train_step to XLA.
|
use_xla: If `True`, compiles train_step to XLA.
|
||||||
layer_registry: A `dict` of layers that support "fast" gradient norm
|
layer_registry: A `LayerRegistry` instance containing functions that
|
||||||
computations. The key is the class of the layer and the value is a
|
help compute gradient norms quickly. See
|
||||||
function that returns a `tuple` `(output, sqr_grad_norms, vars)`,
|
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
|
||||||
where `output` is the pre-activator tensor, `sqr_grad_norms` is
|
more details.
|
||||||
related to the squared norms of a layer's pre-activation tensor, and
|
|
||||||
`vars` are relevant trainable weights (see
|
|
||||||
`layer_registry_factories.py` for examples).
|
|
||||||
*args: These will be passed on to the base class `__init__` method.
|
*args: These will be passed on to the base class `__init__` method.
|
||||||
**kwargs: These will be passed on to the base class `__init__` method.
|
**kwargs: These will be passed on to the base class `__init__` method.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry_factories
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
|
||||||
from tensorflow_privacy.privacy.keras_models import dp_keras_model
|
from tensorflow_privacy.privacy.keras_models import dp_keras_model
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +32,10 @@ def get_layer_registries():
|
||||||
# Outputs a list of testable layer registries.
|
# Outputs a list of testable layer registries.
|
||||||
# The empty registry {} tests the behavior of the standard approach,
|
# The empty registry {} tests the behavior of the standard approach,
|
||||||
# while the other one tests the fast gradient clipping algorithm.
|
# while the other one tests the fast gradient clipping algorithm.
|
||||||
return [{}, layer_registry_factories.make_default_layer_registry()]
|
return [
|
||||||
|
layer_registry.LayerRegistry(),
|
||||||
|
layer_registry.make_default_layer_registry(),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
Loading…
Reference in a new issue