forked from 626_privacy/tensorflow_privacy
Generalize generate_model_outputs_using_core_keras_layers()
.
This change adds the following two new features to the above function: (i) it supports nested custom layers of depth >2; (ii) it allows the caller to exclude certain layers from the expansion. Feature (ii) will be needed for the development of DP models that use Trasformer or BERT-type layers. PiperOrigin-RevId: 520919934
This commit is contained in:
parent
abb0c3f9f6
commit
ee1abe6930
2 changed files with 154 additions and 39 deletions
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""Utility functions that help in the computation of per-example gradient norms."""
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Text, Tuple, Union
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
@ -36,19 +36,6 @@ def has_internal_compute_graph(input_object: Any):
|
|||
)
|
||||
|
||||
|
||||
def _get_internal_layers(
|
||||
input_layer: tf.keras.layers.Layer,
|
||||
) -> List[tf.keras.layers.Layer]:
|
||||
"""Returns a list of layers that are nested within a given layer."""
|
||||
internal_layers = []
|
||||
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
|
||||
for layer in input_layer.layers:
|
||||
internal_layers.extend(_get_internal_layers(layer))
|
||||
else:
|
||||
internal_layers.append(input_layer)
|
||||
return internal_layers
|
||||
|
||||
|
||||
def model_forward_pass(
|
||||
input_model: tf.keras.Model,
|
||||
inputs: PackedTensors,
|
||||
|
@ -114,18 +101,10 @@ def model_forward_pass(
|
|||
generator_outputs_list.extend(node_generator_outputs)
|
||||
else:
|
||||
# Otherwise, we parse the node directly.
|
||||
node_layers = _get_internal_layers(node.layer)
|
||||
for layer in node_layers:
|
||||
node_layer_outputs, layer_generator_outputs = generator_fn(
|
||||
layer, args, kwargs
|
||||
)
|
||||
generator_outputs_list.append(layer_generator_outputs)
|
||||
args = (
|
||||
node_layer_outputs
|
||||
if isinstance(node_layer_outputs, tuple)
|
||||
else (node_layer_outputs,)
|
||||
)
|
||||
kwargs = {}
|
||||
node_layer_outputs, layer_generator_outputs = generator_fn(
|
||||
node.layer, args, kwargs
|
||||
)
|
||||
generator_outputs_list.append(layer_generator_outputs)
|
||||
|
||||
# Update the current dictionary of inputs for the next node.
|
||||
for x_id, y in zip(
|
||||
|
@ -163,9 +142,8 @@ def all_trainable_layers_are_registered(
|
|||
False otherwise.
|
||||
"""
|
||||
for layer in input_model.layers:
|
||||
for sublayer in _get_internal_layers(layer):
|
||||
if not layer_registry.is_elem(sublayer) and sublayer.trainable_variables:
|
||||
return False
|
||||
if not layer_registry.is_elem(layer) and layer.trainable_variables:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
@ -213,17 +191,53 @@ def add_aggregate_noise(
|
|||
|
||||
def generate_model_outputs_using_core_keras_layers(
|
||||
input_model: tf.keras.Model,
|
||||
custom_layer_set: Optional[Set[type]] = None, # pylint: disable=g-bare-generic
|
||||
) -> PackedTensors:
|
||||
"""Returns the model outputs generated by only core Keras layers."""
|
||||
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
|
||||
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
|
||||
"""Returns the model outputs generated by only core Keras layers.
|
||||
|
||||
Args:
|
||||
input_model: A `tf.keras.Model` instance to obtain outputs from.
|
||||
custom_layer_set: An optional `set` of custom layers to expand. If `None`,
|
||||
then this is the set of all registered custom Keras layers.
|
||||
|
||||
Returns:
|
||||
A `tf.Tensor` that is the result of `input_model(input_model.inputs)`
|
||||
using only Keras layers that are not in `custom_layer_set`.
|
||||
"""
|
||||
# Set up helper variables and functions.
|
||||
custom_layer_set = (
|
||||
custom_layer_set or tf.keras.utils.get_custom_objects().values()
|
||||
)
|
||||
|
||||
def _is_core(layer_instance):
|
||||
return type(layer_instance) not in custom_layer_set
|
||||
|
||||
def generator_fn(layer_instance, args, kwargs):
|
||||
if hash(layer_instance.__class__) in cust_hash_set:
|
||||
# Using `.call()` does not register the layer in the compute graph of
|
||||
# a forward pass.
|
||||
return layer_instance.call(*args, **kwargs), None
|
||||
else:
|
||||
return layer_instance(*args, **kwargs), None
|
||||
# Using `.call()` does not register the layer in the compute graph of
|
||||
# a forward pass.
|
||||
layer_outputs = (
|
||||
layer_instance(*args, **kwargs)
|
||||
if _is_core(layer_instance)
|
||||
else layer_instance.call(*args, **kwargs)
|
||||
)
|
||||
return layer_outputs, None
|
||||
|
||||
return model_forward_pass(input_model, input_model.inputs, generator_fn)[0]
|
||||
# Return early if all the existing layers contain only core layers.
|
||||
if all(_is_core(layer) for layer in input_model.layers):
|
||||
return model_forward_pass(input_model, input_model.inputs)[0]
|
||||
|
||||
# Do a forward pass to expand the outermost layers.
|
||||
candidate_outputs, _ = model_forward_pass(
|
||||
input_model, input_model.inputs, generator_fn
|
||||
)
|
||||
|
||||
# The following recursion is inefficient because it recursively builds `n`
|
||||
# Keras model graphs, where `n` is the number of recursive calls. However,
|
||||
# it appears to be the only valid approach without accessing Keras's internal
|
||||
# functions (e.g., `keras.engine.functional._map_graph_network()`).
|
||||
cleaned_model = tf.keras.Model(
|
||||
inputs=input_model.inputs, outputs=candidate_outputs
|
||||
)
|
||||
return generate_model_outputs_using_core_keras_layers(
|
||||
cleaned_model, custom_layer_set
|
||||
)
|
||||
|
|
|
@ -12,12 +12,72 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Helper functions and classes.
|
||||
# ==============================================================================
|
||||
@tf.keras.utils.register_keras_serializable('gradient_clipping_utils_test')
|
||||
class DoubleDense(tf.keras.layers.Layer):
|
||||
"""Generates two dense layers nested together."""
|
||||
|
||||
def __init__(self, units: int):
|
||||
super().__init__()
|
||||
self.dense1 = tf.keras.layers.Dense(units, name='DDense_ext_1')
|
||||
self.dense2 = tf.keras.layers.Dense(1, name='DDense_ext_2')
|
||||
|
||||
def call(self, inputs: Any):
|
||||
x = self.dense1(inputs)
|
||||
return self.dense2(x)
|
||||
|
||||
|
||||
@tf.keras.utils.register_keras_serializable('gradient_clipping_utils_test')
|
||||
class TripleDense(tf.keras.layers.Layer):
|
||||
"""Generates three dense layers nested together."""
|
||||
|
||||
def __init__(self, units: int):
|
||||
super().__init__()
|
||||
self.dense1 = tf.keras.layers.Dense(units, name='TDense_ext_1')
|
||||
self.dense2 = tf.keras.layers.Dense(units, name='TDense_ext_2')
|
||||
self.dense3 = tf.keras.layers.Dense(1, name='TDense_ext_3')
|
||||
|
||||
def call(self, inputs: Any):
|
||||
x1 = self.dense1(inputs)
|
||||
x2 = self.dense2(x1)
|
||||
return self.dense3(x2)
|
||||
|
||||
|
||||
def get_reduced_model(sample_inputs, hidden_layer_list, new_custom_layers=None):
|
||||
"""Reduces a set of layers to only core Keras layers in a model."""
|
||||
sample_outputs = sample_inputs
|
||||
for l in hidden_layer_list:
|
||||
sample_outputs = l(sample_outputs)
|
||||
custom_model = tf.keras.Model(inputs=sample_inputs, outputs=sample_outputs)
|
||||
if new_custom_layers:
|
||||
reduced_outputs = (
|
||||
gradient_clipping_utils.generate_model_outputs_using_core_keras_layers(
|
||||
custom_model,
|
||||
custom_layer_set=new_custom_layers,
|
||||
)
|
||||
)
|
||||
else:
|
||||
reduced_outputs = (
|
||||
gradient_clipping_utils.generate_model_outputs_using_core_keras_layers(
|
||||
custom_model
|
||||
)
|
||||
)
|
||||
return tf.keras.Model(inputs=custom_model.inputs, outputs=reduced_outputs)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main tests.
|
||||
# ==============================================================================
|
||||
class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
|
@ -75,5 +135,46 @@ class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertAllClose(computed_outputs, true_outputs)
|
||||
|
||||
|
||||
class GenerateOutputsUsingCoreKerasLayers(
|
||||
tf.test.TestCase, parameterized.TestCase
|
||||
):
|
||||
|
||||
def test_single_custom_layer_is_reduced(self):
|
||||
num_units = 5
|
||||
num_dims = 3
|
||||
reduced_model = get_reduced_model(
|
||||
tf.keras.Input(num_dims),
|
||||
[DoubleDense(num_units)],
|
||||
)
|
||||
# Ignore the input layer.
|
||||
for l in reduced_model.layers[1:]:
|
||||
self.assertIsInstance(l, tf.keras.layers.Dense)
|
||||
|
||||
def test_two_distinct_custom_layers_are_reduced(self):
|
||||
num_units = 5
|
||||
num_dims = 3
|
||||
reduced_model = get_reduced_model(
|
||||
tf.keras.Input(num_dims),
|
||||
[DoubleDense(num_units), TripleDense(num_units)],
|
||||
)
|
||||
# Ignore the input layer.
|
||||
for l in reduced_model.layers[1:]:
|
||||
self.assertIsInstance(l, tf.keras.layers.Dense)
|
||||
|
||||
def test_new_custom_layer_spec(self):
|
||||
num_units = 5
|
||||
num_dims = 3
|
||||
reduced_model = get_reduced_model(
|
||||
tf.keras.Input(num_dims),
|
||||
[DoubleDense(num_units), TripleDense(num_units)],
|
||||
new_custom_layers=set([DoubleDense]),
|
||||
)
|
||||
# Ignore the input layer.
|
||||
for l in reduced_model.layers[1:]:
|
||||
self.assertTrue(
|
||||
isinstance(l, tf.keras.layers.Dense) or isinstance(l, TripleDense)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue