forked from 626_privacy/tensorflow_privacy
Add DP versions of v1 FTRL optimizer.
PiperOrigin-RevId: 553186886
This commit is contained in:
parent
b7e9709ff7
commit
a32e6ae5d0
2 changed files with 63 additions and 20 deletions
|
@ -41,7 +41,7 @@ def make_optimizer_class(cls):
|
|||
'make_optimizer_class() does not interfere with overridden version.',
|
||||
cls.__name__)
|
||||
|
||||
class DPOptimizerClass(cls): # pylint: disable=empty-docstring
|
||||
class DPOptimizerClass(cls): # pylint: disable=missing-class-docstring
|
||||
__doc__ = ("""Differentially private subclass of `{base_class}`.
|
||||
|
||||
You can use this as a differentially private replacement for
|
||||
|
@ -278,7 +278,7 @@ def make_gaussian_optimizer_class(cls):
|
|||
A subclass of `cls` using DP-SGD with Gaussian averaging.
|
||||
"""
|
||||
|
||||
class DPGaussianOptimizerClass(make_optimizer_class(cls)): # pylint: disable=empty-docstring
|
||||
class DPGaussianOptimizerClass(make_optimizer_class(cls)): # pylint: disable=missing-class-docstring
|
||||
__doc__ = ("""DP subclass of `{}`.
|
||||
|
||||
You can use this as a differentially private replacement for
|
||||
|
@ -372,16 +372,19 @@ def make_gaussian_optimizer_class(cls):
|
|||
|
||||
AdagradOptimizer = tf.compat.v1.train.AdagradOptimizer
|
||||
AdamOptimizer = tf.compat.v1.train.AdamOptimizer
|
||||
FtrlOptimizer = tf.compat.v1.train.FtrlOptimizer
|
||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||
RMSPropOptimizer = tf.compat.v1.train.RMSPropOptimizer
|
||||
|
||||
DPAdagradOptimizer = make_optimizer_class(AdagradOptimizer)
|
||||
DPAdamOptimizer = make_optimizer_class(AdamOptimizer)
|
||||
DPFtrlOptimizer = make_optimizer_class(FtrlOptimizer)
|
||||
DPGradientDescentOptimizer = make_optimizer_class(GradientDescentOptimizer)
|
||||
DPRMSPropOptimizer = make_optimizer_class(RMSPropOptimizer)
|
||||
|
||||
DPAdagradGaussianOptimizer = make_gaussian_optimizer_class(AdagradOptimizer)
|
||||
DPAdamGaussianOptimizer = make_gaussian_optimizer_class(AdamOptimizer)
|
||||
DPFtrlGaussianOptimizer = make_gaussian_optimizer_class(FtrlOptimizer)
|
||||
DPGradientDescentGaussianOptimizer = make_gaussian_optimizer_class(
|
||||
GradientDescentOptimizer)
|
||||
DPRMSPropGaussianOptimizer = make_gaussian_optimizer_class(RMSPropOptimizer)
|
||||
|
|
|
@ -57,22 +57,51 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
# Parameters for testing: optimizer, num_microbatches, expected answer.
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
|
||||
[-2.5, -2.5]),
|
||||
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2,
|
||||
[-2.5, -2.5]),
|
||||
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4,
|
||||
[-2.5, -2.5]),
|
||||
(
|
||||
'DPGradientDescent 1',
|
||||
dp_optimizer.DPGradientDescentOptimizer,
|
||||
1,
|
||||
[-2.5, -2.5],
|
||||
),
|
||||
(
|
||||
'DPGradientDescent 2',
|
||||
dp_optimizer.DPGradientDescentOptimizer,
|
||||
2,
|
||||
[-2.5, -2.5],
|
||||
),
|
||||
(
|
||||
'DPGradientDescent 4',
|
||||
dp_optimizer.DPGradientDescentOptimizer,
|
||||
4,
|
||||
[-2.5, -2.5],
|
||||
),
|
||||
('DPAdagrad 1', dp_optimizer.DPAdagradOptimizer, 1, [-2.5, -2.5]),
|
||||
('DPAdagrad 2', dp_optimizer.DPAdagradOptimizer, 2, [-2.5, -2.5]),
|
||||
('DPAdagrad 4', dp_optimizer.DPAdagradOptimizer, 4, [-2.5, -2.5]),
|
||||
('DPAdam 1', dp_optimizer.DPAdamOptimizer, 1, [-2.5, -2.5]),
|
||||
('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-2.5, -2.5]),
|
||||
('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5]),
|
||||
('DPRMSPropOptimizer 1', dp_optimizer.DPRMSPropOptimizer, 1,
|
||||
[-2.5, -2.5]), ('DPRMSPropOptimizer 2', dp_optimizer.DPRMSPropOptimizer,
|
||||
2, [-2.5, -2.5]),
|
||||
('DPRMSPropOptimizer 4', dp_optimizer.DPRMSPropOptimizer, 4, [-2.5, -2.5])
|
||||
(
|
||||
'DPRMSPropOptimizer 1',
|
||||
dp_optimizer.DPRMSPropOptimizer,
|
||||
1,
|
||||
[-2.5, -2.5],
|
||||
),
|
||||
(
|
||||
'DPRMSPropOptimizer 2',
|
||||
dp_optimizer.DPRMSPropOptimizer,
|
||||
2,
|
||||
[-2.5, -2.5],
|
||||
),
|
||||
(
|
||||
'DPRMSPropOptimizer 4',
|
||||
dp_optimizer.DPRMSPropOptimizer,
|
||||
4,
|
||||
[-2.5, -2.5],
|
||||
),
|
||||
('DPFtrl 1', dp_optimizer.DPFtrlOptimizer, 1, [-2.5, -2.5]),
|
||||
('DPFtrl 2', dp_optimizer.DPFtrlOptimizer, 2, [-2.5, -2.5]),
|
||||
('DPFtrl 4', dp_optimizer.DPFtrlOptimizer, 4, [-2.5, -2.5]),
|
||||
)
|
||||
def testBaseline(self, cls, num_microbatches, expected_answer):
|
||||
with self.cached_session() as sess:
|
||||
|
@ -98,7 +127,9 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
|
||||
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
|
||||
('DPAdam', dp_optimizer.DPAdamOptimizer),
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer))
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer),
|
||||
('DPFtrlOptimizer', dp_optimizer.DPFtrlOptimizer),
|
||||
)
|
||||
def testClippingNorm(self, cls):
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([0.0, 0.0])
|
||||
|
@ -158,7 +189,9 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4),
|
||||
('DPAdagrad', dp_optimizer.DPAdagradOptimizer, 1),
|
||||
('DPAdam', dp_optimizer.DPAdamOptimizer, 1),
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer, 1))
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer, 1),
|
||||
('DPFtrl', dp_optimizer.DPFtrlOptimizer, 1),
|
||||
)
|
||||
def testNoiseMultiplier(self, cls, num_microbatches):
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([0.0])
|
||||
|
@ -212,10 +245,11 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
dp_sum_query, num_microbatches=1, learning_rate=1.0)
|
||||
global_step = tf.compat.v1.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
|
||||
return tf_estimator.EstimatorSpec(
|
||||
mode=mode, loss=scalar_loss, train_op=train_op)
|
||||
return tf_estimator.EstimatorSpec( # pylint: disable=g-deprecated-tf-checker
|
||||
mode=mode, loss=scalar_loss, train_op=train_op
|
||||
)
|
||||
|
||||
linear_regressor = tf_estimator.Estimator(model_fn=linear_model_fn)
|
||||
linear_regressor = tf_estimator.Estimator(model_fn=linear_model_fn) # pylint: disable=g-deprecated-tf-checker
|
||||
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
|
||||
true_bias = 6.0
|
||||
train_data = np.random.normal(scale=3.0, size=(200, 4)).astype(np.float32)
|
||||
|
@ -240,7 +274,9 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
|
||||
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
|
||||
('DPAdam', dp_optimizer.DPAdamOptimizer),
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer))
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer),
|
||||
('DPFtrl', dp_optimizer.DPFtrlOptimizer),
|
||||
)
|
||||
def testUnrollMicrobatches(self, cls):
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
|
@ -270,7 +306,9 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
('DPGradientDescent', dp_optimizer.DPGradientDescentGaussianOptimizer),
|
||||
('DPAdagrad', dp_optimizer.DPAdagradGaussianOptimizer),
|
||||
('DPAdam', dp_optimizer.DPAdamGaussianOptimizer),
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropGaussianOptimizer))
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropGaussianOptimizer),
|
||||
('DPFtrl', dp_optimizer.DPFtrlGaussianOptimizer),
|
||||
)
|
||||
def testDPGaussianOptimizerClass(self, cls):
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([0.0])
|
||||
|
@ -299,7 +337,9 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
|
||||
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
|
||||
('DPAdam', dp_optimizer.DPAdamOptimizer),
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer))
|
||||
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer),
|
||||
('DPFtrl', dp_optimizer.DPFtrlOptimizer),
|
||||
)
|
||||
def testAssertOnNoCallOfComputeGradients(self, cls):
|
||||
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
|
||||
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=1.0)
|
||||
|
|
Loading…
Reference in a new issue