Add test to ensure DP optimizers work with tf.estimator Estimators.

PiperOrigin-RevId: 228920704
This commit is contained in:
Steve Chien 2019-01-11 11:58:30 -08:00 committed by schien1729
parent 22c8b76c04
commit c30f6d776e

View file

@ -1,4 +1,4 @@
# Copyright 2018, The TensorFlow Authors.
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -24,11 +24,6 @@ import tensorflow as tf
from privacy.optimizers import dp_optimizer
try:
xrange
except NameError:
xrange = range
def loss(val0, val1):
"""Loss function that is minimized at the mean of the input points."""
@ -117,7 +112,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
gradient_op = opt.compute_gradients(loss(data0, var0), [var0])
grads = []
for _ in xrange(1000):
for _ in range(1000):
grads_and_vars = sess.run(gradient_op)
grads.append(grads_and_vars[0][0])
@ -139,5 +134,45 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
'make_optimizer_class() does not interfere with overridden version.',
'SimpleOptimizer')
def testEstimator(self):
"""Tests that DP optimizers work with tf.estimator."""
def linear_model_fn(features, labels, mode):
preds = tf.keras.layers.Dense(
1, activation='linear', name='dense').apply(features['x'])
vector_loss = tf.squared_difference(labels, preds)
scalar_loss = tf.reduce_mean(vector_loss)
optimizer = dp_optimizer.DPGradientDescentOptimizer(
l2_norm_clip=1.0,
noise_multiplier=0.0,
num_microbatches=1,
learning_rate=1.0)
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
return tf.estimator.EstimatorSpec(
mode=mode, loss=scalar_loss, train_op=train_op)
linear_regressor = tf.estimator.Estimator(model_fn=linear_model_fn)
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
true_bias = 6.0
train_data = np.random.normal(scale=3.0, size=(200, 4)).astype(np.float32)
train_labels = np.matmul(train_data,
true_weights) + true_bias + np.random.normal(
scale=0.1, size=(200, 1)).astype(np.float32)
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': train_data},
y=train_labels,
batch_size=20,
num_epochs=10,
shuffle=True)
linear_regressor.train(input_fn=train_input_fn, steps=100)
self.assertAllClose(
linear_regressor.get_variable_value('dense/kernel'),
true_weights,
atol=1.0)
if __name__ == '__main__':
tf.test.main()