Ensures DPOptimizer objects can be serialized by TensorFlow.

Handles by processing tensors to numpy. Adds tests to now capture this.

PiperOrigin-RevId: 481656298
This commit is contained in:
A. Unique TensorFlower 2022-10-17 09:11:02 -07:00
parent c25cb4a41b
commit d5538fccbb
3 changed files with 77 additions and 3 deletions

View file

@ -101,6 +101,8 @@ class DPQuery(metaclass=abc.ABCMeta):
just an empty tuple for implementing classes that do not have any persistent
state.
This object must be processable via tf.nest.map_structure.
Returns:
The global state.
"""
@ -288,7 +290,8 @@ class SumAggregationDPQuery(DPQuery):
return tf.nest.map_structure(_zeros_like, template)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
"""Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`."""
"""Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`.
"""
return tf.nest.map_structure(_safe_add, sample_state, preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):

View file

@ -378,8 +378,13 @@ def make_keras_generic_optimizer_class(
Python dictionary.
"""
config = super().get_config()
# The below is necessary to ensure that the global state can be serialized
# by JSON serializers inside of tensorflow saving.
global_state_as_numpy = tf.nest.map_structure(lambda x: x.numpy(),
self._global_state)
config.update({
'global_state': self._global_state._asdict(),
'global_state': global_state_as_numpy._asdict(),
'num_microbatches': self._num_microbatches,
})
return config

View file

@ -64,7 +64,7 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0])
self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0])
def testKerasModelBaselineNoNoiseNoneMicrobatches(self):
def testKerasModelBaselineSaving(self):
"""Tests that DP optimizers work with tf.keras.Model."""
model = tf.keras.models.Sequential(layers=[
@ -87,7 +87,73 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
true_bias = np.array([6.0]).astype(np.float32)
train_data = np.random.normal(scale=3.0, size=(1000, 4)).astype(np.float32)
train_labels = np.matmul(train_data,
true_weights) + true_bias + np.random.normal(
scale=0.0, size=(1000, 1)).astype(np.float32)
model.fit(train_data, train_labels, batch_size=8, epochs=1, shuffle=False)
tempdir = self.create_tempdir()
model.save(tempdir, save_format='tf')
def testKerasModelBaselineAfterSavingLoading(self):
"""Tests that DP optimizers work with tf.keras.Model."""
model = tf.keras.models.Sequential(layers=[
tf.keras.layers.Dense(
1,
activation='linear',
name='dense',
kernel_initializer='zeros',
bias_initializer='zeros')
])
optimizer = dp_optimizer_keras.DPKerasSGDOptimizer(
l2_norm_clip=100.0,
noise_multiplier=0.0,
num_microbatches=None,
learning_rate=0.05)
loss = tf.keras.losses.MeanSquaredError(reduction='none')
model.compile(optimizer, loss)
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
true_bias = np.array([6.0]).astype(np.float32)
train_data = np.random.normal(scale=3.0, size=(1000, 4)).astype(np.float32)
train_labels = np.matmul(train_data,
true_weights) + true_bias + np.random.normal(
scale=0.0, size=(1000, 1)).astype(np.float32)
model.predict(train_data, batch_size=8)
tempdir = self.create_tempdir()
model.save(tempdir, save_format='tf')
model.load_weights(tempdir)
model.fit(train_data, train_labels, batch_size=8, epochs=1, shuffle=False)
@parameterized.named_parameters(('1', 1), ('None', None))
def testKerasModelBaselineNoNoise(self, num_microbatches):
"""Tests that DP optimizers work with tf.keras.Model."""
model = tf.keras.models.Sequential(layers=[
tf.keras.layers.Dense(
1,
activation='linear',
name='dense',
kernel_initializer='zeros',
bias_initializer='zeros')
])
optimizer = dp_optimizer_keras.DPKerasSGDOptimizer(
l2_norm_clip=100.0,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
learning_rate=0.05)
loss = tf.keras.losses.MeanSquaredError(reduction='none')
model.compile(optimizer, loss)
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
true_bias = np.array([6.0]).astype(np.float32)
train_data = np.random.normal(scale=3.0, size=(1000, 4)).astype(np.float32)
train_labels = np.matmul(train_data,
true_weights) + true_bias + np.random.normal(
scale=0.0, size=(1000, 1)).astype(np.float32)