Added support for Keras optimizers and serialization.

PiperOrigin-RevId: 322603030
This commit is contained in:
A. Unique TensorFlower 2020-07-22 10:28:33 -07:00
parent 87c01eb2f5
commit 2ec0f36d1e
2 changed files with 81 additions and 8 deletions

View file

@ -28,9 +28,12 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
def make_optimizer_class(cls): def make_optimizer_class(cls):
"""Constructs a DP optimizer class from an existing one.""" """Constructs a DP optimizer class from an existing one."""
parent_code = tf.train.Optimizer.compute_gradients.__code__ parent_code = tf.train.Optimizer.compute_gradients.__code__
child_code = cls.compute_gradients.__code__
has_compute_gradients = hasattr(cls, 'compute_gradients')
if has_compute_gradients:
child_code = cls.compute_gradients.__code__
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
if child_code is not parent_code: if has_compute_gradients and child_code is not parent_code:
logging.warning( logging.warning(
'WARNING: Calling make_optimizer_class() on class %s that overrides ' 'WARNING: Calling make_optimizer_class() on class %s that overrides '
'method compute_gradients(). Check to ensure that ' 'method compute_gradients(). Check to ensure that '
@ -135,13 +138,23 @@ def make_optimizer_class(cls):
def process_microbatch(i, sample_state): def process_microbatch(i, sample_state):
"""Process one microbatch (record) with privacy helper.""" """Process one microbatch (record) with privacy helper."""
grads, _ = zip( self_super = super(DPOptimizerClass, self)
*super(DPOptimizerClass, self).compute_gradients(
tf.reduce_mean( mean_loss = tf.reduce_mean(input_tensor=tf.gather(
input_tensor=tf.gather(microbatches_losses, microbatches_losses, [i]))
[i])), var_list, gate_gradients,
aggregation_method, colocate_gradients_with_ops, grad_loss)) if hasattr(self_super, 'compute_gradients'):
# This case covers optimizers in tf.train.
compute_gradients_fn = self_super.compute_gradients
else:
# This case covers Keras optimizers from optimizers_v2.
compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access
grads, _ = zip(*compute_gradients_fn(
mean_loss, var_list, gate_gradients,
aggregation_method, colocate_gradients_with_ops, grad_loss))
grads_list = list(grads) grads_list = list(grads)
sample_state = self._dp_sum_query.accumulate_record( sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, grads_list) sample_params, sample_state, grads_list)
return sample_state return sample_state
@ -207,6 +220,11 @@ def make_gaussian_optimizer_class(cls):
unroll_microbatches=False, unroll_microbatches=False,
*args, # pylint: disable=keyword-arg-before-vararg *args, # pylint: disable=keyword-arg-before-vararg
**kwargs): **kwargs):
self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier
self._num_microbatches = num_microbatches
self._base_optimizer_class = cls
dp_sum_query = gaussian_query.GaussianSumQuery( dp_sum_query = gaussian_query.GaussianSumQuery(
l2_norm_clip, l2_norm_clip * noise_multiplier) l2_norm_clip, l2_norm_clip * noise_multiplier)
@ -221,6 +239,25 @@ def make_gaussian_optimizer_class(cls):
*args, *args,
**kwargs) **kwargs)
def get_config(self):
"""Creates configuration for Keras serialization.
This method will be called when Keras creates model checkpoints
and is necessary so that deserialization can be performed.
Returns:
A dict object storing arguments to be passed to the __init__ method
upon deserialization.
"""
config = self._base_optimizer_class.get_config(self)
config.update({
'l2_norm_clip': self._l2_norm_clip,
'noise_multiplier': self._noise_multiplier,
'num_microbatches': self._num_microbatches})
return config
@property @property
def ledger(self): def ledger(self):
return self._dp_sum_query.ledger return self._dp_sum_query.ledger

View file

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
from absl.testing import parameterized from absl.testing import parameterized
import mock import mock
import numpy as np import numpy as np
@ -296,6 +297,41 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
loss=self._loss(data0, var0), var_list=[var0, extra_variable]) loss=self._loss(data0, var0), var_list=[var0, extra_variable])
sess.run(minimize_op) sess.run(minimize_op)
def _testWriteOutAndReload(self, optimizer_cls):
optimizer = optimizer_cls(l2_norm_clip=1.0,
noise_multiplier=0.01,
num_microbatches=1)
test_dir = self.get_temp_dir()
model_path = os.path.join(test_dir, 'model')
model = tf.keras.Sequential([tf.keras.layers.InputLayer(input_shape=(1, 1)),
tf.keras.layers.Dense(units=1,
activation='softmax')])
model.compile(optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True))
tf.keras.models.save_model(model, filepath=model_path,
include_optimizer=True)
optimizer_cls_str = optimizer_cls.__name__
tf.keras.models.load_model(model_path,
custom_objects={
optimizer_cls_str: optimizer_cls})
return
def testWriteOutAndReloadAdam(self):
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
tf.keras.optimizers.Adam)
self._testWriteOutAndReload(optimizer_class)
def testWriteOutAndReloadSGD(self):
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
tf.keras.optimizers.SGD)
self._testWriteOutAndReload(optimizer_class)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()