Fix bug in DPModel that shows up in distributed training.

PiperOrigin-RevId: 528026372
This commit is contained in:
Walid Krichene 2023-04-28 17:30:48 -07:00 committed by A. Unique TensorFlower
parent 9710a4acc7
commit e65e14b2d6
3 changed files with 202 additions and 181 deletions

View file

@ -21,13 +21,14 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
`compute_gradient_norms()` function).
"""
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
from typing import Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
LossFn = Callable[..., tf.Tensor]
def get_registry_generator_fn(
@ -71,7 +72,7 @@ def compute_gradient_norms(
x_batch: InputTensor,
y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry,
per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None,
per_example_loss_fn: Optional[LossFn] = None,
num_microbatches: Optional[lr.BatchSize] = None,
trainable_vars: Optional[List[tf.Variable]] = None,
):
@ -92,9 +93,9 @@ def compute_gradient_norms(
compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
per_example_loss_fn: If not None, used as the function to compute the
vectorized per example loss. Otherwise, we derive it from `input_model`'s
loss function.
per_example_loss_fn: takes as input predictions, labels and weights, and
outputs a vector of per-example losses. If None, derived from
`input_model.loss` by disabling its reduction.
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
@ -193,7 +194,8 @@ def compute_clipped_gradients_and_outputs(
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
num_microbatches: Optional[lr.BatchSize] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor, float]:
clipping_loss: Optional[LossFn] = None,
) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]:
"""Computes the per-example clipped loss gradient and other useful outputs.
Given a batch of observations `(x_batch, y_batch)`, the main steps of this
@ -224,14 +226,21 @@ def compute_clipped_gradients_and_outputs(
microbatches. If not None, indicates that the loss is grouped into
num_microbatches (in this case, the batch dimension needs to be a multiple
of num_microbatches).
clipping_loss: If provided, used for the clipping computation. Defaults to
`input_model.compiled_loss`. Specifying a `clipping_loss` can be useful to
avoid calling `input_model.compiled_loss`, as this will append the value
of the clipped loss to the reported metrics, and this can be misleading as
the value of the clipped loss does not reflect the true loss.
Returns:
A `tuple` `(grad, y_pred, weighted_loss_value)`. The first element is the
A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the
clipped gradient of the loss function, the second is the result of
applying `input_model` to `x_batch`, and the third is loss value of
`input_model`, weighted by the loss weights generated by a specific
`compute_clip_weights()` call.
"""
if clipping_loss is None:
clipping_loss = input_model.compiled_loss
gradient_norms = compute_gradient_norms(
input_model,
x_batch,
@ -260,19 +269,10 @@ def compute_clipped_gradients_and_outputs(
if num_microbatches is None
else lr.add_microbatch_axis(y_pred, num_microbatches)
)
# NOTE: We do not log the loss values here. The caller should invoke
# `input_model.compute_loss()` to log loss values. Specifically,
# calling `input_model.compute_loss()` performs the following steps:
#
# (i) sums `input_model.loss` with the regularization losses given in
# `input_model.losses` to obtain the total loss
# (ii) evaluates the total loss with sample weights (if given)
weighted_loss_value = input_model.loss(
loss_y_batch, loss_y_pred, loss_weights
)
clipping_loss_value = clipping_loss(loss_y_batch, loss_y_pred, loss_weights)
clipped_grads = tape.gradient(
weighted_loss_value,
clipping_loss_value,
input_model.trainable_variables,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
return clipped_grads, y_pred, weighted_loss_value
return clipped_grads, y_pred, clipping_loss_value

View file

@ -15,10 +15,13 @@
from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
def make_dp_model_class(cls):
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
@ -122,6 +125,7 @@ def make_dp_model_class(cls):
self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier
self._layer_registry = layer_registry
self._clipping_loss = None
# Given that `num_microbatches` was added as an argument after the fact,
# this check helps detect unintended calls to the earlier API.
@ -176,15 +180,34 @@ def make_dp_model_class(cls):
)
def _compute_per_example_grads(self, data):
if self._clipping_loss is None:
self._make_clipping_loss()
microbatched_x, microbatched_y = data
with tf.GradientTape() as tape:
microbatched_y_pred = self(microbatched_x, training=True)
# NOTE: Calling `self.loss()` neither logs the total loss nor does it
# include any regularization terms.
microbatched_loss = self.loss(microbatched_y, microbatched_y_pred)
# NOTE: `self._clipping_loss` does not include any regularization terms.
microbatched_loss = self._clipping_loss(
microbatched_y, microbatched_y_pred
)
grads_list = tape.gradient(microbatched_loss, self.trainable_variables)
clipped_grads = self._process_per_example_grads(grads_list)
return microbatched_loss, clipped_grads
return clipped_grads
def _make_clipping_loss(self):
"""Creates a LossesContainer to be used for clipping.
To compute the privatized loss, we wrap the model's compiled_loss inside a
new LossesContainer. This lets us avoid calling model.compiled_loss, which
appends the loss value to the returned metrics (we want to avoid this as
the privatized loss does not reflect the true loss and can be misleading).
"""
losses_container_cls = self.compiled_loss.__class__
self._clipping_loss = losses_container_cls(
self.compiled_loss._user_losses, # pylint:disable=protected-access
loss_weights=self.compiled_loss._user_loss_weights, # pylint:disable=protected-access
output_names=self.output_names,
total_loss_mean=tf.keras.metrics.Mean(name=_PRIVATIZED_LOSS_NAME),
)
def train_step(self, data):
"""DP-SGD version of base class method.
@ -205,11 +228,16 @@ def make_dp_model_class(cls):
Returns:
See the base class.
"""
if self._clipping_loss is None:
self._make_clipping_loss()
output_metrics = {}
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data)
if weights is not None:
raise NotImplementedError(
'DPModel does not currently support weighted losses.'
)
batch_size = tf.shape(y)[0]
eff_num_microbatches = self._num_microbatches or batch_size
privatized_loss_name = 'privatized_loss'
# Branch based on gradient clipping algorithm.
if self._enable_fast_peg_computation:
@ -221,7 +249,7 @@ def make_dp_model_class(cls):
# trick, and uses these norms to clip the per-example gradients.
# NOTE: Reshaping of the input according to the effective number of
# microbatches is done here.
clipped_grads, y_pred, weighted_loss = (
clipped_grads, y_pred, clipping_loss = (
clip_grads.compute_clipped_gradients_and_outputs(
self,
x,
@ -229,8 +257,10 @@ def make_dp_model_class(cls):
self._l2_norm_clip,
self._layer_registry,
self._num_microbatches,
self._clipping_loss,
)
)
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
if self._noise_multiplier > 0:
grads = gradient_clipping_utils.add_aggregate_noise(
self,
@ -241,7 +271,6 @@ def make_dp_model_class(cls):
)
else:
grads = clipped_grads
output_metrics[privatized_loss_name] = weighted_loss
else:
logging.info('Computing gradients using original clipping algorithm.')
# Computes per-example clipped gradients directly. This is called
@ -249,7 +278,7 @@ def make_dp_model_class(cls):
# algorithm.
reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_num_microbatches)
microbatched_data = tf.nest.map_structure(reshape_fn, data)
microbatched_losses, clipped_grads = tf.vectorized_map(
clipped_grads = tf.vectorized_map(
self._compute_per_example_grads,
microbatched_data,
)
@ -257,11 +286,6 @@ def make_dp_model_class(cls):
grads = tf.nest.map_structure(
self._reduce_per_example_grads, clipped_grads
)
if self.loss.reduction == tf.keras.losses.Reduction.SUM:
microbatched_loss = tf.reduce_sum(microbatched_losses)
else:
microbatched_loss = tf.reduce_mean(microbatched_losses)
output_metrics[privatized_loss_name] = microbatched_loss
# Add the values and gradients contributed by regularization losses.
if self.losses:
@ -277,9 +301,10 @@ def make_dp_model_class(cls):
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
grads = [a + b for (a, b) in zip(grads, regularization_grads)]
output_metrics[privatized_loss_name] += summed_regularization_loss
if self._enable_fast_peg_computation:
output_metrics[_PRIVATIZED_LOSS_NAME] += summed_regularization_loss
# Log the true loss.
# Log the true loss, including regularization losses.
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Forward the private gradients to the optimizer and return the results.

View file

@ -28,16 +28,6 @@ def get_data():
return data, labels
def get_layer_registries():
# Outputs a list of testable layer registries.
# The empty registry {} tests the behavior of the standard approach,
# while the other one tests the fast gradient clipping algorithm.
return [
layer_registry.LayerRegistry(),
layer_registry.make_default_layer_registry(),
]
class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
def testBaseline(self):
@ -65,44 +55,49 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(model_weights[0], [[0.90], [1.20]])
self.assertAllClose(model_weights[1], [0.30])
@parameterized.named_parameters(
('l2_norm_clip 10.0', 10.0),
('l2_norm_clip 40.0', 40.0),
('l2_norm_clip 200.0', 200.0),
@parameterized.product(
l2_norm_clip=(10.0, 40.0, 200.0),
fast_clipping=(True, False),
)
def testClippingNorm(self, l2_norm_clip):
def testClippingNorm(self, l2_norm_clip, fast_clipping):
"""Tests that clipping norm works."""
train_data, train_labels = get_data()
for test_reg in get_layer_registries():
# Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
layer_registry=test_reg,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros'
),
],
)
learning_rate = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=1)
# Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros'
),
],
)
learning_rate = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
model.compile(optimizer=optimizer, loss=loss)
expected_loss = loss(train_labels, model(train_data))
results = model.fit(train_data, train_labels, epochs=1, batch_size=1)
model_weights = model.get_weights()
model_weights = model.get_weights()
unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2)
scale = min(1.0, l2_norm_clip / unclipped_gradient)
expected_weights = np.array([[90], [120]]) * scale * learning_rate
expected_bias = np.array([30]) * scale * learning_rate
unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2)
scale = min(1.0, l2_norm_clip / unclipped_gradient)
expected_weights = np.array([[90], [120]]) * scale * learning_rate
expected_bias = np.array([30]) * scale * learning_rate
# Check parameters are as expected, taking into account the learning rate.
self.assertAllClose(model_weights[0], expected_weights)
self.assertAllClose(model_weights[1], expected_bias)
# Check parameters are as expected, taking into account the learning rate.
self.assertAllClose(model_weights[0], expected_weights)
self.assertAllClose(model_weights[1], expected_bias)
# Check the value of the loss.
actual_loss = results.history['loss'][0]
self.assertAllClose(expected_loss, actual_loss)
def _compute_expected_gradients(self, data, labels, w, l2_norm_clip,
num_microbatches):
@ -127,64 +122,61 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
final_grads = np.mean(mb_grads, axis=0)
return final_grads
@parameterized.named_parameters(
('mb_test 0', 1.0, None),
('mb_test 1', 1.0, 1),
('mb_test 2', 1.0, 2),
('mb_test 4', 1.0, 4),
@parameterized.product(
num_microbatches=(None, 1, 2, 4),
fast_clipping=(False, True),
)
def testMicrobatches(self, l2_norm_clip, num_microbatches):
def testMicrobatches(self, num_microbatches, fast_clipping):
l2_norm_clip = 1.0
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
w = np.zeros((2))
train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]])
learning_rate = 1.0
for test_reg in get_layer_registries():
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
# Simple linear model returns w * x.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layer_registry=test_reg,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, use_bias=False, kernel_initializer='zeros'
),
],
)
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
# Simple linear model returns w * x.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, use_bias=False, kernel_initializer='zeros'
),
],
)
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
model_weights = np.squeeze(model.get_weights())
model_weights = np.squeeze(model.get_weights())
effective_num_microbatches = (
train_data.shape[0]
if model._num_microbatches is None
else num_microbatches
)
effective_num_microbatches = (
train_data.shape[0]
if model._num_microbatches is None
else num_microbatches
)
expected_grads = self._compute_expected_gradients(
train_data, train_labels, w, l2_norm_clip, effective_num_microbatches
)
expected_weights = np.squeeze(-learning_rate * expected_grads)
self.assertAllClose(model_weights, expected_weights)
expected_grads = self._compute_expected_gradients(
train_data, train_labels, w, l2_norm_clip, effective_num_microbatches
)
expected_weights = np.squeeze(-learning_rate * expected_grads)
self.assertAllClose(model_weights, expected_weights)
@parameterized.named_parameters(
('noise_multiplier 3 2 None', 3.0, 2.0, None),
('noise_multiplier 5 4 None', 5.0, 4.0, None),
('noise_multiplier 3 2 1', 3.0, 2.0, 1),
('noise_multiplier 5 4 1', 5.0, 4.0, 1),
('noise_multiplier 3 2 2', 3.0, 2.0, 2),
('noise_multiplier 5 4 2', 5.0, 4.0, 2),
('noise_multiplier 3 2 4', 3.0, 2.0, 4),
('noise_multiplier 5 4 4', 5.0, 4.0, 4),
@parameterized.product(
l2_norm_clip=(3.0, 5.0),
noise_multiplier=(2.0, 4.0),
num_microbatches=(None, 1, 2, 4),
fast_clipping=(False, True),
)
def testNoiseMultiplier(self, l2_norm_clip, noise_multiplier,
num_microbatches):
def testNoiseMultiplier(
self, l2_norm_clip, noise_multiplier, num_microbatches, fast_clipping
):
# The idea behind this test is to start with a model whose parameters
# are set to zero. We then run one step of a model that produces
# an un-noised gradient of zero, and then compute the standard deviation
@ -197,69 +189,69 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
learning_rate = 1.0
for test_reg in get_layer_registries():
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
# Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
num_microbatches=num_microbatches,
layer_registry=test_reg,
layers=[
tf.keras.layers.InputLayer(input_shape=(1000,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros'
),
],
)
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4)
# Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
num_microbatches=num_microbatches,
layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
layers=[
tf.keras.layers.InputLayer(input_shape=(1000,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros'
),
],
)
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4)
effective_num_microbatches = num_microbatches or train_data.shape[0]
effective_num_microbatches = num_microbatches or train_data.shape[0]
model_weights = model.get_weights()
measured_std = np.std(model_weights[0])
expected_std = (
l2_norm_clip * noise_multiplier / effective_num_microbatches
)
model_weights = model.get_weights()
measured_std = np.std(model_weights[0])
expected_std = l2_norm_clip * noise_multiplier / effective_num_microbatches
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(measured_std, expected_std, 0.1 * expected_std)
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(measured_std, expected_std, 0.1 * expected_std)
# Simple check to make sure dimensions are correct when output has
# dimension > 1.
@parameterized.named_parameters(
('mb_test None 2', None, 2),
('mb_test 1 2', 1, 2),
('mb_test 2 2', 2, 2),
('mb_test 4 4', 4, 4),
@parameterized.product(
num_microbatches=(None, 1, 2),
output_dimension=(2, 4),
fast_clipping=(False, True),
)
def testMultiDimensionalOutput(self, num_microbatches, output_dimension):
def testMultiDimensionalOutput(
self, num_microbatches, output_dimension, fast_clipping
):
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
train_labels = np.array([[0], [1], [1], [0]])
learning_rate = 1.0
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
for test_reg in get_layer_registries():
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = dp_keras_model.DPSequential(
l2_norm_clip=1.0e9,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layer_registry=test_reg,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
output_dimension, use_bias=False, kernel_initializer='zeros'
),
tf.keras.layers.Dense(1),
],
)
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
model = dp_keras_model.DPSequential(
l2_norm_clip=1.0e9,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
output_dimension, use_bias=False, kernel_initializer='zeros'
),
tf.keras.layers.Dense(1),
],
)
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
# Checks that calls to earlier API using `use_xla` as a positional argument
# raise an exception.
@ -285,10 +277,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
# Simple test to check that regularizer gradients are contributing to the
# final gradient.
@parameterized.named_parameters(
('no_registry', None),
('default_registry', layer_registry.make_default_layer_registry()),
('fast_clipping', True),
('no_fast_clipping', False),
)
def testRegularizationGradient(self, registry):
def testRegularizationGradient(self, fast_clipping):
input_dim = 10
batch_size = 2
regularizer_multiplier = 0.025
@ -306,7 +298,9 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
outputs=outputs,
l2_norm_clip=1e9,
noise_multiplier=0.0,
layer_registry=registry,
layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
)
model.compile(
loss=tf.keras.losses.MeanSquaredError(),
@ -331,10 +325,10 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
# Simple test to check that custom input regularization does NOT contribute
# to the gradient.
@parameterized.named_parameters(
('no_registry', None),
('default_registry', layer_registry.make_default_layer_registry()),
('fast_clipping', True),
('no_fast_clipping', False),
)
def testCustomRegularizationZeroGradient(self, registry):
def testCustomRegularizationZeroGradient(self, fast_clipping):
input_dim = 10
batch_size = 2
inputs = tf.keras.layers.Input((input_dim,))
@ -350,7 +344,9 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
outputs=outputs,
l2_norm_clip=1e9,
noise_multiplier=0.0,
layer_registry=registry,
layer_registry=layer_registry.make_default_layer_registry()
if fast_clipping
else None,
)
model.add_loss(tf.reduce_sum(inputs))
model.compile(