Add support for multi-headed models that use fast gradient clipping.
PiperOrigin-RevId: 627942683
This commit is contained in:
parent
fefad2190e
commit
3fa0a2d362
6 changed files with 208 additions and 32 deletions
|
@ -69,11 +69,80 @@ def get_registry_generator_fn(
|
|||
return registry_generator_fn
|
||||
|
||||
|
||||
def _infer_per_example_loss_fn(model: tf.keras.Model):
|
||||
"""Infer the per-example loss from model config."""
|
||||
|
||||
def _convert(loss_fn):
|
||||
loss_config = loss_fn.get_config()
|
||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||
return loss_fn.from_config(loss_config)
|
||||
|
||||
model_loss = model.loss
|
||||
if isinstance(model_loss, tf.keras.losses.Loss):
|
||||
return _convert(model_loss)
|
||||
elif isinstance(model_loss, dict):
|
||||
# Note that we cannot call the public method `.get_compile_config()` because
|
||||
# it calls a numpy function, which is not supported inside a `tf.function`
|
||||
# wrapped function.
|
||||
compile_config = model._compile_config.config # pylint: disable=protected-access
|
||||
if compile_config is None:
|
||||
raise ValueError('Model must be compiled for loss function conversion')
|
||||
# Does a weighted mean of the configured losses. Note that we cannot build
|
||||
# from the config of the compiled loss because (i) it builds a
|
||||
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
|
||||
# during its construction, (ii) non-unique `tf.Variables` cannot be used
|
||||
# inside a `tf.function`, which is usually where this function is used.
|
||||
if 'loss_weights' not in compile_config:
|
||||
raise ValueError(
|
||||
'Models with multiple loss must have corresponding loss weights for'
|
||||
' loss function conversion'
|
||||
)
|
||||
weights = compile_config['loss_weights']
|
||||
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
|
||||
num_losses = len(weights)
|
||||
|
||||
def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
|
||||
loss_values = []
|
||||
if model_loss.keys() - y_pred.keys():
|
||||
raise ValueError(
|
||||
'y_pred must contain the same keys and the model losses, but '
|
||||
'got %s and %s' % (y_pred.keys(), model_loss.keys())
|
||||
)
|
||||
if model_loss.keys() - y_true.keys():
|
||||
raise ValueError(
|
||||
'y_true must contain the same keys and the model losses, but '
|
||||
'got %s and %s' % (y_true.keys(), model_loss.keys())
|
||||
)
|
||||
if sample_weight is not None:
|
||||
if model_loss.keys() - sample_weight.keys():
|
||||
raise ValueError(
|
||||
'sample_weight must contain the same keys and the model losses,'
|
||||
' but got %s and %s' % (y_true.keys(), model_loss.keys())
|
||||
)
|
||||
for k in y_true.keys():
|
||||
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
|
||||
sgl_value = (
|
||||
weights[k]
|
||||
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
|
||||
/ num_losses
|
||||
)
|
||||
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
|
||||
return tf.math.add_n(loss_values)
|
||||
|
||||
return _per_example_loss_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unsupported type for loss function conversion: {}'.format(
|
||||
type(model_loss)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_gradient_norms(
|
||||
input_model: tf.keras.Model,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
x_batch: type_aliases.InputTensors,
|
||||
y_batch: tf.Tensor,
|
||||
y_batch: type_aliases.OutputTensors,
|
||||
weight_batch: Optional[tf.Tensor] = None,
|
||||
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
|
||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||
|
@ -94,9 +163,9 @@ def compute_gradient_norms(
|
|||
more details.
|
||||
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||
first axis must be the batch dimension.
|
||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||
must be the batch dimension. The number of examples should match the
|
||||
number of examples in `x_batch`.
|
||||
y_batch: An `OutputTensor` representing a batch of output labels. The first
|
||||
axes of the tensors must be the batch dimension. The number of examples
|
||||
should match the number of examples in `x_batch`.
|
||||
weight_batch: Optional batch of weights, passed to the loss function.
|
||||
Weights apply to the loss prior to clipping.
|
||||
per_example_loss_fn: takes as input predictions, labels and weights, and
|
||||
|
@ -131,11 +200,11 @@ def compute_gradient_norms(
|
|||
input_model, x_batch, generator_fn=registry_generator_fn
|
||||
)
|
||||
)
|
||||
|
||||
# Ignore the original loss function's reduction to get per-example loss.
|
||||
if per_example_loss_fn is None:
|
||||
loss_config = input_model.loss.get_config()
|
||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
||||
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
|
||||
|
||||
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
|
||||
if losses.shape is None:
|
||||
raise NotImplementedError(
|
||||
|
@ -233,7 +302,7 @@ def compute_clipped_gradients_and_outputs(
|
|||
l2_norm_clip: float,
|
||||
layer_registry: lr.LayerRegistry,
|
||||
x_batch: type_aliases.InputTensors,
|
||||
y_batch: tf.Tensor,
|
||||
y_batch: type_aliases.OutputTensors,
|
||||
weight_batch: Optional[tf.Tensor] = None,
|
||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||
clipping_loss: Optional[type_aliases.LossFn] = None,
|
||||
|
@ -260,10 +329,10 @@ def compute_clipped_gradients_and_outputs(
|
|||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||
trainable weights (see `layer_registry_factories.py` for examples).
|
||||
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||
first axis must be the batch dimension.
|
||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
||||
must be the batch dimension. The number of examples should match the
|
||||
number of examples in `x_batch`.
|
||||
first axes of each tensor must be the batch dimension.
|
||||
y_batch: An `OutputTensor` representing a batch of output labels. The first
|
||||
axes of each tensor must be the batch dimension. The number of examples
|
||||
should match the number of examples in `x_batch`.
|
||||
weight_batch: Optional vector of weights, passed to the loss function. Must
|
||||
be of size [batch_size]. In case of microbatching, this will be reshaped
|
||||
to [num_microbatches, batch_size/num_microbatches] before passing it to
|
||||
|
@ -285,11 +354,12 @@ def compute_clipped_gradients_and_outputs(
|
|||
clipping_loss_value: the loss value weighted in such a way that its gradient
|
||||
is `clipped_grad`.
|
||||
"""
|
||||
if input_model.loss.reduction == 'none':
|
||||
raise NotImplementedError(
|
||||
'Fast gradient clipping does not support '
|
||||
'models with unreduced loss functions.'
|
||||
)
|
||||
if hasattr(input_model.loss, 'reduction'):
|
||||
if input_model.loss.reduction == 'none':
|
||||
raise NotImplementedError(
|
||||
'Fast gradient clipping does not support '
|
||||
'models with unreduced loss functions.'
|
||||
)
|
||||
if clipping_loss is None:
|
||||
clipping_loss = input_model.compiled_loss
|
||||
gradient_norms = compute_gradient_norms(
|
||||
|
@ -311,11 +381,13 @@ def compute_clipped_gradients_and_outputs(
|
|||
weight_batch, num_microbatches
|
||||
)
|
||||
if num_microbatches is None:
|
||||
clip_weights = clip_weights * weight_batch # shape [num_microbatches]
|
||||
c = clip_weights # shape [num_microbatches]
|
||||
else:
|
||||
# In this case, weight_batch is of shape [batch_size, microbatch_size],
|
||||
# we multiply by the clip_weights (which is of shape [num_microbatches])
|
||||
clip_weights = clip_weights[:, tf.newaxis] * weight_batch
|
||||
c = clip_weights[:, tf.newaxis]
|
||||
clip_weights = tf.nest.map_structure(lambda w: c * w, weight_batch)
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
# WARNING: When num_microbatches is not None, we need to be sure that
|
||||
# `compute_loss` always computes the mean over the microbatches
|
||||
|
|
|
@ -20,13 +20,13 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
|||
|
||||
|
||||
def maybe_add_microbatch_axis(
|
||||
x: tf.Tensor,
|
||||
x: type_aliases.PackedTensors,
|
||||
num_microbatches: Optional[type_aliases.BatchSize],
|
||||
) -> tf.Tensor:
|
||||
"""Adds the microbatch axis.
|
||||
) -> type_aliases.PackedTensors:
|
||||
"""Adds the microbatch axis to a collection of tensors.
|
||||
|
||||
Args:
|
||||
x: the input tensor.
|
||||
x: Model output or input tensors.
|
||||
num_microbatches: If None, x is returned unchanged. Otherwise, must divide
|
||||
the batch size.
|
||||
|
||||
|
@ -36,9 +36,13 @@ def maybe_add_microbatch_axis(
|
|||
"""
|
||||
if num_microbatches is None:
|
||||
return x
|
||||
with tf.control_dependencies(
|
||||
[tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)]
|
||||
):
|
||||
return tf.reshape(
|
||||
x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0)
|
||||
)
|
||||
|
||||
def _expand(t):
|
||||
with tf.control_dependencies(
|
||||
[tf.assert_equal(tf.math.floormod(tf.shape(t)[0], num_microbatches), 0)]
|
||||
):
|
||||
return tf.reshape(
|
||||
t, tf.concat([[num_microbatches, -1], tf.shape(t)[1:]], axis=0)
|
||||
)
|
||||
|
||||
return tf.nest.map_structure(_expand, x)
|
||||
|
|
|
@ -144,6 +144,30 @@ def all_trainable_layers_are_registered(
|
|||
return True
|
||||
|
||||
|
||||
def _infer_loss_reduction_type(model: tf.keras.Model):
|
||||
"""Infers what type of loss reduction is being performed."""
|
||||
model_loss = model.loss
|
||||
if isinstance(model_loss, tf.keras.losses.Loss):
|
||||
return model_loss.reduction
|
||||
elif isinstance(model.loss, dict):
|
||||
reductions = set()
|
||||
compiled_loss = model.compiled_loss
|
||||
if compiled_loss is None:
|
||||
raise ValueError('Model must be compiled for adding noise')
|
||||
new_config_list = compiled_loss.get_config()['losses']
|
||||
for loss_config in new_config_list:
|
||||
reductions.add(loss_config['config']['reduction'])
|
||||
if len(reductions) > 1:
|
||||
raise ValueError(
|
||||
'Reductions in models with multiple losses must all be the same'
|
||||
)
|
||||
return reductions.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unsupported type for adding noise: {}'.format(type(model_loss))
|
||||
)
|
||||
|
||||
|
||||
def add_aggregate_noise(
|
||||
clipped_grads: list[tf.Tensor],
|
||||
batch_size: tf.Tensor,
|
||||
|
@ -194,7 +218,7 @@ def add_aggregate_noise(
|
|||
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||
tf.keras.losses.Reduction.AUTO,
|
||||
]
|
||||
model_reduction = loss_model.loss.reduction
|
||||
model_reduction = _infer_loss_reduction_type(loss_model)
|
||||
loss_reduction = (
|
||||
'mean' if model_reduction in implicit_mean_reductions else 'sum'
|
||||
)
|
||||
|
|
|
@ -82,7 +82,6 @@ def layer_normalization_computation(
|
|||
stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
|
||||
grads, num_microbatches
|
||||
)
|
||||
stacked_grads = tf.reduce_sum(stacked_grads, axis=1)
|
||||
reduction_axes = tf.range(1, tf.rank(stacked_grads))
|
||||
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
|
||||
|
||||
|
|
|
@ -213,6 +213,22 @@ def make_dp_model_class(cls):
|
|||
total_loss_mean=tf.keras.metrics.Mean(name=_PRIVATIZED_LOSS_NAME),
|
||||
)
|
||||
|
||||
def _infer_batch_size(self, t):
|
||||
"""Infers the batch size from a tensor or a container of tensors."""
|
||||
if t is None:
|
||||
return None
|
||||
elif isinstance(t, tf.Tensor):
|
||||
t0 = t
|
||||
elif isinstance(t, list):
|
||||
t0 = t[0]
|
||||
elif isinstance(t, dict):
|
||||
t0 = list(t.values())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unsupported type for batch size inference: {}'.format(type(t))
|
||||
)
|
||||
return tf.shape(t0)[0]
|
||||
|
||||
def train_step(self, data):
|
||||
"""DP-SGD version of base class method.
|
||||
|
||||
|
@ -236,7 +252,7 @@ def make_dp_model_class(cls):
|
|||
self._make_clipping_loss()
|
||||
output_metrics = {}
|
||||
x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||
batch_size = tf.shape(y)[0]
|
||||
batch_size = self._infer_batch_size(y)
|
||||
num_microbatches = self._num_microbatches or batch_size
|
||||
|
||||
# Branch based on gradient clipping algorithm.
|
||||
|
|
|
@ -425,6 +425,67 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
|||
model.trainable_variables, tf.ones_like(model.trainable_variables)
|
||||
)
|
||||
|
||||
# Checks single optimizer update and consistency with non-DP model.
|
||||
@parameterized.product(
|
||||
fast_clipping=[True, False],
|
||||
num_microbatches=[None, 2],
|
||||
)
|
||||
def testWeightedLoss(self, fast_clipping, num_microbatches):
|
||||
k1 = 'foo'
|
||||
k2 = 'bar'
|
||||
input_dim = 10
|
||||
default_layer_registry = (
|
||||
layer_registry.make_default_layer_registry() if fast_clipping else None
|
||||
)
|
||||
|
||||
def _make_base_model(use_dp: bool):
|
||||
inputs = {
|
||||
k1: tf.keras.layers.Input((input_dim,)),
|
||||
k2: tf.keras.layers.Input((input_dim,)),
|
||||
}
|
||||
dense1 = tf.keras.layers.Dense(1, kernel_initializer='ones')
|
||||
dense2 = tf.keras.layers.Dense(1, kernel_initializer='ones')
|
||||
outputs = {k1: dense1(inputs[k1]), k2: dense2(inputs[k2])}
|
||||
if use_dp:
|
||||
base_model = dp_keras_model.DPModel(
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
l2_norm_clip=1e9,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=num_microbatches,
|
||||
layer_registry=default_layer_registry,
|
||||
)
|
||||
else:
|
||||
base_model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
losses = {
|
||||
k1: tf.keras.losses.MeanSquaredError(),
|
||||
k2: tf.keras.losses.MeanAbsoluteError(),
|
||||
}
|
||||
loss_weights = {k1: 0.4, k2: 0.6}
|
||||
sgd = tf.keras.optimizers.SGD(1.0)
|
||||
base_model.compile(loss=losses, loss_weights=loss_weights, optimizer=sgd)
|
||||
return base_model
|
||||
|
||||
dp_model = _make_base_model(use_dp=True)
|
||||
model = _make_base_model(use_dp=False)
|
||||
batch_size = 4
|
||||
x_batch = {
|
||||
k1: tf.reshape(
|
||||
tf.range(input_dim * batch_size, dtype=tf.float32),
|
||||
[batch_size, input_dim],
|
||||
),
|
||||
k2: 2.0 * tf.reshape(
|
||||
tf.range(input_dim * batch_size, dtype=tf.float32),
|
||||
[batch_size, input_dim],
|
||||
),
|
||||
}
|
||||
y_batch = {k1: tf.zeros([batch_size, 1]), k2: tf.ones([batch_size, 1])}
|
||||
|
||||
# Checks that gradients align.
|
||||
model.fit(x=x_batch, y=y_batch, epochs=1, batch_size=4, shuffle=False)
|
||||
dp_model.fit(x=x_batch, y=y_batch, epochs=1, batch_size=4, shuffle=False)
|
||||
self.assertAllClose(dp_model.trainable_variables, model.trainable_variables)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue