forked from 626_privacy/tensorflow_privacy
Generalize model_forward_pass()
to allow input models with multiple outputs.
PiperOrigin-RevId: 517145254
This commit is contained in:
parent
043e8b5272
commit
7ae50c5ca5
3 changed files with 103 additions and 6 deletions
|
@ -9,6 +9,14 @@ py_library(
|
|||
deps = [":layer_registry"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "gradient_clipping_utils_test",
|
||||
srcs = ["gradient_clipping_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":gradient_clipping_utils"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "layer_registry",
|
||||
srcs = ["layer_registry.py"],
|
||||
|
|
|
@ -20,7 +20,7 @@ import tensorflow as tf
|
|||
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||
|
||||
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
|
||||
|
||||
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
|
||||
|
||||
|
@ -51,9 +51,9 @@ def _get_internal_layers(
|
|||
|
||||
def model_forward_pass(
|
||||
input_model: tf.keras.Model,
|
||||
inputs: InputTensor,
|
||||
inputs: PackedTensors,
|
||||
generator_fn: GeneratorFunction = None,
|
||||
) -> Tuple[tf.Tensor, List[Any]]:
|
||||
) -> Tuple[PackedTensors, List[Any]]:
|
||||
"""Does a forward pass of a model and returns useful intermediates.
|
||||
|
||||
NOTE: the graph traversal algorithm is an adaptation of the logic in the
|
||||
|
@ -72,7 +72,7 @@ def model_forward_pass(
|
|||
|
||||
Returns:
|
||||
A `tuple` `(outputs, generator_outputs_list)`. `outputs` is the
|
||||
`tf.Tensor` that is generated as a result of a forward pass.
|
||||
`PackedTensor` that is generated as a result of a forward pass.
|
||||
`generator_outputs_list` is a `list` whose i-th entry is the output of
|
||||
`generator_fn(lyr, args, kwargs)[1]` where `lyr` is the i-th
|
||||
layer when the compute graph of `input_model` is traversed in BFS order.
|
||||
|
@ -133,7 +133,17 @@ def model_forward_pass(
|
|||
):
|
||||
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||||
|
||||
return node_layer_outputs, generator_outputs_list
|
||||
# Gather outputs (in case there are multiple) and return.
|
||||
output_tensors = []
|
||||
for x in input_model.outputs:
|
||||
x_id = str(id(x))
|
||||
output_tensors.append(tensor_dict[x_id].pop())
|
||||
model_outputs = tf.nest.pack_sequence_as(
|
||||
input_model._nested_outputs, # pylint: disable=protected-access
|
||||
output_tensors,
|
||||
)
|
||||
|
||||
return model_outputs, generator_outputs_list
|
||||
|
||||
|
||||
def all_trainable_layers_are_registered(
|
||||
|
@ -203,7 +213,7 @@ def add_aggregate_noise(
|
|||
|
||||
def generate_model_outputs_using_core_keras_layers(
|
||||
input_model: tf.keras.Model,
|
||||
) -> tf.Tensor:
|
||||
) -> PackedTensors:
|
||||
"""Returns the model outputs generated by only core Keras layers."""
|
||||
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
|
||||
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||
|
||||
|
||||
class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
input_packing_type=[None, tuple, list, dict],
|
||||
output_packing_type=[None, tuple, list, dict],
|
||||
)
|
||||
def test_outputs_are_consistent(
|
||||
self, input_packing_type, output_packing_type
|
||||
):
|
||||
num_dims = 3
|
||||
num_inputs = 1 if input_packing_type is None else 2
|
||||
num_outputs = 1 if output_packing_type is None else 2
|
||||
sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)]
|
||||
temp_sum = tf.stack(sample_inputs, axis=0)
|
||||
sample_outputs = [
|
||||
tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs)
|
||||
]
|
||||
sample_x_batch = [
|
||||
tf.multiply(tf.range(num_dims, dtype=tf.float32), float(i + 1.0))
|
||||
for i in range(num_inputs)
|
||||
]
|
||||
|
||||
# Pack inputs.
|
||||
if input_packing_type is None:
|
||||
inputs = sample_inputs[0]
|
||||
x_batch = sample_x_batch[0]
|
||||
elif input_packing_type is not dict:
|
||||
inputs = input_packing_type(sample_inputs)
|
||||
x_batch = input_packing_type(sample_x_batch)
|
||||
else:
|
||||
inputs = {}
|
||||
x_batch = {}
|
||||
keys = [str(i) for i in range(len(sample_inputs))]
|
||||
for k, v1, v2 in zip(keys, sample_inputs, sample_x_batch):
|
||||
inputs[k] = v1
|
||||
x_batch[k] = v2
|
||||
|
||||
# Pack outputs.
|
||||
if output_packing_type is None:
|
||||
outputs = sample_outputs[0]
|
||||
elif output_packing_type is not dict:
|
||||
outputs = output_packing_type(sample_outputs)
|
||||
else:
|
||||
outputs = {}
|
||||
keys = [str(i) for i in range(len(sample_outputs))]
|
||||
for k, v in zip(keys, sample_outputs):
|
||||
outputs[k] = v
|
||||
|
||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
computed_outputs, _ = gradient_clipping_utils.model_forward_pass(
|
||||
model,
|
||||
x_batch,
|
||||
)
|
||||
true_outputs = model(x_batch)
|
||||
self.assertAllClose(computed_outputs, true_outputs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in a new issue