Generalize model_forward_pass()
to allow input models with multiple outputs.
PiperOrigin-RevId: 517145254
This commit is contained in:
parent
043e8b5272
commit
7ae50c5ca5
3 changed files with 103 additions and 6 deletions
|
@ -9,6 +9,14 @@ py_library(
|
||||||
deps = [":layer_registry"],
|
deps = [":layer_registry"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "gradient_clipping_utils_test",
|
||||||
|
srcs = ["gradient_clipping_utils_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":gradient_clipping_utils"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "layer_registry",
|
name = "layer_registry",
|
||||||
srcs = ["layer_registry.py"],
|
srcs = ["layer_registry.py"],
|
||||||
|
|
|
@ -20,7 +20,7 @@ import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
|
|
||||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||||
|
|
||||||
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
|
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
|
||||||
|
|
||||||
|
@ -51,9 +51,9 @@ def _get_internal_layers(
|
||||||
|
|
||||||
def model_forward_pass(
|
def model_forward_pass(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
inputs: InputTensor,
|
inputs: PackedTensors,
|
||||||
generator_fn: GeneratorFunction = None,
|
generator_fn: GeneratorFunction = None,
|
||||||
) -> Tuple[tf.Tensor, List[Any]]:
|
) -> Tuple[PackedTensors, List[Any]]:
|
||||||
"""Does a forward pass of a model and returns useful intermediates.
|
"""Does a forward pass of a model and returns useful intermediates.
|
||||||
|
|
||||||
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
||||||
|
@ -72,7 +72,7 @@ def model_forward_pass(
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tuple` `(outputs, generator_outputs_list)`. `outputs` is the
|
A `tuple` `(outputs, generator_outputs_list)`. `outputs` is the
|
||||||
`tf.Tensor` that is generated as a result of a forward pass.
|
`PackedTensor` that is generated as a result of a forward pass.
|
||||||
`generator_outputs_list` is a `list` whose i-th entry is the output of
|
`generator_outputs_list` is a `list` whose i-th entry is the output of
|
||||||
`generator_fn(lyr, args, kwargs)[1]` where `lyr` is the i-th
|
`generator_fn(lyr, args, kwargs)[1]` where `lyr` is the i-th
|
||||||
layer when the compute graph of `input_model` is traversed in BFS order.
|
layer when the compute graph of `input_model` is traversed in BFS order.
|
||||||
|
@ -133,7 +133,17 @@ def model_forward_pass(
|
||||||
):
|
):
|
||||||
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||||||
|
|
||||||
return node_layer_outputs, generator_outputs_list
|
# Gather outputs (in case there are multiple) and return.
|
||||||
|
output_tensors = []
|
||||||
|
for x in input_model.outputs:
|
||||||
|
x_id = str(id(x))
|
||||||
|
output_tensors.append(tensor_dict[x_id].pop())
|
||||||
|
model_outputs = tf.nest.pack_sequence_as(
|
||||||
|
input_model._nested_outputs, # pylint: disable=protected-access
|
||||||
|
output_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_outputs, generator_outputs_list
|
||||||
|
|
||||||
|
|
||||||
def all_trainable_layers_are_registered(
|
def all_trainable_layers_are_registered(
|
||||||
|
@ -203,7 +213,7 @@ def add_aggregate_noise(
|
||||||
|
|
||||||
def generate_model_outputs_using_core_keras_layers(
|
def generate_model_outputs_using_core_keras_layers(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
) -> tf.Tensor:
|
) -> PackedTensors:
|
||||||
"""Returns the model outputs generated by only core Keras layers."""
|
"""Returns the model outputs generated by only core Keras layers."""
|
||||||
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
|
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
|
||||||
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
|
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
|
||||||
|
|
||||||
|
class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.product(
|
||||||
|
input_packing_type=[None, tuple, list, dict],
|
||||||
|
output_packing_type=[None, tuple, list, dict],
|
||||||
|
)
|
||||||
|
def test_outputs_are_consistent(
|
||||||
|
self, input_packing_type, output_packing_type
|
||||||
|
):
|
||||||
|
num_dims = 3
|
||||||
|
num_inputs = 1 if input_packing_type is None else 2
|
||||||
|
num_outputs = 1 if output_packing_type is None else 2
|
||||||
|
sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)]
|
||||||
|
temp_sum = tf.stack(sample_inputs, axis=0)
|
||||||
|
sample_outputs = [
|
||||||
|
tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs)
|
||||||
|
]
|
||||||
|
sample_x_batch = [
|
||||||
|
tf.multiply(tf.range(num_dims, dtype=tf.float32), float(i + 1.0))
|
||||||
|
for i in range(num_inputs)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Pack inputs.
|
||||||
|
if input_packing_type is None:
|
||||||
|
inputs = sample_inputs[0]
|
||||||
|
x_batch = sample_x_batch[0]
|
||||||
|
elif input_packing_type is not dict:
|
||||||
|
inputs = input_packing_type(sample_inputs)
|
||||||
|
x_batch = input_packing_type(sample_x_batch)
|
||||||
|
else:
|
||||||
|
inputs = {}
|
||||||
|
x_batch = {}
|
||||||
|
keys = [str(i) for i in range(len(sample_inputs))]
|
||||||
|
for k, v1, v2 in zip(keys, sample_inputs, sample_x_batch):
|
||||||
|
inputs[k] = v1
|
||||||
|
x_batch[k] = v2
|
||||||
|
|
||||||
|
# Pack outputs.
|
||||||
|
if output_packing_type is None:
|
||||||
|
outputs = sample_outputs[0]
|
||||||
|
elif output_packing_type is not dict:
|
||||||
|
outputs = output_packing_type(sample_outputs)
|
||||||
|
else:
|
||||||
|
outputs = {}
|
||||||
|
keys = [str(i) for i in range(len(sample_outputs))]
|
||||||
|
for k, v in zip(keys, sample_outputs):
|
||||||
|
outputs[k] = v
|
||||||
|
|
||||||
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
computed_outputs, _ = gradient_clipping_utils.model_forward_pass(
|
||||||
|
model,
|
||||||
|
x_batch,
|
||||||
|
)
|
||||||
|
true_outputs = model(x_batch)
|
||||||
|
self.assertAllClose(computed_outputs, true_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in a new issue