Add support for multi-headed models that use fast gradient clipping.

PiperOrigin-RevId: 627942683
This commit is contained in:
William Kong 2024-04-24 20:45:59 -07:00 committed by A. Unique TensorFlower
parent fefad2190e
commit 3fa0a2d362
6 changed files with 208 additions and 32 deletions

View file

@ -69,11 +69,80 @@ def get_registry_generator_fn(
return registry_generator_fn return registry_generator_fn
def _infer_per_example_loss_fn(model: tf.keras.Model):
"""Infer the per-example loss from model config."""
def _convert(loss_fn):
loss_config = loss_fn.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
return loss_fn.from_config(loss_config)
model_loss = model.loss
if isinstance(model_loss, tf.keras.losses.Loss):
return _convert(model_loss)
elif isinstance(model_loss, dict):
# Note that we cannot call the public method `.get_compile_config()` because
# it calls a numpy function, which is not supported inside a `tf.function`
# wrapped function.
compile_config = model._compile_config.config # pylint: disable=protected-access
if compile_config is None:
raise ValueError('Model must be compiled for loss function conversion')
# Does a weighted mean of the configured losses. Note that we cannot build
# from the config of the compiled loss because (i) it builds a
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
# during its construction, (ii) non-unique `tf.Variables` cannot be used
# inside a `tf.function`, which is usually where this function is used.
if 'loss_weights' not in compile_config:
raise ValueError(
'Models with multiple loss must have corresponding loss weights for'
' loss function conversion'
)
weights = compile_config['loss_weights']
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
num_losses = len(weights)
def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
loss_values = []
if model_loss.keys() - y_pred.keys():
raise ValueError(
'y_pred must contain the same keys and the model losses, but '
'got %s and %s' % (y_pred.keys(), model_loss.keys())
)
if model_loss.keys() - y_true.keys():
raise ValueError(
'y_true must contain the same keys and the model losses, but '
'got %s and %s' % (y_true.keys(), model_loss.keys())
)
if sample_weight is not None:
if model_loss.keys() - sample_weight.keys():
raise ValueError(
'sample_weight must contain the same keys and the model losses,'
' but got %s and %s' % (y_true.keys(), model_loss.keys())
)
for k in y_true.keys():
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
sgl_value = (
weights[k]
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
/ num_losses
)
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
return tf.math.add_n(loss_values)
return _per_example_loss_fn
else:
raise ValueError(
'Unsupported type for loss function conversion: {}'.format(
type(model_loss)
)
)
def compute_gradient_norms( def compute_gradient_norms(
input_model: tf.keras.Model, input_model: tf.keras.Model,
layer_registry: lr.LayerRegistry, layer_registry: lr.LayerRegistry,
x_batch: type_aliases.InputTensors, x_batch: type_aliases.InputTensors,
y_batch: tf.Tensor, y_batch: type_aliases.OutputTensors,
weight_batch: Optional[tf.Tensor] = None, weight_batch: Optional[tf.Tensor] = None,
per_example_loss_fn: Optional[type_aliases.LossFn] = None, per_example_loss_fn: Optional[type_aliases.LossFn] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None, num_microbatches: Optional[type_aliases.BatchSize] = None,
@ -94,9 +163,9 @@ def compute_gradient_norms(
more details. more details.
x_batch: An `InputTensor` representing a batch of inputs to the model. The x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension. first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis y_batch: An `OutputTensor` representing a batch of output labels. The first
must be the batch dimension. The number of examples should match the axes of the tensors must be the batch dimension. The number of examples
number of examples in `x_batch`. should match the number of examples in `x_batch`.
weight_batch: Optional batch of weights, passed to the loss function. weight_batch: Optional batch of weights, passed to the loss function.
Weights apply to the loss prior to clipping. Weights apply to the loss prior to clipping.
per_example_loss_fn: takes as input predictions, labels and weights, and per_example_loss_fn: takes as input predictions, labels and weights, and
@ -131,11 +200,11 @@ def compute_gradient_norms(
input_model, x_batch, generator_fn=registry_generator_fn input_model, x_batch, generator_fn=registry_generator_fn
) )
) )
# Ignore the original loss function's reduction to get per-example loss. # Ignore the original loss function's reduction to get per-example loss.
if per_example_loss_fn is None: if per_example_loss_fn is None:
loss_config = input_model.loss.get_config() per_example_loss_fn = _infer_per_example_loss_fn(input_model)
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch) losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
if losses.shape is None: if losses.shape is None:
raise NotImplementedError( raise NotImplementedError(
@ -233,7 +302,7 @@ def compute_clipped_gradients_and_outputs(
l2_norm_clip: float, l2_norm_clip: float,
layer_registry: lr.LayerRegistry, layer_registry: lr.LayerRegistry,
x_batch: type_aliases.InputTensors, x_batch: type_aliases.InputTensors,
y_batch: tf.Tensor, y_batch: type_aliases.OutputTensors,
weight_batch: Optional[tf.Tensor] = None, weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None, num_microbatches: Optional[type_aliases.BatchSize] = None,
clipping_loss: Optional[type_aliases.LossFn] = None, clipping_loss: Optional[type_aliases.LossFn] = None,
@ -260,10 +329,10 @@ def compute_clipped_gradients_and_outputs(
squared norms of a layer's pre-activation tensor, and `vars` are relevant squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples). trainable weights (see `layer_registry_factories.py` for examples).
x_batch: An `InputTensor` representing a batch of inputs to the model. The x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension. first axes of each tensor must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis y_batch: An `OutputTensor` representing a batch of output labels. The first
must be the batch dimension. The number of examples should match the axes of each tensor must be the batch dimension. The number of examples
number of examples in `x_batch`. should match the number of examples in `x_batch`.
weight_batch: Optional vector of weights, passed to the loss function. Must weight_batch: Optional vector of weights, passed to the loss function. Must
be of size [batch_size]. In case of microbatching, this will be reshaped be of size [batch_size]. In case of microbatching, this will be reshaped
to [num_microbatches, batch_size/num_microbatches] before passing it to to [num_microbatches, batch_size/num_microbatches] before passing it to
@ -285,6 +354,7 @@ def compute_clipped_gradients_and_outputs(
clipping_loss_value: the loss value weighted in such a way that its gradient clipping_loss_value: the loss value weighted in such a way that its gradient
is `clipped_grad`. is `clipped_grad`.
""" """
if hasattr(input_model.loss, 'reduction'):
if input_model.loss.reduction == 'none': if input_model.loss.reduction == 'none':
raise NotImplementedError( raise NotImplementedError(
'Fast gradient clipping does not support ' 'Fast gradient clipping does not support '
@ -311,11 +381,13 @@ def compute_clipped_gradients_and_outputs(
weight_batch, num_microbatches weight_batch, num_microbatches
) )
if num_microbatches is None: if num_microbatches is None:
clip_weights = clip_weights * weight_batch # shape [num_microbatches] c = clip_weights # shape [num_microbatches]
else: else:
# In this case, weight_batch is of shape [batch_size, microbatch_size], # In this case, weight_batch is of shape [batch_size, microbatch_size],
# we multiply by the clip_weights (which is of shape [num_microbatches]) # we multiply by the clip_weights (which is of shape [num_microbatches])
clip_weights = clip_weights[:, tf.newaxis] * weight_batch c = clip_weights[:, tf.newaxis]
clip_weights = tf.nest.map_structure(lambda w: c * w, weight_batch)
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
# WARNING: When num_microbatches is not None, we need to be sure that # WARNING: When num_microbatches is not None, we need to be sure that
# `compute_loss` always computes the mean over the microbatches # `compute_loss` always computes the mean over the microbatches

View file

@ -20,13 +20,13 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def maybe_add_microbatch_axis( def maybe_add_microbatch_axis(
x: tf.Tensor, x: type_aliases.PackedTensors,
num_microbatches: Optional[type_aliases.BatchSize], num_microbatches: Optional[type_aliases.BatchSize],
) -> tf.Tensor: ) -> type_aliases.PackedTensors:
"""Adds the microbatch axis. """Adds the microbatch axis to a collection of tensors.
Args: Args:
x: the input tensor. x: Model output or input tensors.
num_microbatches: If None, x is returned unchanged. Otherwise, must divide num_microbatches: If None, x is returned unchanged. Otherwise, must divide
the batch size. the batch size.
@ -36,9 +36,13 @@ def maybe_add_microbatch_axis(
""" """
if num_microbatches is None: if num_microbatches is None:
return x return x
def _expand(t):
with tf.control_dependencies( with tf.control_dependencies(
[tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)] [tf.assert_equal(tf.math.floormod(tf.shape(t)[0], num_microbatches), 0)]
): ):
return tf.reshape( return tf.reshape(
x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0) t, tf.concat([[num_microbatches, -1], tf.shape(t)[1:]], axis=0)
) )
return tf.nest.map_structure(_expand, x)

View file

@ -144,6 +144,30 @@ def all_trainable_layers_are_registered(
return True return True
def _infer_loss_reduction_type(model: tf.keras.Model):
"""Infers what type of loss reduction is being performed."""
model_loss = model.loss
if isinstance(model_loss, tf.keras.losses.Loss):
return model_loss.reduction
elif isinstance(model.loss, dict):
reductions = set()
compiled_loss = model.compiled_loss
if compiled_loss is None:
raise ValueError('Model must be compiled for adding noise')
new_config_list = compiled_loss.get_config()['losses']
for loss_config in new_config_list:
reductions.add(loss_config['config']['reduction'])
if len(reductions) > 1:
raise ValueError(
'Reductions in models with multiple losses must all be the same'
)
return reductions.pop()
else:
raise ValueError(
'Unsupported type for adding noise: {}'.format(type(model_loss))
)
def add_aggregate_noise( def add_aggregate_noise(
clipped_grads: list[tf.Tensor], clipped_grads: list[tf.Tensor],
batch_size: tf.Tensor, batch_size: tf.Tensor,
@ -194,7 +218,7 @@ def add_aggregate_noise(
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO, tf.keras.losses.Reduction.AUTO,
] ]
model_reduction = loss_model.loss.reduction model_reduction = _infer_loss_reduction_type(loss_model)
loss_reduction = ( loss_reduction = (
'mean' if model_reduction in implicit_mean_reductions else 'sum' 'mean' if model_reduction in implicit_mean_reductions else 'sum'
) )

View file

@ -82,7 +82,6 @@ def layer_normalization_computation(
stacked_grads = common_manip_utils.maybe_add_microbatch_axis( stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
grads, num_microbatches grads, num_microbatches
) )
stacked_grads = tf.reduce_sum(stacked_grads, axis=1)
reduction_axes = tf.range(1, tf.rank(stacked_grads)) reduction_axes = tf.range(1, tf.rank(stacked_grads))
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes) return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)

View file

@ -213,6 +213,22 @@ def make_dp_model_class(cls):
total_loss_mean=tf.keras.metrics.Mean(name=_PRIVATIZED_LOSS_NAME), total_loss_mean=tf.keras.metrics.Mean(name=_PRIVATIZED_LOSS_NAME),
) )
def _infer_batch_size(self, t):
"""Infers the batch size from a tensor or a container of tensors."""
if t is None:
return None
elif isinstance(t, tf.Tensor):
t0 = t
elif isinstance(t, list):
t0 = t[0]
elif isinstance(t, dict):
t0 = list(t.values())[0]
else:
raise ValueError(
'Unsupported type for batch size inference: {}'.format(type(t))
)
return tf.shape(t0)[0]
def train_step(self, data): def train_step(self, data):
"""DP-SGD version of base class method. """DP-SGD version of base class method.
@ -236,7 +252,7 @@ def make_dp_model_class(cls):
self._make_clipping_loss() self._make_clipping_loss()
output_metrics = {} output_metrics = {}
x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data) x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data)
batch_size = tf.shape(y)[0] batch_size = self._infer_batch_size(y)
num_microbatches = self._num_microbatches or batch_size num_microbatches = self._num_microbatches or batch_size
# Branch based on gradient clipping algorithm. # Branch based on gradient clipping algorithm.

View file

@ -425,6 +425,67 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
model.trainable_variables, tf.ones_like(model.trainable_variables) model.trainable_variables, tf.ones_like(model.trainable_variables)
) )
# Checks single optimizer update and consistency with non-DP model.
@parameterized.product(
fast_clipping=[True, False],
num_microbatches=[None, 2],
)
def testWeightedLoss(self, fast_clipping, num_microbatches):
k1 = 'foo'
k2 = 'bar'
input_dim = 10
default_layer_registry = (
layer_registry.make_default_layer_registry() if fast_clipping else None
)
def _make_base_model(use_dp: bool):
inputs = {
k1: tf.keras.layers.Input((input_dim,)),
k2: tf.keras.layers.Input((input_dim,)),
}
dense1 = tf.keras.layers.Dense(1, kernel_initializer='ones')
dense2 = tf.keras.layers.Dense(1, kernel_initializer='ones')
outputs = {k1: dense1(inputs[k1]), k2: dense2(inputs[k2])}
if use_dp:
base_model = dp_keras_model.DPModel(
inputs=inputs,
outputs=outputs,
l2_norm_clip=1e9,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layer_registry=default_layer_registry,
)
else:
base_model = tf.keras.Model(inputs=inputs, outputs=outputs)
losses = {
k1: tf.keras.losses.MeanSquaredError(),
k2: tf.keras.losses.MeanAbsoluteError(),
}
loss_weights = {k1: 0.4, k2: 0.6}
sgd = tf.keras.optimizers.SGD(1.0)
base_model.compile(loss=losses, loss_weights=loss_weights, optimizer=sgd)
return base_model
dp_model = _make_base_model(use_dp=True)
model = _make_base_model(use_dp=False)
batch_size = 4
x_batch = {
k1: tf.reshape(
tf.range(input_dim * batch_size, dtype=tf.float32),
[batch_size, input_dim],
),
k2: 2.0 * tf.reshape(
tf.range(input_dim * batch_size, dtype=tf.float32),
[batch_size, input_dim],
),
}
y_batch = {k1: tf.zeros([batch_size, 1]), k2: tf.ones([batch_size, 1])}
# Checks that gradients align.
model.fit(x=x_batch, y=y_batch, epochs=1, batch_size=4, shuffle=False)
dp_model.fit(x=x_batch, y=y_batch, epochs=1, batch_size=4, shuffle=False)
self.assertAllClose(dp_model.trainable_variables, model.trainable_variables)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()