Added support for Keras optimizers and serialization.
PiperOrigin-RevId: 322603030
This commit is contained in:
parent
87c01eb2f5
commit
2ec0f36d1e
2 changed files with 81 additions and 8 deletions
|
@ -28,9 +28,12 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
def make_optimizer_class(cls):
|
def make_optimizer_class(cls):
|
||||||
"""Constructs a DP optimizer class from an existing one."""
|
"""Constructs a DP optimizer class from an existing one."""
|
||||||
parent_code = tf.train.Optimizer.compute_gradients.__code__
|
parent_code = tf.train.Optimizer.compute_gradients.__code__
|
||||||
|
|
||||||
|
has_compute_gradients = hasattr(cls, 'compute_gradients')
|
||||||
|
if has_compute_gradients:
|
||||||
child_code = cls.compute_gradients.__code__
|
child_code = cls.compute_gradients.__code__
|
||||||
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||||
if child_code is not parent_code:
|
if has_compute_gradients and child_code is not parent_code:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
'WARNING: Calling make_optimizer_class() on class %s that overrides '
|
'WARNING: Calling make_optimizer_class() on class %s that overrides '
|
||||||
'method compute_gradients(). Check to ensure that '
|
'method compute_gradients(). Check to ensure that '
|
||||||
|
@ -135,13 +138,23 @@ def make_optimizer_class(cls):
|
||||||
|
|
||||||
def process_microbatch(i, sample_state):
|
def process_microbatch(i, sample_state):
|
||||||
"""Process one microbatch (record) with privacy helper."""
|
"""Process one microbatch (record) with privacy helper."""
|
||||||
grads, _ = zip(
|
self_super = super(DPOptimizerClass, self)
|
||||||
*super(DPOptimizerClass, self).compute_gradients(
|
|
||||||
tf.reduce_mean(
|
mean_loss = tf.reduce_mean(input_tensor=tf.gather(
|
||||||
input_tensor=tf.gather(microbatches_losses,
|
microbatches_losses, [i]))
|
||||||
[i])), var_list, gate_gradients,
|
|
||||||
|
if hasattr(self_super, 'compute_gradients'):
|
||||||
|
# This case covers optimizers in tf.train.
|
||||||
|
compute_gradients_fn = self_super.compute_gradients
|
||||||
|
else:
|
||||||
|
# This case covers Keras optimizers from optimizers_v2.
|
||||||
|
compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access
|
||||||
|
|
||||||
|
grads, _ = zip(*compute_gradients_fn(
|
||||||
|
mean_loss, var_list, gate_gradients,
|
||||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||||
grads_list = list(grads)
|
grads_list = list(grads)
|
||||||
|
|
||||||
sample_state = self._dp_sum_query.accumulate_record(
|
sample_state = self._dp_sum_query.accumulate_record(
|
||||||
sample_params, sample_state, grads_list)
|
sample_params, sample_state, grads_list)
|
||||||
return sample_state
|
return sample_state
|
||||||
|
@ -207,6 +220,11 @@ def make_gaussian_optimizer_class(cls):
|
||||||
unroll_microbatches=False,
|
unroll_microbatches=False,
|
||||||
*args, # pylint: disable=keyword-arg-before-vararg
|
*args, # pylint: disable=keyword-arg-before-vararg
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
self._l2_norm_clip = l2_norm_clip
|
||||||
|
self._noise_multiplier = noise_multiplier
|
||||||
|
self._num_microbatches = num_microbatches
|
||||||
|
self._base_optimizer_class = cls
|
||||||
|
|
||||||
dp_sum_query = gaussian_query.GaussianSumQuery(
|
dp_sum_query = gaussian_query.GaussianSumQuery(
|
||||||
l2_norm_clip, l2_norm_clip * noise_multiplier)
|
l2_norm_clip, l2_norm_clip * noise_multiplier)
|
||||||
|
|
||||||
|
@ -221,6 +239,25 @@ def make_gaussian_optimizer_class(cls):
|
||||||
*args,
|
*args,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
"""Creates configuration for Keras serialization.
|
||||||
|
|
||||||
|
This method will be called when Keras creates model checkpoints
|
||||||
|
and is necessary so that deserialization can be performed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict object storing arguments to be passed to the __init__ method
|
||||||
|
upon deserialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config = self._base_optimizer_class.get_config(self)
|
||||||
|
config.update({
|
||||||
|
'l2_norm_clip': self._l2_norm_clip,
|
||||||
|
'noise_multiplier': self._noise_multiplier,
|
||||||
|
'num_microbatches': self._num_microbatches})
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ledger(self):
|
def ledger(self):
|
||||||
return self._dp_sum_query.ledger
|
return self._dp_sum_query.ledger
|
||||||
|
|
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import mock
|
import mock
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -296,6 +297,41 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
loss=self._loss(data0, var0), var_list=[var0, extra_variable])
|
loss=self._loss(data0, var0), var_list=[var0, extra_variable])
|
||||||
sess.run(minimize_op)
|
sess.run(minimize_op)
|
||||||
|
|
||||||
|
def _testWriteOutAndReload(self, optimizer_cls):
|
||||||
|
optimizer = optimizer_cls(l2_norm_clip=1.0,
|
||||||
|
noise_multiplier=0.01,
|
||||||
|
num_microbatches=1)
|
||||||
|
|
||||||
|
test_dir = self.get_temp_dir()
|
||||||
|
model_path = os.path.join(test_dir, 'model')
|
||||||
|
|
||||||
|
model = tf.keras.Sequential([tf.keras.layers.InputLayer(input_shape=(1, 1)),
|
||||||
|
tf.keras.layers.Dense(units=1,
|
||||||
|
activation='softmax')])
|
||||||
|
model.compile(optimizer=optimizer,
|
||||||
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
|
from_logits=True))
|
||||||
|
|
||||||
|
tf.keras.models.save_model(model, filepath=model_path,
|
||||||
|
include_optimizer=True)
|
||||||
|
|
||||||
|
optimizer_cls_str = optimizer_cls.__name__
|
||||||
|
tf.keras.models.load_model(model_path,
|
||||||
|
custom_objects={
|
||||||
|
optimizer_cls_str: optimizer_cls})
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def testWriteOutAndReloadAdam(self):
|
||||||
|
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
|
||||||
|
tf.keras.optimizers.Adam)
|
||||||
|
self._testWriteOutAndReload(optimizer_class)
|
||||||
|
|
||||||
|
def testWriteOutAndReloadSGD(self):
|
||||||
|
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
|
||||||
|
tf.keras.optimizers.SGD)
|
||||||
|
self._testWriteOutAndReload(optimizer_class)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue