forked from 626_privacy/tensorflow_privacy
Added support for Keras optimizers and serialization.
PiperOrigin-RevId: 322603030
This commit is contained in:
parent
87c01eb2f5
commit
2ec0f36d1e
2 changed files with 81 additions and 8 deletions
|
@ -28,9 +28,12 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
|
|||
def make_optimizer_class(cls):
|
||||
"""Constructs a DP optimizer class from an existing one."""
|
||||
parent_code = tf.train.Optimizer.compute_gradients.__code__
|
||||
|
||||
has_compute_gradients = hasattr(cls, 'compute_gradients')
|
||||
if has_compute_gradients:
|
||||
child_code = cls.compute_gradients.__code__
|
||||
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||
if child_code is not parent_code:
|
||||
if has_compute_gradients and child_code is not parent_code:
|
||||
logging.warning(
|
||||
'WARNING: Calling make_optimizer_class() on class %s that overrides '
|
||||
'method compute_gradients(). Check to ensure that '
|
||||
|
@ -135,13 +138,23 @@ def make_optimizer_class(cls):
|
|||
|
||||
def process_microbatch(i, sample_state):
|
||||
"""Process one microbatch (record) with privacy helper."""
|
||||
grads, _ = zip(
|
||||
*super(DPOptimizerClass, self).compute_gradients(
|
||||
tf.reduce_mean(
|
||||
input_tensor=tf.gather(microbatches_losses,
|
||||
[i])), var_list, gate_gradients,
|
||||
self_super = super(DPOptimizerClass, self)
|
||||
|
||||
mean_loss = tf.reduce_mean(input_tensor=tf.gather(
|
||||
microbatches_losses, [i]))
|
||||
|
||||
if hasattr(self_super, 'compute_gradients'):
|
||||
# This case covers optimizers in tf.train.
|
||||
compute_gradients_fn = self_super.compute_gradients
|
||||
else:
|
||||
# This case covers Keras optimizers from optimizers_v2.
|
||||
compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access
|
||||
|
||||
grads, _ = zip(*compute_gradients_fn(
|
||||
mean_loss, var_list, gate_gradients,
|
||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||
grads_list = list(grads)
|
||||
|
||||
sample_state = self._dp_sum_query.accumulate_record(
|
||||
sample_params, sample_state, grads_list)
|
||||
return sample_state
|
||||
|
@ -207,6 +220,11 @@ def make_gaussian_optimizer_class(cls):
|
|||
unroll_microbatches=False,
|
||||
*args, # pylint: disable=keyword-arg-before-vararg
|
||||
**kwargs):
|
||||
self._l2_norm_clip = l2_norm_clip
|
||||
self._noise_multiplier = noise_multiplier
|
||||
self._num_microbatches = num_microbatches
|
||||
self._base_optimizer_class = cls
|
||||
|
||||
dp_sum_query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip, l2_norm_clip * noise_multiplier)
|
||||
|
||||
|
@ -221,6 +239,25 @@ def make_gaussian_optimizer_class(cls):
|
|||
*args,
|
||||
**kwargs)
|
||||
|
||||
def get_config(self):
|
||||
"""Creates configuration for Keras serialization.
|
||||
|
||||
This method will be called when Keras creates model checkpoints
|
||||
and is necessary so that deserialization can be performed.
|
||||
|
||||
Returns:
|
||||
A dict object storing arguments to be passed to the __init__ method
|
||||
upon deserialization.
|
||||
"""
|
||||
|
||||
config = self._base_optimizer_class.get_config(self)
|
||||
config.update({
|
||||
'l2_norm_clip': self._l2_norm_clip,
|
||||
'noise_multiplier': self._noise_multiplier,
|
||||
'num_microbatches': self._num_microbatches})
|
||||
|
||||
return config
|
||||
|
||||
@property
|
||||
def ledger(self):
|
||||
return self._dp_sum_query.ledger
|
||||
|
|
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from absl.testing import parameterized
|
||||
import mock
|
||||
import numpy as np
|
||||
|
@ -296,6 +297,41 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
loss=self._loss(data0, var0), var_list=[var0, extra_variable])
|
||||
sess.run(minimize_op)
|
||||
|
||||
def _testWriteOutAndReload(self, optimizer_cls):
|
||||
optimizer = optimizer_cls(l2_norm_clip=1.0,
|
||||
noise_multiplier=0.01,
|
||||
num_microbatches=1)
|
||||
|
||||
test_dir = self.get_temp_dir()
|
||||
model_path = os.path.join(test_dir, 'model')
|
||||
|
||||
model = tf.keras.Sequential([tf.keras.layers.InputLayer(input_shape=(1, 1)),
|
||||
tf.keras.layers.Dense(units=1,
|
||||
activation='softmax')])
|
||||
model.compile(optimizer=optimizer,
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True))
|
||||
|
||||
tf.keras.models.save_model(model, filepath=model_path,
|
||||
include_optimizer=True)
|
||||
|
||||
optimizer_cls_str = optimizer_cls.__name__
|
||||
tf.keras.models.load_model(model_path,
|
||||
custom_objects={
|
||||
optimizer_cls_str: optimizer_cls})
|
||||
|
||||
return
|
||||
|
||||
def testWriteOutAndReloadAdam(self):
|
||||
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
|
||||
tf.keras.optimizers.Adam)
|
||||
self._testWriteOutAndReload(optimizer_class)
|
||||
|
||||
def testWriteOutAndReloadSGD(self):
|
||||
optimizer_class = dp_optimizer.make_gaussian_optimizer_class(
|
||||
tf.keras.optimizers.SGD)
|
||||
self._testWriteOutAndReload(optimizer_class)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in a new issue