Report the true loss in DPModel instead of the norm-adjusted loss.
PiperOrigin-RevId: 517112812
This commit is contained in:
parent
8f4ab1a8bb
commit
043e8b5272
3 changed files with 215 additions and 73 deletions
|
@ -21,7 +21,7 @@ of the approach given in https://arxiv.org/pdf/2009.03106.pdf (see the
|
||||||
`compute_gradient_norms()` function).
|
`compute_gradient_norms()` function).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, Iterable, Optional, Text, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
@ -129,11 +129,14 @@ def compute_gradient_norms(
|
||||||
vars_list = [a for (a, b) in filtered_outputs]
|
vars_list = [a for (a, b) in filtered_outputs]
|
||||||
sqr_norm_fns_list = [b for (a, b) in filtered_outputs]
|
sqr_norm_fns_list = [b for (a, b) in filtered_outputs]
|
||||||
# Second loop evaluates the squared L2 norm functions and appends the results.
|
# Second loop evaluates the squared L2 norm functions and appends the results.
|
||||||
grads_list = tape.gradient(summed_loss, vars_list)
|
grads_list = tape.gradient(
|
||||||
|
summed_loss,
|
||||||
|
vars_list,
|
||||||
|
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||||
|
)
|
||||||
sqr_norm_list = []
|
sqr_norm_list = []
|
||||||
for grads, f in zip(grads_list, sqr_norm_fns_list):
|
for grads, f in zip(grads_list, sqr_norm_fns_list):
|
||||||
sqr_norm_list.append(f(grads))
|
sqr_norm_list.append(f(grads))
|
||||||
del tape
|
|
||||||
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
|
sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1)
|
||||||
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
|
||||||
|
|
||||||
|
@ -163,24 +166,23 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
|
||||||
return l2_norm_clip / tf.math.maximum(l2_norm_clip, gradient_norms)
|
return l2_norm_clip / tf.math.maximum(l2_norm_clip, gradient_norms)
|
||||||
|
|
||||||
|
|
||||||
def compute_pred_and_clipped_gradients(
|
def compute_clipped_gradients_and_outputs(
|
||||||
input_model: tf.keras.Model,
|
input_model: tf.keras.Model,
|
||||||
x_batch: InputTensor,
|
x_batch: InputTensor,
|
||||||
y_batch: tf.Tensor,
|
y_batch: tf.Tensor,
|
||||||
l2_norm_clip: float,
|
l2_norm_clip: float,
|
||||||
layer_registry: lr.LayerRegistry,
|
layer_registry: lr.LayerRegistry,
|
||||||
num_microbatches: Optional[lr.BatchSize] = None,
|
num_microbatches: Optional[lr.BatchSize] = None,
|
||||||
):
|
) -> Tuple[List[tf.Tensor], tf.Tensor, float]:
|
||||||
"""Computes the per-example predictions and per-example clipped loss gradient.
|
"""Computes the per-example clipped loss gradient and other useful outputs.
|
||||||
|
|
||||||
Given a batch of observations `(x_batch, y_batch)`, the main steps of this
|
Given a batch of observations `(x_batch, y_batch)`, the main steps of this
|
||||||
function are: (i) compute the l2-norm of the gradients of the trainable
|
function are: (i) compute the l2-norm of the gradients of the trainable
|
||||||
variables of `input_model` for each example in the batch; (ii) use the norms
|
variables of `input_model` for each example in the batch; (ii) use the norms
|
||||||
computed in (i) to obtain "clip_weights" that are used to generate a weighted
|
computed in (i) to obtain "clip_weights" that are used to generate a weighted
|
||||||
loss function whose gradient for each example has l2-norm at most
|
loss function whose gradient for each example has l2-norm at most
|
||||||
`l2_norm_clip`; (iii) output the clipped gradients in (ii) and the
|
`l2_norm_clip`; (iii) output the clipped gradients in (ii) and other useful
|
||||||
`tf.Tensor` generated by `input_model` when it is given `x_batch` as its
|
outputs to the caller.
|
||||||
input.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_model: The `tf.keras.Model` from which to obtain the layers from.
|
input_model: The `tf.keras.Model` from which to obtain the layers from.
|
||||||
|
@ -204,9 +206,11 @@ def compute_pred_and_clipped_gradients(
|
||||||
of num_microbatches).
|
of num_microbatches).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tuple` `(y_pred, grad)`. The first element is the prediction generated by
|
A `tuple` `(grad, y_pred, weighted_loss_value)`. The first element is the
|
||||||
the model on the input `x_batch`. The second element is the clipped
|
clipped gradient of the loss function, the second is the result of
|
||||||
gradient of the loss function.
|
applying `input_model` to `x_batch`, and the third is loss value of
|
||||||
|
`input_model`, weighted by the loss weights generated by a specific
|
||||||
|
`compute_clip_weights()` call.
|
||||||
"""
|
"""
|
||||||
gradient_norms = compute_gradient_norms(
|
gradient_norms = compute_gradient_norms(
|
||||||
input_model,
|
input_model,
|
||||||
|
@ -217,18 +221,37 @@ def compute_pred_and_clipped_gradients(
|
||||||
)
|
)
|
||||||
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
y_pred = input_model(x_batch, training=True)
|
# WARNING: When num_microbatches is not None, we need to be sure that
|
||||||
if num_microbatches is not None:
|
|
||||||
y_batch = lr.add_microbatch_axis(y_batch, num_microbatches)
|
|
||||||
y_pred = lr.add_microbatch_axis(y_pred, num_microbatches)
|
|
||||||
# Warning: When num_microbatches is not None, we need to be sure that
|
|
||||||
# `compute_loss` always computes the mean over the microbatches
|
# `compute_loss` always computes the mean over the microbatches
|
||||||
# as it is the assumption made when computing the gradient norm.
|
# as it is the assumption made when computing the gradient norm.
|
||||||
# It is indeed the case for multiple keras loss functions
|
# It is indeed the case for multiple keras loss functions
|
||||||
# (e.g. mean_squared_error and binary_crossentropy). However it
|
# (e.g. mean_squared_error and binary_crossentropy). However it
|
||||||
# is not defined in the contract so may not hold, especially for
|
# is not defined in the contract so may not hold, especially for
|
||||||
# custom losses.
|
# custom losses.
|
||||||
loss_value = input_model.compute_loss(
|
y_pred = input_model(x_batch, training=True)
|
||||||
x_batch, y_batch, y_pred, loss_weights
|
loss_y_batch = (
|
||||||
|
y_batch
|
||||||
|
if num_microbatches is None
|
||||||
|
else lr.add_microbatch_axis(y_batch, num_microbatches)
|
||||||
)
|
)
|
||||||
return y_pred, tape.gradient(loss_value, input_model.trainable_variables)
|
loss_y_pred = (
|
||||||
|
y_pred
|
||||||
|
if num_microbatches is None
|
||||||
|
else lr.add_microbatch_axis(y_pred, num_microbatches)
|
||||||
|
)
|
||||||
|
# NOTE: We do not log the loss values here. The caller should invoke
|
||||||
|
# `input_model.compute_loss()` to log loss values. Specifically,
|
||||||
|
# calling `input_model.compute_loss()` performs the following steps:
|
||||||
|
#
|
||||||
|
# (i) sums `input_model.loss` with the regularization losses given in
|
||||||
|
# `input_model.losses` to obtain the total loss
|
||||||
|
# (ii) evaluates the total loss with sample weights (if given)
|
||||||
|
weighted_loss_value = input_model.loss(
|
||||||
|
loss_y_batch, loss_y_pred, loss_weights
|
||||||
|
)
|
||||||
|
clipped_grads = tape.gradient(
|
||||||
|
weighted_loss_value,
|
||||||
|
input_model.trainable_variables,
|
||||||
|
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||||
|
)
|
||||||
|
return clipped_grads, y_pred, weighted_loss_value
|
||||||
|
|
|
@ -17,14 +17,14 @@ from absl import logging
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
|
||||||
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
|
||||||
|
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
|
||||||
|
|
||||||
|
|
||||||
def make_dp_model_class(cls):
|
def make_dp_model_class(cls):
|
||||||
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
|
||||||
|
|
||||||
class DPModelClass(cls): # pylint: disable=missing-class-docstring
|
class DPModelClass(cls): # pylint: disable=missing-class-docstring
|
||||||
__doc__ = (
|
__doc__ = ("""DP subclass of `{base_model}`.
|
||||||
"""DP subclass of `{base_model}`.
|
|
||||||
|
|
||||||
This can be used as a differentially private replacement for
|
This can be used as a differentially private replacement for
|
||||||
{base_model}. This class implements DP-SGD using the standard
|
{base_model}. This class implements DP-SGD using the standard
|
||||||
|
@ -32,7 +32,7 @@ def make_dp_model_class(cls):
|
||||||
|
|
||||||
This class also utilizes a faster gradient clipping algorithm if the
|
This class also utilizes a faster gradient clipping algorithm if the
|
||||||
following two conditions hold:
|
following two conditions hold:
|
||||||
(i) the trainable layers of the model are keys in the `dict` input
|
(i) the trainable layers of the model are keys in the input
|
||||||
`layer_registry`,
|
`layer_registry`,
|
||||||
(ii) the loss `tf.Tensor` for a given batch of examples is either a
|
(ii) the loss `tf.Tensor` for a given batch of examples is either a
|
||||||
scalar or a 2D `tf.Tensor` that has only one column
|
scalar or a 2D `tf.Tensor` that has only one column
|
||||||
|
@ -56,6 +56,11 @@ def make_dp_model_class(cls):
|
||||||
It is the caller's responsibility to make sure that the loss function
|
It is the caller's responsibility to make sure that the loss function
|
||||||
does behave this way.
|
does behave this way.
|
||||||
|
|
||||||
|
WARNING: This API does not have privacy guarantees for custom
|
||||||
|
layer-level losses created by the `layer.add_loss()` API. It does,
|
||||||
|
however, support layer regularization losses. All of these layer-level
|
||||||
|
losses are found in `model.losses`.
|
||||||
|
|
||||||
When instantiating this class, you need to supply several
|
When instantiating this class, you need to supply several
|
||||||
DP-related arguments followed by the standard arguments for
|
DP-related arguments followed by the standard arguments for
|
||||||
`{short_base_model}`.
|
`{short_base_model}`.
|
||||||
|
@ -83,8 +88,7 @@ def make_dp_model_class(cls):
|
||||||
model.fit(train_data, train_labels, epochs=1, batch_size=32)
|
model.fit(train_data, train_labels, epochs=1, batch_size=32)
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
""").format(
|
||||||
).format(
|
|
||||||
base_model='tf.keras.' + cls.__name__,
|
base_model='tf.keras.' + cls.__name__,
|
||||||
short_base_model=cls.__name__,
|
short_base_model=cls.__name__,
|
||||||
dp_model_class='DP' + cls.__name__,
|
dp_model_class='DP' + cls.__name__,
|
||||||
|
@ -124,8 +128,10 @@ def make_dp_model_class(cls):
|
||||||
# In particular, boolean values supplied to `use_xla` in the earlier API
|
# In particular, boolean values supplied to `use_xla` in the earlier API
|
||||||
# will raise an error.
|
# will raise an error.
|
||||||
if isinstance(num_microbatches, bool):
|
if isinstance(num_microbatches, bool):
|
||||||
raise ValueError('Boolean value supplied for `num_microbatches`. '
|
raise ValueError(
|
||||||
'Did you intend it for `use_xla`?')
|
'Boolean value supplied for `num_microbatches`. '
|
||||||
|
'Did you intend it for `use_xla`?'
|
||||||
|
)
|
||||||
self._num_microbatches = num_microbatches
|
self._num_microbatches = num_microbatches
|
||||||
|
|
||||||
# If all the trainable layers are in the input layer registry, we
|
# If all the trainable layers are in the input layer registry, we
|
||||||
|
@ -144,7 +150,8 @@ def make_dp_model_class(cls):
|
||||||
|
|
||||||
if use_xla:
|
if use_xla:
|
||||||
self.train_step = tf.function(
|
self.train_step = tf.function(
|
||||||
self.train_step, experimental_compile=True)
|
self.train_step, experimental_compile=True
|
||||||
|
)
|
||||||
|
|
||||||
def _process_per_example_grads(self, grads):
|
def _process_per_example_grads(self, grads):
|
||||||
grads_flat = tf.nest.flatten(grads)
|
grads_flat = tf.nest.flatten(grads)
|
||||||
|
@ -152,7 +159,7 @@ def make_dp_model_class(cls):
|
||||||
tf.reduce_sum(input_tensor=tf.square(g)) for g in grads_flat
|
tf.reduce_sum(input_tensor=tf.square(g)) for g in grads_flat
|
||||||
]
|
]
|
||||||
global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
|
global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
|
||||||
div = tf.maximum(global_norm / self._l2_norm_clip, 1.)
|
div = tf.maximum(global_norm / self._l2_norm_clip, 1.0)
|
||||||
clipped_flat = [g / div for g in grads_flat]
|
clipped_flat = [g / div for g in grads_flat]
|
||||||
return tf.nest.pack_sequence_as(grads, clipped_flat)
|
return tf.nest.pack_sequence_as(grads, clipped_flat)
|
||||||
|
|
||||||
|
@ -160,19 +167,23 @@ def make_dp_model_class(cls):
|
||||||
summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0)
|
summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0)
|
||||||
noise_stddev = self._l2_norm_clip * self._noise_multiplier
|
noise_stddev = self._l2_norm_clip * self._noise_multiplier
|
||||||
noise = tf.random.normal(
|
noise = tf.random.normal(
|
||||||
tf.shape(input=summed_grads), stddev=noise_stddev)
|
tf.shape(input=summed_grads), stddev=noise_stddev
|
||||||
|
)
|
||||||
noised_grads = summed_grads + noise
|
noised_grads = summed_grads + noise
|
||||||
return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype)
|
return noised_grads / tf.cast(
|
||||||
|
tf.shape(stacked_grads)[0], noised_grads.dtype
|
||||||
|
)
|
||||||
|
|
||||||
def _compute_per_example_grads(self, data):
|
def _compute_per_example_grads(self, data):
|
||||||
x, y = data
|
microbatched_x, microbatched_y = data
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
y_pred = self(x, training=True)
|
microbatched_y_pred = self(microbatched_x, training=True)
|
||||||
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
|
# NOTE: Calling `self.loss()` neither logs the total loss nor does it
|
||||||
|
# include any regularization terms.
|
||||||
grads_list = tape.gradient(loss, self.trainable_variables)
|
microbatched_loss = self.loss(microbatched_y, microbatched_y_pred)
|
||||||
|
grads_list = tape.gradient(microbatched_loss, self.trainable_variables)
|
||||||
clipped_grads = self._process_per_example_grads(grads_list)
|
clipped_grads = self._process_per_example_grads(grads_list)
|
||||||
return y_pred, loss, clipped_grads
|
return microbatched_loss, clipped_grads
|
||||||
|
|
||||||
def train_step(self, data):
|
def train_step(self, data):
|
||||||
"""DP-SGD version of base class method.
|
"""DP-SGD version of base class method.
|
||||||
|
@ -184,7 +195,7 @@ def make_dp_model_class(cls):
|
||||||
condition is satisfied if the model subclasses the keras.Sequential or
|
condition is satisfied if the model subclasses the keras.Sequential or
|
||||||
keras.engine.functional.Functional class).
|
keras.engine.functional.Functional class).
|
||||||
|
|
||||||
If (i) and (ii) above do not hold, then clips and aggregates
|
If (i) and (ii) above do not hold, then this function clips and aggregates
|
||||||
gradients at the microbatch level.
|
gradients at the microbatch level.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -193,6 +204,13 @@ def make_dp_model_class(cls):
|
||||||
Returns:
|
Returns:
|
||||||
See the base class.
|
See the base class.
|
||||||
"""
|
"""
|
||||||
|
output_metrics = {}
|
||||||
|
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||||
|
batch_size = tf.shape(y)[0]
|
||||||
|
eff_microbatch_size = self._num_microbatches or batch_size
|
||||||
|
privatized_loss_name = 'privatized_loss'
|
||||||
|
|
||||||
|
# Branch based on gradient clipping algorithm.
|
||||||
if self._enable_fast_peg_computation:
|
if self._enable_fast_peg_computation:
|
||||||
logging.info(
|
logging.info(
|
||||||
'Computing gradients using the fast per-example gradient '
|
'Computing gradients using the fast per-example gradient '
|
||||||
|
@ -200,8 +218,10 @@ def make_dp_model_class(cls):
|
||||||
)
|
)
|
||||||
# Computes the per-example gradient norms using a "fast" clipping
|
# Computes the per-example gradient norms using a "fast" clipping
|
||||||
# trick, and uses these norms to clip the per-example gradients.
|
# trick, and uses these norms to clip the per-example gradients.
|
||||||
x, y, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
|
# NOTE: Reshaping of the input according to the effective number of
|
||||||
y_pred, clipped_grads = clip_grads.compute_pred_and_clipped_gradients(
|
# microbatches is done here.
|
||||||
|
clipped_grads, y_pred, weighted_loss = (
|
||||||
|
clip_grads.compute_clipped_gradients_and_outputs(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
y,
|
y,
|
||||||
|
@ -209,51 +229,62 @@ def make_dp_model_class(cls):
|
||||||
self._layer_registry,
|
self._layer_registry,
|
||||||
self._num_microbatches,
|
self._num_microbatches,
|
||||||
)
|
)
|
||||||
batch_size = self._num_microbatches or tf.shape(y)[0]
|
)
|
||||||
grads = gradient_clipping_utils.add_aggregate_noise(
|
grads = gradient_clipping_utils.add_aggregate_noise(
|
||||||
self,
|
self,
|
||||||
clipped_grads,
|
clipped_grads,
|
||||||
batch_size,
|
eff_microbatch_size,
|
||||||
self._l2_norm_clip,
|
self._l2_norm_clip,
|
||||||
self._noise_multiplier,
|
self._noise_multiplier,
|
||||||
)
|
)
|
||||||
|
output_metrics[privatized_loss_name] = weighted_loss
|
||||||
else:
|
else:
|
||||||
logging.info('Computing gradients using microbatching.')
|
logging.info('Computing gradients using microbatching.')
|
||||||
# Computes per-example clipped gradients directly. This is called
|
# Computes per-example clipped gradients directly. This is called
|
||||||
# if at least one of the layers cannot use the "fast" gradient clipping
|
# if at least one of the layers cannot use the "fast" gradient clipping
|
||||||
# algorithm.
|
# algorithm.
|
||||||
# TODO(wkong): check if the following is valid with sample weights.
|
reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_microbatch_size)
|
||||||
_, y = data
|
microbatched_data = tf.nest.map_structure(reshape_fn, data)
|
||||||
batch_size = y.shape[0]
|
microbatched_losses, clipped_grads = tf.vectorized_map(
|
||||||
|
self._compute_per_example_grads,
|
||||||
if self._num_microbatches is None:
|
microbatched_data,
|
||||||
self._num_microbatches = batch_size
|
|
||||||
if batch_size % self._num_microbatches != 0:
|
|
||||||
raise ValueError('Number of_microbatches must divide batch size.')
|
|
||||||
|
|
||||||
def reshape_fn(x):
|
|
||||||
new_shape = (
|
|
||||||
self._num_microbatches,
|
|
||||||
batch_size // self._num_microbatches,
|
|
||||||
) + x.shape[1:]
|
|
||||||
return tf.reshape(x, new_shape)
|
|
||||||
|
|
||||||
data = tf.nest.map_structure(reshape_fn, data)
|
|
||||||
|
|
||||||
y_pred, _, per_eg_grads = tf.vectorized_map(
|
|
||||||
self._compute_per_example_grads, data
|
|
||||||
)
|
)
|
||||||
|
y_pred = self(x, training=True)
|
||||||
y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:])
|
|
||||||
|
|
||||||
grads = tf.nest.map_structure(
|
grads = tf.nest.map_structure(
|
||||||
self._reduce_per_example_grads, per_eg_grads
|
self._reduce_per_example_grads, clipped_grads
|
||||||
)
|
)
|
||||||
|
if self.loss.reduction == tf.keras.losses.Reduction.SUM:
|
||||||
|
microbatched_loss = tf.reduce_sum(microbatched_losses)
|
||||||
|
else:
|
||||||
|
microbatched_loss = tf.reduce_mean(microbatched_losses)
|
||||||
|
output_metrics[privatized_loss_name] = microbatched_loss
|
||||||
|
|
||||||
|
# Add the values and gradients contributed by regularization losses.
|
||||||
|
if self.losses:
|
||||||
|
logging.warning(
|
||||||
|
'Losses in `model.losses` must be input (batch) independent in '
|
||||||
|
'order to obtain the desired differential privacy guarantees.'
|
||||||
|
)
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
summed_regularization_loss = tf.add_n(self.losses)
|
||||||
|
regularization_grads = tape.gradient(
|
||||||
|
summed_regularization_loss,
|
||||||
|
self.trainable_variables,
|
||||||
|
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||||
|
)
|
||||||
|
grads = [a + b for (a, b) in zip(grads, regularization_grads)]
|
||||||
|
output_metrics[privatized_loss_name] += summed_regularization_loss
|
||||||
|
|
||||||
|
# Log the true loss.
|
||||||
|
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
|
||||||
|
|
||||||
# Forward the private gradients to the optimizer and return the results.
|
# Forward the private gradients to the optimizer and return the results.
|
||||||
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
||||||
self.compiled_metrics.update_state(y, y_pred)
|
self.compiled_metrics.update_state(y, y_pred)
|
||||||
return {m.name: m.result() for m in self.metrics}
|
for m in self.metrics:
|
||||||
|
output_metrics[m.name] = m.result()
|
||||||
|
|
||||||
|
return output_metrics
|
||||||
|
|
||||||
return DPModelClass
|
return DPModelClass
|
||||||
|
|
||||||
|
|
|
@ -282,5 +282,93 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Simple test to check that regularizer gradients are contributing to the
|
||||||
|
# final gradient.
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('no_registry', None),
|
||||||
|
('default_registry', layer_registry.make_default_layer_registry()),
|
||||||
|
)
|
||||||
|
def testRegularizationGradient(self, registry):
|
||||||
|
input_dim = 10
|
||||||
|
batch_size = 2
|
||||||
|
regularizer_multiplier = 0.025
|
||||||
|
inputs = tf.keras.layers.Input((input_dim,))
|
||||||
|
dense_lyr = tf.keras.layers.Dense(
|
||||||
|
1,
|
||||||
|
kernel_initializer='ones',
|
||||||
|
use_bias=False,
|
||||||
|
kernel_regularizer=tf.keras.regularizers.L2(regularizer_multiplier),
|
||||||
|
)
|
||||||
|
# Zero-out outputs to avoid contributions from the main loss function.
|
||||||
|
outputs = tf.multiply(dense_lyr(inputs), 0.0)
|
||||||
|
model = dp_keras_model.DPModel(
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=outputs,
|
||||||
|
l2_norm_clip=1e9,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
layer_registry=registry,
|
||||||
|
)
|
||||||
|
model.compile(
|
||||||
|
loss=tf.keras.losses.MeanSquaredError(),
|
||||||
|
optimizer=tf.keras.optimizers.SGD(1.0),
|
||||||
|
run_eagerly=True,
|
||||||
|
)
|
||||||
|
x_batch = tf.reshape(
|
||||||
|
tf.range(input_dim * batch_size, dtype=tf.float32),
|
||||||
|
[batch_size, input_dim],
|
||||||
|
)
|
||||||
|
y_batch = tf.zeros([batch_size, 1])
|
||||||
|
model.fit(x=x_batch, y=y_batch)
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
model.trainable_variables,
|
||||||
|
tf.multiply(
|
||||||
|
tf.ones_like(model.trainable_variables),
|
||||||
|
1.0 - 2.0 * regularizer_multiplier,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simple test to check that custom input regularization does NOT contribute
|
||||||
|
# to the gradient.
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('no_registry', None),
|
||||||
|
('default_registry', layer_registry.make_default_layer_registry()),
|
||||||
|
)
|
||||||
|
def testCustomRegularizationZeroGradient(self, registry):
|
||||||
|
input_dim = 10
|
||||||
|
batch_size = 2
|
||||||
|
inputs = tf.keras.layers.Input((input_dim,))
|
||||||
|
dense_lyr = tf.keras.layers.Dense(
|
||||||
|
1,
|
||||||
|
kernel_initializer='ones',
|
||||||
|
use_bias=False,
|
||||||
|
)
|
||||||
|
# Zero-out outputs to avoid contributions from the main loss function.
|
||||||
|
outputs = tf.multiply(dense_lyr(inputs), 0.0)
|
||||||
|
model = dp_keras_model.DPModel(
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=outputs,
|
||||||
|
l2_norm_clip=1e9,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
layer_registry=registry,
|
||||||
|
)
|
||||||
|
model.add_loss(tf.reduce_sum(inputs))
|
||||||
|
model.compile(
|
||||||
|
loss=tf.keras.losses.MeanSquaredError(),
|
||||||
|
optimizer=tf.keras.optimizers.SGD(1.0),
|
||||||
|
run_eagerly=True,
|
||||||
|
)
|
||||||
|
x_batch = tf.reshape(
|
||||||
|
tf.range(input_dim * batch_size, dtype=tf.float32),
|
||||||
|
[batch_size, input_dim],
|
||||||
|
)
|
||||||
|
y_batch = tf.zeros([batch_size, 1])
|
||||||
|
model.fit(x=x_batch, y=y_batch)
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
model.trainable_variables, tf.ones_like(model.trainable_variables)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue