diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index e50781d..dd24122 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -9,6 +9,14 @@ py_library( deps = [":layer_registry"], ) +py_test( + name = "gradient_clipping_utils_test", + srcs = ["gradient_clipping_utils_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":gradient_clipping_utils"], +) + py_library( name = "layer_registry", srcs = ["layer_registry.py"], diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 896fc3c..23de260 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -20,7 +20,7 @@ import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr -InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] +PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]] @@ -51,9 +51,9 @@ def _get_internal_layers( def model_forward_pass( input_model: tf.keras.Model, - inputs: InputTensor, + inputs: PackedTensors, generator_fn: GeneratorFunction = None, -) -> Tuple[tf.Tensor, List[Any]]: +) -> Tuple[PackedTensors, List[Any]]: """Does a forward pass of a model and returns useful intermediates. NOTE: the graph traversal algorithm is an adaptation of the logic in the @@ -72,7 +72,7 @@ def model_forward_pass( Returns: A `tuple` `(outputs, generator_outputs_list)`. `outputs` is the - `tf.Tensor` that is generated as a result of a forward pass. + `PackedTensor` that is generated as a result of a forward pass. `generator_outputs_list` is a `list` whose i-th entry is the output of `generator_fn(lyr, args, kwargs)[1]` where `lyr` is the i-th layer when the compute graph of `input_model` is traversed in BFS order. @@ -133,7 +133,17 @@ def model_forward_pass( ): tensor_dict[x_id] = [y] * tensor_usage_count[x_id] - return node_layer_outputs, generator_outputs_list + # Gather outputs (in case there are multiple) and return. + output_tensors = [] + for x in input_model.outputs: + x_id = str(id(x)) + output_tensors.append(tensor_dict[x_id].pop()) + model_outputs = tf.nest.pack_sequence_as( + input_model._nested_outputs, # pylint: disable=protected-access + output_tensors, + ) + + return model_outputs, generator_outputs_list def all_trainable_layers_are_registered( @@ -203,7 +213,7 @@ def add_aggregate_noise( def generate_model_outputs_using_core_keras_layers( input_model: tf.keras.Model, -) -> tf.Tensor: +) -> PackedTensors: """Returns the model outputs generated by only core Keras layers.""" cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects()) cust_hash_set = set([hash(v) for v in cust_obj_dict.values()]) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py new file mode 100644 index 0000000..f3c84b4 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py @@ -0,0 +1,79 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf + +from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils + + +class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + input_packing_type=[None, tuple, list, dict], + output_packing_type=[None, tuple, list, dict], + ) + def test_outputs_are_consistent( + self, input_packing_type, output_packing_type + ): + num_dims = 3 + num_inputs = 1 if input_packing_type is None else 2 + num_outputs = 1 if output_packing_type is None else 2 + sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)] + temp_sum = tf.stack(sample_inputs, axis=0) + sample_outputs = [ + tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs) + ] + sample_x_batch = [ + tf.multiply(tf.range(num_dims, dtype=tf.float32), float(i + 1.0)) + for i in range(num_inputs) + ] + + # Pack inputs. + if input_packing_type is None: + inputs = sample_inputs[0] + x_batch = sample_x_batch[0] + elif input_packing_type is not dict: + inputs = input_packing_type(sample_inputs) + x_batch = input_packing_type(sample_x_batch) + else: + inputs = {} + x_batch = {} + keys = [str(i) for i in range(len(sample_inputs))] + for k, v1, v2 in zip(keys, sample_inputs, sample_x_batch): + inputs[k] = v1 + x_batch[k] = v2 + + # Pack outputs. + if output_packing_type is None: + outputs = sample_outputs[0] + elif output_packing_type is not dict: + outputs = output_packing_type(sample_outputs) + else: + outputs = {} + keys = [str(i) for i in range(len(sample_outputs))] + for k, v in zip(keys, sample_outputs): + outputs[k] = v + + model = tf.keras.Model(inputs=inputs, outputs=outputs) + computed_outputs, _ = gradient_clipping_utils.model_forward_pass( + model, + x_batch, + ) + true_outputs = model(x_batch) + self.assertAllClose(computed_outputs, true_outputs) + + +if __name__ == '__main__': + tf.test.main()