diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py index cf5519b..b970c4a 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py @@ -28,9 +28,12 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query def make_optimizer_class(cls): """Constructs a DP optimizer class from an existing one.""" parent_code = tf.train.Optimizer.compute_gradients.__code__ - child_code = cls.compute_gradients.__code__ + + has_compute_gradients = hasattr(cls, 'compute_gradients') + if has_compute_gradients: + child_code = cls.compute_gradients.__code__ GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name - if child_code is not parent_code: + if has_compute_gradients and child_code is not parent_code: logging.warning( 'WARNING: Calling make_optimizer_class() on class %s that overrides ' 'method compute_gradients(). Check to ensure that ' @@ -135,13 +138,23 @@ def make_optimizer_class(cls): def process_microbatch(i, sample_state): """Process one microbatch (record) with privacy helper.""" - grads, _ = zip( - *super(DPOptimizerClass, self).compute_gradients( - tf.reduce_mean( - input_tensor=tf.gather(microbatches_losses, - [i])), var_list, gate_gradients, - aggregation_method, colocate_gradients_with_ops, grad_loss)) + self_super = super(DPOptimizerClass, self) + + mean_loss = tf.reduce_mean(input_tensor=tf.gather( + microbatches_losses, [i])) + + if hasattr(self_super, 'compute_gradients'): + # This case covers optimizers in tf.train. + compute_gradients_fn = self_super.compute_gradients + else: + # This case covers Keras optimizers from optimizers_v2. + compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access + + grads, _ = zip(*compute_gradients_fn( + mean_loss, var_list, gate_gradients, + aggregation_method, colocate_gradients_with_ops, grad_loss)) grads_list = list(grads) + sample_state = self._dp_sum_query.accumulate_record( sample_params, sample_state, grads_list) return sample_state @@ -207,6 +220,11 @@ def make_gaussian_optimizer_class(cls): unroll_microbatches=False, *args, # pylint: disable=keyword-arg-before-vararg **kwargs): + self._l2_norm_clip = l2_norm_clip + self._noise_multiplier = noise_multiplier + self._num_microbatches = num_microbatches + self._base_optimizer_class = cls + dp_sum_query = gaussian_query.GaussianSumQuery( l2_norm_clip, l2_norm_clip * noise_multiplier) @@ -221,6 +239,25 @@ def make_gaussian_optimizer_class(cls): *args, **kwargs) + def get_config(self): + """Creates configuration for Keras serialization. + + This method will be called when Keras creates model checkpoints + and is necessary so that deserialization can be performed. + + Returns: + A dict object storing arguments to be passed to the __init__ method + upon deserialization. + """ + + config = self._base_optimizer_class.get_config(self) + config.update({ + 'l2_norm_clip': self._l2_norm_clip, + 'noise_multiplier': self._noise_multiplier, + 'num_microbatches': self._num_microbatches}) + + return config + @property def ledger(self): return self._dp_sum_query.ledger diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py index 909c674..5876b75 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from absl.testing import parameterized import mock import numpy as np @@ -296,6 +297,41 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): loss=self._loss(data0, var0), var_list=[var0, extra_variable]) sess.run(minimize_op) + def _testWriteOutAndReload(self, optimizer_cls): + optimizer = optimizer_cls(l2_norm_clip=1.0, + noise_multiplier=0.01, + num_microbatches=1) + + test_dir = self.get_temp_dir() + model_path = os.path.join(test_dir, 'model') + + model = tf.keras.Sequential([tf.keras.layers.InputLayer(input_shape=(1, 1)), + tf.keras.layers.Dense(units=1, + activation='softmax')]) + model.compile(optimizer=optimizer, + loss=tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True)) + + tf.keras.models.save_model(model, filepath=model_path, + include_optimizer=True) + + optimizer_cls_str = optimizer_cls.__name__ + tf.keras.models.load_model(model_path, + custom_objects={ + optimizer_cls_str: optimizer_cls}) + + return + + def testWriteOutAndReloadAdam(self): + optimizer_class = dp_optimizer.make_gaussian_optimizer_class( + tf.keras.optimizers.Adam) + self._testWriteOutAndReload(optimizer_class) + + def testWriteOutAndReloadSGD(self): + optimizer_class = dp_optimizer.make_gaussian_optimizer_class( + tf.keras.optimizers.SGD) + self._testWriteOutAndReload(optimizer_class) + if __name__ == '__main__': tf.test.main()