Add support for multi-headed models that use fast gradient clipping.

PiperOrigin-RevId: 627942683
This commit is contained in:
William Kong 2024-04-24 20:45:59 -07:00 committed by A. Unique TensorFlower
parent fefad2190e
commit 3fa0a2d362
6 changed files with 208 additions and 32 deletions

View file

@ -69,11 +69,80 @@ def get_registry_generator_fn(
return registry_generator_fn
def _infer_per_example_loss_fn(model: tf.keras.Model):
"""Infer the per-example loss from model config."""
def _convert(loss_fn):
loss_config = loss_fn.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
return loss_fn.from_config(loss_config)
model_loss = model.loss
if isinstance(model_loss, tf.keras.losses.Loss):
return _convert(model_loss)
elif isinstance(model_loss, dict):
# Note that we cannot call the public method `.get_compile_config()` because
# it calls a numpy function, which is not supported inside a `tf.function`
# wrapped function.
compile_config = model._compile_config.config # pylint: disable=protected-access
if compile_config is None:
raise ValueError('Model must be compiled for loss function conversion')
# Does a weighted mean of the configured losses. Note that we cannot build
# from the config of the compiled loss because (i) it builds a
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
# during its construction, (ii) non-unique `tf.Variables` cannot be used
# inside a `tf.function`, which is usually where this function is used.
if 'loss_weights' not in compile_config:
raise ValueError(
'Models with multiple loss must have corresponding loss weights for'
' loss function conversion'
)
weights = compile_config['loss_weights']
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
num_losses = len(weights)
def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
loss_values = []
if model_loss.keys() - y_pred.keys():
raise ValueError(
'y_pred must contain the same keys and the model losses, but '
'got %s and %s' % (y_pred.keys(), model_loss.keys())
)
if model_loss.keys() - y_true.keys():
raise ValueError(
'y_true must contain the same keys and the model losses, but '
'got %s and %s' % (y_true.keys(), model_loss.keys())
)
if sample_weight is not None:
if model_loss.keys() - sample_weight.keys():
raise ValueError(
'sample_weight must contain the same keys and the model losses,'
' but got %s and %s' % (y_true.keys(), model_loss.keys())
)
for k in y_true.keys():
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
sgl_value = (
weights[k]
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
/ num_losses
)
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
return tf.math.add_n(loss_values)
return _per_example_loss_fn
else:
raise ValueError(
'Unsupported type for loss function conversion: {}'.format(
type(model_loss)
)
)
def compute_gradient_norms(
input_model: tf.keras.Model,
layer_registry: lr.LayerRegistry,
x_batch: type_aliases.InputTensors,
y_batch: tf.Tensor,
y_batch: type_aliases.OutputTensors,
weight_batch: Optional[tf.Tensor] = None,
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None,
@ -94,9 +163,9 @@ def compute_gradient_norms(
more details.
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
y_batch: An `OutputTensor` representing a batch of output labels. The first
axes of the tensors must be the batch dimension. The number of examples
should match the number of examples in `x_batch`.
weight_batch: Optional batch of weights, passed to the loss function.
Weights apply to the loss prior to clipping.
per_example_loss_fn: takes as input predictions, labels and weights, and
@ -131,11 +200,11 @@ def compute_gradient_norms(
input_model, x_batch, generator_fn=registry_generator_fn
)
)
# Ignore the original loss function's reduction to get per-example loss.
if per_example_loss_fn is None:
loss_config = input_model.loss.get_config()
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
per_example_loss_fn = input_model.loss.from_config(loss_config)
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
if losses.shape is None:
raise NotImplementedError(
@ -233,7 +302,7 @@ def compute_clipped_gradients_and_outputs(
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
x_batch: type_aliases.InputTensors,
y_batch: tf.Tensor,
y_batch: type_aliases.OutputTensors,
weight_batch: Optional[tf.Tensor] = None,
num_microbatches: Optional[type_aliases.BatchSize] = None,
clipping_loss: Optional[type_aliases.LossFn] = None,
@ -260,10 +329,10 @@ def compute_clipped_gradients_and_outputs(
squared norms of a layer's pre-activation tensor, and `vars` are relevant
trainable weights (see `layer_registry_factories.py` for examples).
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
number of examples in `x_batch`.
first axes of each tensor must be the batch dimension.
y_batch: An `OutputTensor` representing a batch of output labels. The first
axes of each tensor must be the batch dimension. The number of examples
should match the number of examples in `x_batch`.
weight_batch: Optional vector of weights, passed to the loss function. Must
be of size [batch_size]. In case of microbatching, this will be reshaped
to [num_microbatches, batch_size/num_microbatches] before passing it to
@ -285,6 +354,7 @@ def compute_clipped_gradients_and_outputs(
clipping_loss_value: the loss value weighted in such a way that its gradient
is `clipped_grad`.
"""
if hasattr(input_model.loss, 'reduction'):
if input_model.loss.reduction == 'none':
raise NotImplementedError(
'Fast gradient clipping does not support '
@ -311,11 +381,13 @@ def compute_clipped_gradients_and_outputs(
weight_batch, num_microbatches
)
if num_microbatches is None:
clip_weights = clip_weights * weight_batch # shape [num_microbatches]
c = clip_weights # shape [num_microbatches]
else:
# In this case, weight_batch is of shape [batch_size, microbatch_size],
# we multiply by the clip_weights (which is of shape [num_microbatches])
clip_weights = clip_weights[:, tf.newaxis] * weight_batch
c = clip_weights[:, tf.newaxis]
clip_weights = tf.nest.map_structure(lambda w: c * w, weight_batch)
with tf.GradientTape() as tape:
# WARNING: When num_microbatches is not None, we need to be sure that
# `compute_loss` always computes the mean over the microbatches

View file

@ -20,13 +20,13 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
def maybe_add_microbatch_axis(
x: tf.Tensor,
x: type_aliases.PackedTensors,
num_microbatches: Optional[type_aliases.BatchSize],
) -> tf.Tensor:
"""Adds the microbatch axis.
) -> type_aliases.PackedTensors:
"""Adds the microbatch axis to a collection of tensors.
Args:
x: the input tensor.
x: Model output or input tensors.
num_microbatches: If None, x is returned unchanged. Otherwise, must divide
the batch size.
@ -36,9 +36,13 @@ def maybe_add_microbatch_axis(
"""
if num_microbatches is None:
return x
def _expand(t):
with tf.control_dependencies(
[tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)]
[tf.assert_equal(tf.math.floormod(tf.shape(t)[0], num_microbatches), 0)]
):
return tf.reshape(
x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0)
t, tf.concat([[num_microbatches, -1], tf.shape(t)[1:]], axis=0)
)
return tf.nest.map_structure(_expand, x)

View file

@ -144,6 +144,30 @@ def all_trainable_layers_are_registered(
return True
def _infer_loss_reduction_type(model: tf.keras.Model):
"""Infers what type of loss reduction is being performed."""
model_loss = model.loss
if isinstance(model_loss, tf.keras.losses.Loss):
return model_loss.reduction
elif isinstance(model.loss, dict):
reductions = set()
compiled_loss = model.compiled_loss
if compiled_loss is None:
raise ValueError('Model must be compiled for adding noise')
new_config_list = compiled_loss.get_config()['losses']
for loss_config in new_config_list:
reductions.add(loss_config['config']['reduction'])
if len(reductions) > 1:
raise ValueError(
'Reductions in models with multiple losses must all be the same'
)
return reductions.pop()
else:
raise ValueError(
'Unsupported type for adding noise: {}'.format(type(model_loss))
)
def add_aggregate_noise(
clipped_grads: list[tf.Tensor],
batch_size: tf.Tensor,
@ -194,7 +218,7 @@ def add_aggregate_noise(
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
]
model_reduction = loss_model.loss.reduction
model_reduction = _infer_loss_reduction_type(loss_model)
loss_reduction = (
'mean' if model_reduction in implicit_mean_reductions else 'sum'
)

View file

@ -82,7 +82,6 @@ def layer_normalization_computation(
stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
grads, num_microbatches
)
stacked_grads = tf.reduce_sum(stacked_grads, axis=1)
reduction_axes = tf.range(1, tf.rank(stacked_grads))
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)

View file

@ -213,6 +213,22 @@ def make_dp_model_class(cls):
total_loss_mean=tf.keras.metrics.Mean(name=_PRIVATIZED_LOSS_NAME),
)
def _infer_batch_size(self, t):
"""Infers the batch size from a tensor or a container of tensors."""
if t is None:
return None
elif isinstance(t, tf.Tensor):
t0 = t
elif isinstance(t, list):
t0 = t[0]
elif isinstance(t, dict):
t0 = list(t.values())[0]
else:
raise ValueError(
'Unsupported type for batch size inference: {}'.format(type(t))
)
return tf.shape(t0)[0]
def train_step(self, data):
"""DP-SGD version of base class method.
@ -236,7 +252,7 @@ def make_dp_model_class(cls):
self._make_clipping_loss()
output_metrics = {}
x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data)
batch_size = tf.shape(y)[0]
batch_size = self._infer_batch_size(y)
num_microbatches = self._num_microbatches or batch_size
# Branch based on gradient clipping algorithm.

View file

@ -425,6 +425,67 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
model.trainable_variables, tf.ones_like(model.trainable_variables)
)
# Checks single optimizer update and consistency with non-DP model.
@parameterized.product(
fast_clipping=[True, False],
num_microbatches=[None, 2],
)
def testWeightedLoss(self, fast_clipping, num_microbatches):
k1 = 'foo'
k2 = 'bar'
input_dim = 10
default_layer_registry = (
layer_registry.make_default_layer_registry() if fast_clipping else None
)
def _make_base_model(use_dp: bool):
inputs = {
k1: tf.keras.layers.Input((input_dim,)),
k2: tf.keras.layers.Input((input_dim,)),
}
dense1 = tf.keras.layers.Dense(1, kernel_initializer='ones')
dense2 = tf.keras.layers.Dense(1, kernel_initializer='ones')
outputs = {k1: dense1(inputs[k1]), k2: dense2(inputs[k2])}
if use_dp:
base_model = dp_keras_model.DPModel(
inputs=inputs,
outputs=outputs,
l2_norm_clip=1e9,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layer_registry=default_layer_registry,
)
else:
base_model = tf.keras.Model(inputs=inputs, outputs=outputs)
losses = {
k1: tf.keras.losses.MeanSquaredError(),
k2: tf.keras.losses.MeanAbsoluteError(),
}
loss_weights = {k1: 0.4, k2: 0.6}
sgd = tf.keras.optimizers.SGD(1.0)
base_model.compile(loss=losses, loss_weights=loss_weights, optimizer=sgd)
return base_model
dp_model = _make_base_model(use_dp=True)
model = _make_base_model(use_dp=False)
batch_size = 4
x_batch = {
k1: tf.reshape(
tf.range(input_dim * batch_size, dtype=tf.float32),
[batch_size, input_dim],
),
k2: 2.0 * tf.reshape(
tf.range(input_dim * batch_size, dtype=tf.float32),
[batch_size, input_dim],
),
}
y_batch = {k1: tf.zeros([batch_size, 1]), k2: tf.ones([batch_size, 1])}
# Checks that gradients align.
model.fit(x=x_batch, y=y_batch, epochs=1, batch_size=4, shuffle=False)
dp_model.fit(x=x_batch, y=y_batch, epochs=1, batch_size=4, shuffle=False)
self.assertAllClose(dp_model.trainable_variables, model.trainable_variables)
if __name__ == '__main__':
tf.test.main()