Generalize model_forward_pass() to allow input models with multiple outputs.

PiperOrigin-RevId: 517145254
This commit is contained in:
A. Unique TensorFlower 2023-03-16 09:35:34 -07:00
parent 043e8b5272
commit 7ae50c5ca5
3 changed files with 103 additions and 6 deletions

View file

@ -9,6 +9,14 @@ py_library(
deps = [":layer_registry"],
)
py_test(
name = "gradient_clipping_utils_test",
srcs = ["gradient_clipping_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":gradient_clipping_utils"],
)
py_library(
name = "layer_registry",
srcs = ["layer_registry.py"],

View file

@ -20,7 +20,7 @@ import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
@ -51,9 +51,9 @@ def _get_internal_layers(
def model_forward_pass(
input_model: tf.keras.Model,
inputs: InputTensor,
inputs: PackedTensors,
generator_fn: GeneratorFunction = None,
) -> Tuple[tf.Tensor, List[Any]]:
) -> Tuple[PackedTensors, List[Any]]:
"""Does a forward pass of a model and returns useful intermediates.
NOTE: the graph traversal algorithm is an adaptation of the logic in the
@ -72,7 +72,7 @@ def model_forward_pass(
Returns:
A `tuple` `(outputs, generator_outputs_list)`. `outputs` is the
`tf.Tensor` that is generated as a result of a forward pass.
`PackedTensor` that is generated as a result of a forward pass.
`generator_outputs_list` is a `list` whose i-th entry is the output of
`generator_fn(lyr, args, kwargs)[1]` where `lyr` is the i-th
layer when the compute graph of `input_model` is traversed in BFS order.
@ -133,7 +133,17 @@ def model_forward_pass(
):
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
return node_layer_outputs, generator_outputs_list
# Gather outputs (in case there are multiple) and return.
output_tensors = []
for x in input_model.outputs:
x_id = str(id(x))
output_tensors.append(tensor_dict[x_id].pop())
model_outputs = tf.nest.pack_sequence_as(
input_model._nested_outputs, # pylint: disable=protected-access
output_tensors,
)
return model_outputs, generator_outputs_list
def all_trainable_layers_are_registered(
@ -203,7 +213,7 @@ def add_aggregate_noise(
def generate_model_outputs_using_core_keras_layers(
input_model: tf.keras.Model,
) -> tf.Tensor:
) -> PackedTensors:
"""Returns the model outputs generated by only core Keras layers."""
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])

View file

@ -0,0 +1,79 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
input_packing_type=[None, tuple, list, dict],
output_packing_type=[None, tuple, list, dict],
)
def test_outputs_are_consistent(
self, input_packing_type, output_packing_type
):
num_dims = 3
num_inputs = 1 if input_packing_type is None else 2
num_outputs = 1 if output_packing_type is None else 2
sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)]
temp_sum = tf.stack(sample_inputs, axis=0)
sample_outputs = [
tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs)
]
sample_x_batch = [
tf.multiply(tf.range(num_dims, dtype=tf.float32), float(i + 1.0))
for i in range(num_inputs)
]
# Pack inputs.
if input_packing_type is None:
inputs = sample_inputs[0]
x_batch = sample_x_batch[0]
elif input_packing_type is not dict:
inputs = input_packing_type(sample_inputs)
x_batch = input_packing_type(sample_x_batch)
else:
inputs = {}
x_batch = {}
keys = [str(i) for i in range(len(sample_inputs))]
for k, v1, v2 in zip(keys, sample_inputs, sample_x_batch):
inputs[k] = v1
x_batch[k] = v2
# Pack outputs.
if output_packing_type is None:
outputs = sample_outputs[0]
elif output_packing_type is not dict:
outputs = output_packing_type(sample_outputs)
else:
outputs = {}
keys = [str(i) for i in range(len(sample_outputs))]
for k, v in zip(keys, sample_outputs):
outputs[k] = v
model = tf.keras.Model(inputs=inputs, outputs=outputs)
computed_outputs, _ = gradient_clipping_utils.model_forward_pass(
model,
x_batch,
)
true_outputs = model(x_batch)
self.assertAllClose(computed_outputs, true_outputs)
if __name__ == '__main__':
tf.test.main()