Add support for multi-headed models that use fast gradient clipping.
PiperOrigin-RevId: 627942683
This commit is contained in:
parent
fefad2190e
commit
3fa0a2d362
6 changed files with 208 additions and 32 deletions
|
@ -69,11 +69,80 @@ def get_registry_generator_fn(
|
||||||
return registry_generator_fn
|
return registry_generator_fn
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_per_example_loss_fn(model: tf.keras.Model):
|
||||||
|
"""Infer the per-example loss from model config."""
|
||||||
|
|
||||||
|
def _convert(loss_fn):
|
||||||
|
loss_config = loss_fn.get_config()
|
||||||
|
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
||||||
|
return loss_fn.from_config(loss_config)
|
||||||
|
|
||||||
|
model_loss = model.loss
|
||||||
|
if isinstance(model_loss, tf.keras.losses.Loss):
|
||||||
|
return _convert(model_loss)
|
||||||
|
elif isinstance(model_loss, dict):
|
||||||
|
# Note that we cannot call the public method `.get_compile_config()` because
|
||||||
|
# it calls a numpy function, which is not supported inside a `tf.function`
|
||||||
|
# wrapped function.
|
||||||
|
compile_config = model._compile_config.config # pylint: disable=protected-access
|
||||||
|
if compile_config is None:
|
||||||
|
raise ValueError('Model must be compiled for loss function conversion')
|
||||||
|
# Does a weighted mean of the configured losses. Note that we cannot build
|
||||||
|
# from the config of the compiled loss because (i) it builds a
|
||||||
|
# `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s
|
||||||
|
# during its construction, (ii) non-unique `tf.Variables` cannot be used
|
||||||
|
# inside a `tf.function`, which is usually where this function is used.
|
||||||
|
if 'loss_weights' not in compile_config:
|
||||||
|
raise ValueError(
|
||||||
|
'Models with multiple loss must have corresponding loss weights for'
|
||||||
|
' loss function conversion'
|
||||||
|
)
|
||||||
|
weights = compile_config['loss_weights']
|
||||||
|
per_example_losses = {k: _convert(v) for k, v in model_loss.items()}
|
||||||
|
num_losses = len(weights)
|
||||||
|
|
||||||
|
def _per_example_loss_fn(y_true, y_pred, sample_weight=None):
|
||||||
|
loss_values = []
|
||||||
|
if model_loss.keys() - y_pred.keys():
|
||||||
|
raise ValueError(
|
||||||
|
'y_pred must contain the same keys and the model losses, but '
|
||||||
|
'got %s and %s' % (y_pred.keys(), model_loss.keys())
|
||||||
|
)
|
||||||
|
if model_loss.keys() - y_true.keys():
|
||||||
|
raise ValueError(
|
||||||
|
'y_true must contain the same keys and the model losses, but '
|
||||||
|
'got %s and %s' % (y_true.keys(), model_loss.keys())
|
||||||
|
)
|
||||||
|
if sample_weight is not None:
|
||||||
|
if model_loss.keys() - sample_weight.keys():
|
||||||
|
raise ValueError(
|
||||||
|
'sample_weight must contain the same keys and the model losses,'
|
||||||
|
' but got %s and %s' % (y_true.keys(), model_loss.keys())
|
||||||
|
)
|
||||||
|
for k in y_true.keys():
|
||||||
|
sgl_sample_weight = None if sample_weight is None else sample_weight[k]
|
||||||
|
sgl_value = (
|
||||||
|
weights[k]
|
||||||
|
* per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight)
|
||||||
|
/ num_losses
|
||||||
|
)
|
||||||
|
loss_values.append(tf.reshape(sgl_value, shape=[-1]))
|
||||||
|
return tf.math.add_n(loss_values)
|
||||||
|
|
||||||
|
return _per_example_loss_fn
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Unsupported type for loss function conversion: {}'.format(
|
||||||
|
type(model_loss)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_gradient_norms(
|
def compute_gradient_norms(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
layer_registry: lr.LayerRegistry,
|
layer_registry: lr.LayerRegistry,
|
||||||
x_batch: type_aliases.InputTensors,
|
x_batch: type_aliases.InputTensors,
|
||||||
y_batch: tf.Tensor,
|
y_batch: type_aliases.OutputTensors,
|
||||||
weight_batch: Optional[tf.Tensor] = None,
|
weight_batch: Optional[tf.Tensor] = None,
|
||||||
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
|
per_example_loss_fn: Optional[type_aliases.LossFn] = None,
|
||||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||||
|
@ -94,9 +163,9 @@ def compute_gradient_norms(
|
||||||
more details.
|
more details.
|
||||||
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||||
first axis must be the batch dimension.
|
first axis must be the batch dimension.
|
||||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
y_batch: An `OutputTensor` representing a batch of output labels. The first
|
||||||
must be the batch dimension. The number of examples should match the
|
axes of the tensors must be the batch dimension. The number of examples
|
||||||
number of examples in `x_batch`.
|
should match the number of examples in `x_batch`.
|
||||||
weight_batch: Optional batch of weights, passed to the loss function.
|
weight_batch: Optional batch of weights, passed to the loss function.
|
||||||
Weights apply to the loss prior to clipping.
|
Weights apply to the loss prior to clipping.
|
||||||
per_example_loss_fn: takes as input predictions, labels and weights, and
|
per_example_loss_fn: takes as input predictions, labels and weights, and
|
||||||
|
@ -131,11 +200,11 @@ def compute_gradient_norms(
|
||||||
input_model, x_batch, generator_fn=registry_generator_fn
|
input_model, x_batch, generator_fn=registry_generator_fn
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ignore the original loss function's reduction to get per-example loss.
|
# Ignore the original loss function's reduction to get per-example loss.
|
||||||
if per_example_loss_fn is None:
|
if per_example_loss_fn is None:
|
||||||
loss_config = input_model.loss.get_config()
|
per_example_loss_fn = _infer_per_example_loss_fn(input_model)
|
||||||
loss_config['reduction'] = tf.keras.losses.Reduction.NONE
|
|
||||||
per_example_loss_fn = input_model.loss.from_config(loss_config)
|
|
||||||
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
|
losses = per_example_loss_fn(y_batch, model_outputs, weight_batch)
|
||||||
if losses.shape is None:
|
if losses.shape is None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -233,7 +302,7 @@ def compute_clipped_gradients_and_outputs(
|
||||||
l2_norm_clip: float,
|
l2_norm_clip: float,
|
||||||
layer_registry: lr.LayerRegistry,
|
layer_registry: lr.LayerRegistry,
|
||||||
x_batch: type_aliases.InputTensors,
|
x_batch: type_aliases.InputTensors,
|
||||||
y_batch: tf.Tensor,
|
y_batch: type_aliases.OutputTensors,
|
||||||
weight_batch: Optional[tf.Tensor] = None,
|
weight_batch: Optional[tf.Tensor] = None,
|
||||||
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
num_microbatches: Optional[type_aliases.BatchSize] = None,
|
||||||
clipping_loss: Optional[type_aliases.LossFn] = None,
|
clipping_loss: Optional[type_aliases.LossFn] = None,
|
||||||
|
@ -260,10 +329,10 @@ def compute_clipped_gradients_and_outputs(
|
||||||
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
squared norms of a layer's pre-activation tensor, and `vars` are relevant
|
||||||
trainable weights (see `layer_registry_factories.py` for examples).
|
trainable weights (see `layer_registry_factories.py` for examples).
|
||||||
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
x_batch: An `InputTensor` representing a batch of inputs to the model. The
|
||||||
first axis must be the batch dimension.
|
first axes of each tensor must be the batch dimension.
|
||||||
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
|
y_batch: An `OutputTensor` representing a batch of output labels. The first
|
||||||
must be the batch dimension. The number of examples should match the
|
axes of each tensor must be the batch dimension. The number of examples
|
||||||
number of examples in `x_batch`.
|
should match the number of examples in `x_batch`.
|
||||||
weight_batch: Optional vector of weights, passed to the loss function. Must
|
weight_batch: Optional vector of weights, passed to the loss function. Must
|
||||||
be of size [batch_size]. In case of microbatching, this will be reshaped
|
be of size [batch_size]. In case of microbatching, this will be reshaped
|
||||||
to [num_microbatches, batch_size/num_microbatches] before passing it to
|
to [num_microbatches, batch_size/num_microbatches] before passing it to
|
||||||
|
@ -285,6 +354,7 @@ def compute_clipped_gradients_and_outputs(
|
||||||
clipping_loss_value: the loss value weighted in such a way that its gradient
|
clipping_loss_value: the loss value weighted in such a way that its gradient
|
||||||
is `clipped_grad`.
|
is `clipped_grad`.
|
||||||
"""
|
"""
|
||||||
|
if hasattr(input_model.loss, 'reduction'):
|
||||||
if input_model.loss.reduction == 'none':
|
if input_model.loss.reduction == 'none':
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Fast gradient clipping does not support '
|
'Fast gradient clipping does not support '
|
||||||
|
@ -311,11 +381,13 @@ def compute_clipped_gradients_and_outputs(
|
||||||
weight_batch, num_microbatches
|
weight_batch, num_microbatches
|
||||||
)
|
)
|
||||||
if num_microbatches is None:
|
if num_microbatches is None:
|
||||||
clip_weights = clip_weights * weight_batch # shape [num_microbatches]
|
c = clip_weights # shape [num_microbatches]
|
||||||
else:
|
else:
|
||||||
# In this case, weight_batch is of shape [batch_size, microbatch_size],
|
# In this case, weight_batch is of shape [batch_size, microbatch_size],
|
||||||
# we multiply by the clip_weights (which is of shape [num_microbatches])
|
# we multiply by the clip_weights (which is of shape [num_microbatches])
|
||||||
clip_weights = clip_weights[:, tf.newaxis] * weight_batch
|
c = clip_weights[:, tf.newaxis]
|
||||||
|
clip_weights = tf.nest.map_structure(lambda w: c * w, weight_batch)
|
||||||
|
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
# WARNING: When num_microbatches is not None, we need to be sure that
|
# WARNING: When num_microbatches is not None, we need to be sure that
|
||||||
# `compute_loss` always computes the mean over the microbatches
|
# `compute_loss` always computes the mean over the microbatches
|
||||||
|
|
|
@ -20,13 +20,13 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
|
||||||
|
|
||||||
|
|
||||||
def maybe_add_microbatch_axis(
|
def maybe_add_microbatch_axis(
|
||||||
x: tf.Tensor,
|
x: type_aliases.PackedTensors,
|
||||||
num_microbatches: Optional[type_aliases.BatchSize],
|
num_microbatches: Optional[type_aliases.BatchSize],
|
||||||
) -> tf.Tensor:
|
) -> type_aliases.PackedTensors:
|
||||||
"""Adds the microbatch axis.
|
"""Adds the microbatch axis to a collection of tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: the input tensor.
|
x: Model output or input tensors.
|
||||||
num_microbatches: If None, x is returned unchanged. Otherwise, must divide
|
num_microbatches: If None, x is returned unchanged. Otherwise, must divide
|
||||||
the batch size.
|
the batch size.
|
||||||
|
|
||||||
|
@ -36,9 +36,13 @@ def maybe_add_microbatch_axis(
|
||||||
"""
|
"""
|
||||||
if num_microbatches is None:
|
if num_microbatches is None:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def _expand(t):
|
||||||
with tf.control_dependencies(
|
with tf.control_dependencies(
|
||||||
[tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)]
|
[tf.assert_equal(tf.math.floormod(tf.shape(t)[0], num_microbatches), 0)]
|
||||||
):
|
):
|
||||||
return tf.reshape(
|
return tf.reshape(
|
||||||
x, tf.concat([[num_microbatches, -1], tf.shape(x)[1:]], axis=0)
|
t, tf.concat([[num_microbatches, -1], tf.shape(t)[1:]], axis=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return tf.nest.map_structure(_expand, x)
|
||||||
|
|
|
@ -144,6 +144,30 @@ def all_trainable_layers_are_registered(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_loss_reduction_type(model: tf.keras.Model):
|
||||||
|
"""Infers what type of loss reduction is being performed."""
|
||||||
|
model_loss = model.loss
|
||||||
|
if isinstance(model_loss, tf.keras.losses.Loss):
|
||||||
|
return model_loss.reduction
|
||||||
|
elif isinstance(model.loss, dict):
|
||||||
|
reductions = set()
|
||||||
|
compiled_loss = model.compiled_loss
|
||||||
|
if compiled_loss is None:
|
||||||
|
raise ValueError('Model must be compiled for adding noise')
|
||||||
|
new_config_list = compiled_loss.get_config()['losses']
|
||||||
|
for loss_config in new_config_list:
|
||||||
|
reductions.add(loss_config['config']['reduction'])
|
||||||
|
if len(reductions) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
'Reductions in models with multiple losses must all be the same'
|
||||||
|
)
|
||||||
|
return reductions.pop()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Unsupported type for adding noise: {}'.format(type(model_loss))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_aggregate_noise(
|
def add_aggregate_noise(
|
||||||
clipped_grads: list[tf.Tensor],
|
clipped_grads: list[tf.Tensor],
|
||||||
batch_size: tf.Tensor,
|
batch_size: tf.Tensor,
|
||||||
|
@ -194,7 +218,7 @@ def add_aggregate_noise(
|
||||||
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||||
tf.keras.losses.Reduction.AUTO,
|
tf.keras.losses.Reduction.AUTO,
|
||||||
]
|
]
|
||||||
model_reduction = loss_model.loss.reduction
|
model_reduction = _infer_loss_reduction_type(loss_model)
|
||||||
loss_reduction = (
|
loss_reduction = (
|
||||||
'mean' if model_reduction in implicit_mean_reductions else 'sum'
|
'mean' if model_reduction in implicit_mean_reductions else 'sum'
|
||||||
)
|
)
|
||||||
|
|
|
@ -82,7 +82,6 @@ def layer_normalization_computation(
|
||||||
stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
|
stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
|
||||||
grads, num_microbatches
|
grads, num_microbatches
|
||||||
)
|
)
|
||||||
stacked_grads = tf.reduce_sum(stacked_grads, axis=1)
|
|
||||||
reduction_axes = tf.range(1, tf.rank(stacked_grads))
|
reduction_axes = tf.range(1, tf.rank(stacked_grads))
|
||||||
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
|
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)
|
||||||
|
|
||||||
|
|
|
@ -213,6 +213,22 @@ def make_dp_model_class(cls):
|
||||||
total_loss_mean=tf.keras.metrics.Mean(name=_PRIVATIZED_LOSS_NAME),
|
total_loss_mean=tf.keras.metrics.Mean(name=_PRIVATIZED_LOSS_NAME),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _infer_batch_size(self, t):
|
||||||
|
"""Infers the batch size from a tensor or a container of tensors."""
|
||||||
|
if t is None:
|
||||||
|
return None
|
||||||
|
elif isinstance(t, tf.Tensor):
|
||||||
|
t0 = t
|
||||||
|
elif isinstance(t, list):
|
||||||
|
t0 = t[0]
|
||||||
|
elif isinstance(t, dict):
|
||||||
|
t0 = list(t.values())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Unsupported type for batch size inference: {}'.format(type(t))
|
||||||
|
)
|
||||||
|
return tf.shape(t0)[0]
|
||||||
|
|
||||||
def train_step(self, data):
|
def train_step(self, data):
|
||||||
"""DP-SGD version of base class method.
|
"""DP-SGD version of base class method.
|
||||||
|
|
||||||
|
@ -236,7 +252,7 @@ def make_dp_model_class(cls):
|
||||||
self._make_clipping_loss()
|
self._make_clipping_loss()
|
||||||
output_metrics = {}
|
output_metrics = {}
|
||||||
x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data)
|
x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||||
batch_size = tf.shape(y)[0]
|
batch_size = self._infer_batch_size(y)
|
||||||
num_microbatches = self._num_microbatches or batch_size
|
num_microbatches = self._num_microbatches or batch_size
|
||||||
|
|
||||||
# Branch based on gradient clipping algorithm.
|
# Branch based on gradient clipping algorithm.
|
||||||
|
|
|
@ -425,6 +425,67 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
model.trainable_variables, tf.ones_like(model.trainable_variables)
|
model.trainable_variables, tf.ones_like(model.trainable_variables)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Checks single optimizer update and consistency with non-DP model.
|
||||||
|
@parameterized.product(
|
||||||
|
fast_clipping=[True, False],
|
||||||
|
num_microbatches=[None, 2],
|
||||||
|
)
|
||||||
|
def testWeightedLoss(self, fast_clipping, num_microbatches):
|
||||||
|
k1 = 'foo'
|
||||||
|
k2 = 'bar'
|
||||||
|
input_dim = 10
|
||||||
|
default_layer_registry = (
|
||||||
|
layer_registry.make_default_layer_registry() if fast_clipping else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_base_model(use_dp: bool):
|
||||||
|
inputs = {
|
||||||
|
k1: tf.keras.layers.Input((input_dim,)),
|
||||||
|
k2: tf.keras.layers.Input((input_dim,)),
|
||||||
|
}
|
||||||
|
dense1 = tf.keras.layers.Dense(1, kernel_initializer='ones')
|
||||||
|
dense2 = tf.keras.layers.Dense(1, kernel_initializer='ones')
|
||||||
|
outputs = {k1: dense1(inputs[k1]), k2: dense2(inputs[k2])}
|
||||||
|
if use_dp:
|
||||||
|
base_model = dp_keras_model.DPModel(
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=outputs,
|
||||||
|
l2_norm_clip=1e9,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
layer_registry=default_layer_registry,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
base_model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
losses = {
|
||||||
|
k1: tf.keras.losses.MeanSquaredError(),
|
||||||
|
k2: tf.keras.losses.MeanAbsoluteError(),
|
||||||
|
}
|
||||||
|
loss_weights = {k1: 0.4, k2: 0.6}
|
||||||
|
sgd = tf.keras.optimizers.SGD(1.0)
|
||||||
|
base_model.compile(loss=losses, loss_weights=loss_weights, optimizer=sgd)
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
dp_model = _make_base_model(use_dp=True)
|
||||||
|
model = _make_base_model(use_dp=False)
|
||||||
|
batch_size = 4
|
||||||
|
x_batch = {
|
||||||
|
k1: tf.reshape(
|
||||||
|
tf.range(input_dim * batch_size, dtype=tf.float32),
|
||||||
|
[batch_size, input_dim],
|
||||||
|
),
|
||||||
|
k2: 2.0 * tf.reshape(
|
||||||
|
tf.range(input_dim * batch_size, dtype=tf.float32),
|
||||||
|
[batch_size, input_dim],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
y_batch = {k1: tf.zeros([batch_size, 1]), k2: tf.ones([batch_size, 1])}
|
||||||
|
|
||||||
|
# Checks that gradients align.
|
||||||
|
model.fit(x=x_batch, y=y_batch, epochs=1, batch_size=4, shuffle=False)
|
||||||
|
dp_model.fit(x=x_batch, y=y_batch, epochs=1, batch_size=4, shuffle=False)
|
||||||
|
self.assertAllClose(dp_model.trainable_variables, model.trainable_variables)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue