General cleanup.
1. Rename PrivateQuery to DPQuery. 2. Move construction of DPQuery to outside of optimizer. 3. Remove PrivateAverageQuery and PrivateSumQuery, and rename DPQuery's 'get_query_result' method to 'get_noised_result'. Rename private_queries.py to dp_query.py. 4. Remove thrice-replicated run_query function from the test classes and replace with a single function in new test_utils.py. 5. Add functions gaussian_sum_query_from_noise_multplier and gaussian_average_query_from_noise_multplier. PiperOrigin-RevId: 230595991
This commit is contained in:
parent
7e2d796bde
commit
c8cb3c6b70
11 changed files with 189 additions and 185 deletions
|
@ -37,22 +37,19 @@ def make_optimizer_class(cls):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
l2_norm_clip,
|
dp_average_query,
|
||||||
noise_multiplier,
|
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
unroll_microbatches=False,
|
unroll_microbatches=False,
|
||||||
*args, # pylint: disable=keyword-arg-before-vararg
|
*args, # pylint: disable=keyword-arg-before-vararg
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(DPOptimizerClass, self).__init__(*args, **kwargs)
|
super(DPOptimizerClass, self).__init__(*args, **kwargs)
|
||||||
stddev = l2_norm_clip * noise_multiplier
|
self._dp_average_query = dp_average_query
|
||||||
self._num_microbatches = num_microbatches
|
self._num_microbatches = num_microbatches
|
||||||
self._private_query = gaussian_query.GaussianAverageQuery(
|
self._global_state = self._dp_average_query.initial_global_state()
|
||||||
l2_norm_clip, stddev, num_microbatches)
|
|
||||||
# TODO(b/122613513): Set unroll_microbatches=True to avoid this bug.
|
# TODO(b/122613513): Set unroll_microbatches=True to avoid this bug.
|
||||||
# Beware: When num_microbatches is large (>100), enabling this parameter
|
# Beware: When num_microbatches is large (>100), enabling this parameter
|
||||||
# may cause an OOM error.
|
# may cause an OOM error.
|
||||||
self._unroll_microbatches = unroll_microbatches
|
self._unroll_microbatches = unroll_microbatches
|
||||||
self._global_state = self._private_query.initial_global_state()
|
|
||||||
|
|
||||||
def compute_gradients(self,
|
def compute_gradients(self,
|
||||||
loss,
|
loss,
|
||||||
|
@ -68,7 +65,7 @@ def make_optimizer_class(cls):
|
||||||
# sampling from the dataset without replacement.
|
# sampling from the dataset without replacement.
|
||||||
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
||||||
sample_params = (
|
sample_params = (
|
||||||
self._private_query.derive_sample_params(self._global_state))
|
self._dp_average_query.derive_sample_params(self._global_state))
|
||||||
|
|
||||||
def process_microbatch(i, sample_state):
|
def process_microbatch(i, sample_state):
|
||||||
"""Process one microbatch (record) with privacy helper."""
|
"""Process one microbatch (record) with privacy helper."""
|
||||||
|
@ -76,7 +73,7 @@ def make_optimizer_class(cls):
|
||||||
tf.gather(microbatches_losses, [i]), var_list, gate_gradients,
|
tf.gather(microbatches_losses, [i]), var_list, gate_gradients,
|
||||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||||
grads_list = list(grads)
|
grads_list = list(grads)
|
||||||
sample_state = self._private_query.accumulate_record(
|
sample_state = self._dp_average_query.accumulate_record(
|
||||||
sample_params, sample_state, grads_list)
|
sample_params, sample_state, grads_list)
|
||||||
return sample_state
|
return sample_state
|
||||||
|
|
||||||
|
@ -84,7 +81,7 @@ def make_optimizer_class(cls):
|
||||||
var_list = (
|
var_list = (
|
||||||
tf.trainable_variables() + tf.get_collection(
|
tf.trainable_variables() + tf.get_collection(
|
||||||
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||||
sample_state = self._private_query.initial_sample_state(
|
sample_state = self._dp_average_query.initial_sample_state(
|
||||||
self._global_state, var_list)
|
self._global_state, var_list)
|
||||||
|
|
||||||
if self._unroll_microbatches:
|
if self._unroll_microbatches:
|
||||||
|
@ -100,15 +97,48 @@ def make_optimizer_class(cls):
|
||||||
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
|
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
|
||||||
|
|
||||||
final_grads, self._global_state = (
|
final_grads, self._global_state = (
|
||||||
self._private_query.get_noised_average(sample_state,
|
self._dp_average_query.get_noised_result(
|
||||||
self._global_state))
|
sample_state, self._global_state))
|
||||||
|
|
||||||
return list(zip(final_grads, var_list))
|
return list(zip(final_grads, var_list))
|
||||||
|
|
||||||
return DPOptimizerClass
|
return DPOptimizerClass
|
||||||
|
|
||||||
|
|
||||||
|
def make_gaussian_optimizer_class(cls):
|
||||||
|
"""Constructs a DP optimizer with Gaussian averaging of updates."""
|
||||||
|
|
||||||
|
class DPGaussianOptimizerClass(make_optimizer_class(cls)):
|
||||||
|
"""DP subclass of given class cls using Gaussian averaging."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
l2_norm_clip,
|
||||||
|
noise_multiplier,
|
||||||
|
num_microbatches,
|
||||||
|
unroll_microbatches=False,
|
||||||
|
*args, # pylint: disable=keyword-arg-before-vararg
|
||||||
|
**kwargs):
|
||||||
|
dp_average_query = gaussian_query.GaussianAverageQuery(
|
||||||
|
l2_norm_clip, l2_norm_clip * noise_multiplier, num_microbatches)
|
||||||
|
super(DPGaussianOptimizerClass, self).__init__(
|
||||||
|
dp_average_query,
|
||||||
|
num_microbatches,
|
||||||
|
unroll_microbatches,
|
||||||
|
*args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return DPGaussianOptimizerClass
|
||||||
|
|
||||||
|
|
||||||
DPAdagradOptimizer = make_optimizer_class(tf.train.AdagradOptimizer)
|
DPAdagradOptimizer = make_optimizer_class(tf.train.AdagradOptimizer)
|
||||||
DPAdamOptimizer = make_optimizer_class(tf.train.AdamOptimizer)
|
DPAdamOptimizer = make_optimizer_class(tf.train.AdamOptimizer)
|
||||||
DPGradientDescentOptimizer = make_optimizer_class(
|
DPGradientDescentOptimizer = make_optimizer_class(
|
||||||
tf.train.GradientDescentOptimizer)
|
tf.train.GradientDescentOptimizer)
|
||||||
|
|
||||||
|
DPAdagradGaussianOptimizer = make_gaussian_optimizer_class(
|
||||||
|
tf.train.AdagradOptimizer)
|
||||||
|
DPAdamGaussianOptimizer = make_gaussian_optimizer_class(tf.train.AdamOptimizer)
|
||||||
|
DPGradientDescentGaussianOptimizer = make_gaussian_optimizer_class(
|
||||||
|
tf.train.GradientDescentOptimizer)
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from privacy.optimizers import dp_optimizer
|
from privacy.optimizers import dp_optimizer
|
||||||
|
from privacy.optimizers import gaussian_query
|
||||||
|
|
||||||
|
|
||||||
def loss(val0, val1):
|
def loss(val0, val1):
|
||||||
|
@ -51,9 +52,11 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
var0 = tf.Variable([1.0, 2.0])
|
var0 = tf.Variable([1.0, 2.0])
|
||||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||||
|
|
||||||
|
dp_average_query = gaussian_query.GaussianAverageQuery(
|
||||||
|
1.0e9, 0.0, num_microbatches)
|
||||||
|
|
||||||
opt = cls(
|
opt = cls(
|
||||||
l2_norm_clip=1.0e9,
|
dp_average_query,
|
||||||
noise_multiplier=0.0,
|
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
learning_rate=2.0)
|
learning_rate=2.0)
|
||||||
|
|
||||||
|
@ -76,11 +79,9 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
var0 = tf.Variable([0.0, 0.0])
|
var0 = tf.Variable([0.0, 0.0])
|
||||||
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
|
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
|
||||||
|
|
||||||
opt = cls(
|
dp_average_query = gaussian_query.GaussianAverageQuery(1.0, 0.0, 1)
|
||||||
l2_norm_clip=1.0,
|
|
||||||
noise_multiplier=0.0,
|
opt = cls(dp_average_query, num_microbatches=1, learning_rate=2.0)
|
||||||
num_microbatches=1,
|
|
||||||
learning_rate=2.0)
|
|
||||||
|
|
||||||
self.evaluate(tf.global_variables_initializer())
|
self.evaluate(tf.global_variables_initializer())
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
|
@ -100,11 +101,9 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
var0 = tf.Variable([0.0])
|
var0 = tf.Variable([0.0])
|
||||||
data0 = tf.Variable([[0.0]])
|
data0 = tf.Variable([[0.0]])
|
||||||
|
|
||||||
opt = cls(
|
dp_average_query = gaussian_query.GaussianAverageQuery(4.0, 8.0, 1)
|
||||||
l2_norm_clip=4.0,
|
|
||||||
noise_multiplier=2.0,
|
opt = cls(dp_average_query, num_microbatches=1, learning_rate=2.0)
|
||||||
num_microbatches=1,
|
|
||||||
learning_rate=2.0)
|
|
||||||
|
|
||||||
self.evaluate(tf.global_variables_initializer())
|
self.evaluate(tf.global_variables_initializer())
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
|
@ -143,9 +142,9 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
vector_loss = tf.squared_difference(labels, preds)
|
vector_loss = tf.squared_difference(labels, preds)
|
||||||
scalar_loss = tf.reduce_mean(vector_loss)
|
scalar_loss = tf.reduce_mean(vector_loss)
|
||||||
|
dp_average_query = gaussian_query.GaussianAverageQuery(1.0, 0.0, 1)
|
||||||
optimizer = dp_optimizer.DPGradientDescentOptimizer(
|
optimizer = dp_optimizer.DPGradientDescentOptimizer(
|
||||||
l2_norm_clip=1.0,
|
dp_average_query,
|
||||||
noise_multiplier=0.0,
|
|
||||||
num_microbatches=1,
|
num_microbatches=1,
|
||||||
learning_rate=1.0)
|
learning_rate=1.0)
|
||||||
global_step = tf.train.get_global_step()
|
global_step = tf.train.get_global_step()
|
||||||
|
@ -183,9 +182,10 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
var0 = tf.Variable([1.0, 2.0])
|
var0 = tf.Variable([1.0, 2.0])
|
||||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||||
|
|
||||||
|
dp_average_query = (
|
||||||
|
gaussian_query.GaussianAverageQuery(1.0e9, 0.0, 4))
|
||||||
opt = cls(
|
opt = cls(
|
||||||
l2_norm_clip=1.0e9,
|
dp_average_query,
|
||||||
noise_multiplier=0.0,
|
|
||||||
num_microbatches=4,
|
num_microbatches=4,
|
||||||
learning_rate=2.0,
|
learning_rate=2.0,
|
||||||
unroll_microbatches=True)
|
unroll_microbatches=True)
|
||||||
|
@ -200,6 +200,33 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
grads_and_vars = sess.run(gradient_op)
|
grads_and_vars = sess.run(gradient_op)
|
||||||
self.assertAllCloseAccordingToType([-2.5, -2.5], grads_and_vars[0][0])
|
self.assertAllCloseAccordingToType([-2.5, -2.5], grads_and_vars[0][0])
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('DPGradientDescent', dp_optimizer.DPGradientDescentGaussianOptimizer),
|
||||||
|
('DPAdagrad', dp_optimizer.DPAdagradGaussianOptimizer),
|
||||||
|
('DPAdam', dp_optimizer.DPAdamGaussianOptimizer))
|
||||||
|
def testDPGaussianOptimizerClass(self, cls):
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
var0 = tf.Variable([0.0])
|
||||||
|
data0 = tf.Variable([[0.0]])
|
||||||
|
|
||||||
|
opt = cls(
|
||||||
|
l2_norm_clip=4.0,
|
||||||
|
noise_multiplier=2.0,
|
||||||
|
num_microbatches=1,
|
||||||
|
learning_rate=2.0)
|
||||||
|
|
||||||
|
self.evaluate(tf.global_variables_initializer())
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
self.assertAllClose([0.0], self.evaluate(var0))
|
||||||
|
|
||||||
|
gradient_op = opt.compute_gradients(loss(data0, var0), [var0])
|
||||||
|
grads = []
|
||||||
|
for _ in range(1000):
|
||||||
|
grads_and_vars = sess.run(gradient_op)
|
||||||
|
grads.append(grads_and_vars[0][0])
|
||||||
|
|
||||||
|
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||||
|
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -22,14 +22,14 @@ from __future__ import print_function
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
|
|
||||||
class PrivateQuery(object):
|
class DPQuery(object):
|
||||||
"""Interface for differentially private query mechanisms."""
|
"""Interface for differentially private query mechanisms."""
|
||||||
|
|
||||||
__metaclass__ = abc.ABCMeta
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""Returns the initial global state for the PrivateQuery."""
|
"""Returns the initial global state for the DPQuery."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@ -72,7 +72,7 @@ class PrivateQuery(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_query_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Gets query result after all records of sample have been accumulated.
|
"""Gets query result after all records of sample have been accumulated.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -84,47 +84,3 @@ class PrivateQuery(object):
|
||||||
query and "new_global_state" is the updated global state.
|
query and "new_global_state" is the updated global state.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PrivateSumQuery(PrivateQuery):
|
|
||||||
"""Interface for differentially private mechanisms to compute a sum."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_noised_sum(self, sample_state, global_state):
|
|
||||||
"""Gets estimate of sum after all records of sample have been accumulated.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_state: The sample state after all records have been accumulated.
|
|
||||||
global_state: The global state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (estimate, new_global_state) where "estimate" is the estimated
|
|
||||||
sum of the records and "new_global_state" is the updated global state.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_query_result(self, sample_state, global_state):
|
|
||||||
"""Delegates to get_noised_sum."""
|
|
||||||
return self.get_noised_sum(sample_state, global_state)
|
|
||||||
|
|
||||||
|
|
||||||
class PrivateAverageQuery(PrivateQuery):
|
|
||||||
"""Interface for differentially private mechanisms to compute an average."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_noised_average(self, sample_state, global_state):
|
|
||||||
"""Gets average estimate after all records of sample have been accumulated.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_state: The sample state after all records have been accumulated.
|
|
||||||
global_state: The global state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (estimate, new_global_state) where "estimate" is the estimated
|
|
||||||
average of the records and "new_global_state" is the updated global state.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_query_result(self, sample_state, global_state):
|
|
||||||
"""Delegates to get_noised_average."""
|
|
||||||
return self.get_noised_average(sample_state, global_state)
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Implements PrivateQuery interface for Gaussian average queries.
|
"""Implements DPQuery interface for Gaussian average queries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
@ -23,13 +23,13 @@ import collections
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from privacy.optimizers import private_queries
|
from privacy.optimizers import dp_query
|
||||||
|
|
||||||
nest = tf.contrib.framework.nest
|
nest = tf.contrib.framework.nest
|
||||||
|
|
||||||
|
|
||||||
class GaussianSumQuery(private_queries.PrivateSumQuery):
|
class GaussianSumQuery(dp_query.DPQuery):
|
||||||
"""Implements PrivateQuery interface for Gaussian sum queries.
|
"""Implements DPQuery interface for Gaussian sum queries.
|
||||||
|
|
||||||
Accumulates clipped vectors, then adds Gaussian noise to the sum.
|
Accumulates clipped vectors, then adds Gaussian noise to the sum.
|
||||||
"""
|
"""
|
||||||
|
@ -94,7 +94,7 @@ class GaussianSumQuery(private_queries.PrivateSumQuery):
|
||||||
clipped = nest.pack_sequence_as(record, clipped_as_list)
|
clipped = nest.pack_sequence_as(record, clipped_as_list)
|
||||||
return nest.map_structure(tf.add, sample_state, clipped)
|
return nest.map_structure(tf.add, sample_state, clipped)
|
||||||
|
|
||||||
def get_noised_sum(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Gets noised sum after all records of sample have been accumulated.
|
"""Gets noised sum after all records of sample have been accumulated.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -111,10 +111,15 @@ class GaussianSumQuery(private_queries.PrivateSumQuery):
|
||||||
return nest.map_structure(add_noise, sample_state), global_state
|
return nest.map_structure(add_noise, sample_state), global_state
|
||||||
|
|
||||||
|
|
||||||
class GaussianAverageQuery(private_queries.PrivateAverageQuery):
|
class GaussianAverageQuery(dp_query.DPQuery):
|
||||||
"""Implements PrivateQuery interface for Gaussian average queries.
|
"""Implements DPQuery interface for Gaussian average queries.
|
||||||
|
|
||||||
Accumulates clipped vectors, adds Gaussian noise, and normalizes.
|
Accumulates clipped vectors, adds Gaussian noise, and normalizes.
|
||||||
|
|
||||||
|
Note that we use "fixed-denominator" estimation: the denominator should be
|
||||||
|
specified as the expected number of records per sample. Accumulating the
|
||||||
|
denominator separately would also be possible but would be produce a higher
|
||||||
|
variance estimator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
@ -177,7 +182,7 @@ class GaussianAverageQuery(private_queries.PrivateAverageQuery):
|
||||||
"""
|
"""
|
||||||
return self._numerator.accumulate_record(params, sample_state, record)
|
return self._numerator.accumulate_record(params, sample_state, record)
|
||||||
|
|
||||||
def get_noised_average(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Gets noised average after all records of sample have been accumulated.
|
"""Gets noised average after all records of sample have been accumulated.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -188,7 +193,7 @@ class GaussianAverageQuery(private_queries.PrivateAverageQuery):
|
||||||
A tuple (estimate, new_global_state) where "estimate" is the estimated
|
A tuple (estimate, new_global_state) where "estimate" is the estimated
|
||||||
average of the records and "new_global_state" is the updated global state.
|
average of the records and "new_global_state" is the updated global state.
|
||||||
"""
|
"""
|
||||||
noised_sum, new_sum_global_state = self._numerator.get_noised_sum(
|
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
||||||
sample_state, global_state.sum_state)
|
sample_state, global_state.sum_state)
|
||||||
new_global_state = self._GlobalState(
|
new_global_state = self._GlobalState(
|
||||||
new_sum_global_state, global_state.denominator)
|
new_sum_global_state, global_state.denominator)
|
||||||
|
|
|
@ -23,25 +23,7 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from privacy.optimizers import gaussian_query
|
from privacy.optimizers import gaussian_query
|
||||||
|
from privacy.optimizers import test_utils
|
||||||
|
|
||||||
def _run_query(query, records):
|
|
||||||
"""Executes query on the given set of records as a single sample.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: A PrivateQuery to run.
|
|
||||||
records: An iterable containing records to pass to the query.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The result of the query.
|
|
||||||
"""
|
|
||||||
global_state = query.initial_global_state()
|
|
||||||
params = query.derive_sample_params(global_state)
|
|
||||||
sample_state = query.initial_sample_state(global_state, next(iter(records)))
|
|
||||||
for record in records:
|
|
||||||
sample_state = query.accumulate_record(params, sample_state, record)
|
|
||||||
result, _ = query.get_query_result(sample_state, global_state)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
@ -53,7 +35,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
query = gaussian_query.GaussianSumQuery(
|
query = gaussian_query.GaussianSumQuery(
|
||||||
l2_norm_clip=10.0, stddev=0.0)
|
l2_norm_clip=10.0, stddev=0.0)
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [1.0, 1.0]
|
expected = [1.0, 1.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
@ -65,7 +47,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
query = gaussian_query.GaussianSumQuery(
|
query = gaussian_query.GaussianSumQuery(
|
||||||
l2_norm_clip=5.0, stddev=0.0)
|
l2_norm_clip=5.0, stddev=0.0)
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [1.0, 1.0]
|
expected = [1.0, 1.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
@ -77,7 +59,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
query = gaussian_query.GaussianSumQuery(
|
query = gaussian_query.GaussianSumQuery(
|
||||||
l2_norm_clip=5.0, stddev=stddev)
|
l2_norm_clip=5.0, stddev=stddev)
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
noised_sums = []
|
noised_sums = []
|
||||||
for _ in xrange(1000):
|
for _ in xrange(1000):
|
||||||
|
@ -93,7 +75,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
query = gaussian_query.GaussianAverageQuery(
|
query = gaussian_query.GaussianAverageQuery(
|
||||||
l2_norm_clip=3.0, sum_stddev=0.0, denominator=2.0)
|
l2_norm_clip=3.0, sum_stddev=0.0, denominator=2.0)
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected_average = [1.0, 1.0]
|
expected_average = [1.0, 1.0]
|
||||||
self.assertAllClose(result, expected_average)
|
self.assertAllClose(result, expected_average)
|
||||||
|
@ -106,7 +88,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
query = gaussian_query.GaussianAverageQuery(
|
query = gaussian_query.GaussianAverageQuery(
|
||||||
l2_norm_clip=5.0, sum_stddev=sum_stddev, denominator=denominator)
|
l2_norm_clip=5.0, sum_stddev=sum_stddev, denominator=denominator)
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
noised_averages = []
|
noised_averages = []
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
|
@ -123,7 +105,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
def test_incompatible_records(self, record1, record2, error_type):
|
def test_incompatible_records(self, record1, record2, error_type):
|
||||||
query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
||||||
with self.assertRaises(error_type):
|
with self.assertRaises(error_type):
|
||||||
_run_query(query, [record1, record2])
|
test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Implements PrivateQuery interface for queries over nested structures.
|
"""Implements DPQuery interface for queries over nested structures.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
@ -22,13 +22,13 @@ from __future__ import print_function
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from privacy.optimizers import private_queries
|
from privacy.optimizers import dp_query
|
||||||
|
|
||||||
nest = tf.contrib.framework.nest
|
nest = tf.contrib.framework.nest
|
||||||
|
|
||||||
|
|
||||||
class NestedQuery(private_queries.PrivateQuery):
|
class NestedQuery(dp_query.DPQuery):
|
||||||
"""Implements PrivateQuery interface for structured queries.
|
"""Implements DPQuery interface for structured queries.
|
||||||
|
|
||||||
NestedQuery evaluates arbitrary nested structures of queries. Records must be
|
NestedQuery evaluates arbitrary nested structures of queries. Records must be
|
||||||
nested structures of tensors that are compatible (in type and arity) with the
|
nested structures of tensors that are compatible (in type and arity) with the
|
||||||
|
@ -100,7 +100,7 @@ class NestedQuery(private_queries.PrivateQuery):
|
||||||
return self._map_to_queries(
|
return self._map_to_queries(
|
||||||
'accumulate_record', params, sample_state, record)
|
'accumulate_record', params, sample_state, record)
|
||||||
|
|
||||||
def get_query_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""Gets query result after all records of sample have been accumulated.
|
"""Gets query result after all records of sample have been accumulated.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -114,7 +114,7 @@ class NestedQuery(private_queries.PrivateQuery):
|
||||||
for the subqueries.
|
for the subqueries.
|
||||||
"""
|
"""
|
||||||
estimates_and_new_global_states = self._map_to_queries(
|
estimates_and_new_global_states = self._map_to_queries(
|
||||||
'get_query_result', sample_state, global_state)
|
'get_noised_result', sample_state, global_state)
|
||||||
|
|
||||||
flat_estimates, flat_new_global_states = zip(
|
flat_estimates, flat_new_global_states = zip(
|
||||||
*nest.flatten_up_to(self._queries, estimates_and_new_global_states))
|
*nest.flatten_up_to(self._queries, estimates_and_new_global_states))
|
||||||
|
|
|
@ -25,31 +25,13 @@ import tensorflow as tf
|
||||||
|
|
||||||
from privacy.optimizers import gaussian_query
|
from privacy.optimizers import gaussian_query
|
||||||
from privacy.optimizers import nested_query
|
from privacy.optimizers import nested_query
|
||||||
|
from privacy.optimizers import test_utils
|
||||||
|
|
||||||
nest = tf.contrib.framework.nest
|
nest = tf.contrib.framework.nest
|
||||||
|
|
||||||
_basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
_basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
def _run_query(query, records):
|
|
||||||
"""Executes query on the given set of records as a single sample.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: A PrivateQuery to run.
|
|
||||||
records: An iterable containing records to pass to the query.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The result of the query.
|
|
||||||
"""
|
|
||||||
global_state = query.initial_global_state()
|
|
||||||
params = query.derive_sample_params(global_state)
|
|
||||||
sample_state = query.initial_sample_state(global_state, next(iter(records)))
|
|
||||||
for record in records:
|
|
||||||
sample_state = query.accumulate_record(params, sample_state, record)
|
|
||||||
result, _ = query.get_query_result(sample_state, global_state)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_nested_gaussian_sum_no_clip_no_noise(self):
|
def test_nested_gaussian_sum_no_clip_no_noise(self):
|
||||||
|
@ -64,7 +46,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record1 = [1.0, [2.0, 3.0]]
|
record1 = [1.0, [2.0, 3.0]]
|
||||||
record2 = [4.0, [3.0, 2.0]]
|
record2 = [4.0, [3.0, 2.0]]
|
||||||
|
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [5.0, [5.0, 5.0]]
|
expected = [5.0, [5.0, 5.0]]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
@ -81,7 +63,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record1 = [1.0, [2.0, 3.0]]
|
record1 = [1.0, [2.0, 3.0]]
|
||||||
record2 = [4.0, [3.0, 2.0]]
|
record2 = [4.0, [3.0, 2.0]]
|
||||||
|
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [1.0, [1.0, 1.0]]
|
expected = [1.0, [1.0, 1.0]]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
@ -98,7 +80,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]]
|
record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]]
|
||||||
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
|
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
|
||||||
|
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [1.0, [1.0, 1.0]]
|
expected = [1.0, [1.0, 1.0]]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
@ -118,7 +100,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
|
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
|
||||||
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
|
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
|
||||||
|
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [{'a': 1.0, 'b': 1.0}, {'c': (1.0, 1.0), 'd': [1.0]}]
|
expected = [{'a': 1.0, 'b': 1.0}, {'c': (1.0, 1.0), 'd': [1.0]}]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
@ -137,7 +119,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record1 = (3.0, [2.0, 1.5])
|
record1 = (3.0, [2.0, 1.5])
|
||||||
record2 = (0.0, [-1.0, -3.5])
|
record2 = (0.0, [-1.0, -3.5])
|
||||||
|
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
noised_averages = []
|
noised_averages = []
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
|
@ -157,7 +139,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
def test_record_incompatible_with_query(
|
def test_record_incompatible_with_query(
|
||||||
self, queries, record, error_type):
|
self, queries, record, error_type):
|
||||||
with self.assertRaises(error_type):
|
with self.assertRaises(error_type):
|
||||||
_run_query(nested_query.NestedQuery(queries), [record])
|
test_utils.run_query(nested_query.NestedQuery(queries), [record])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Implements PrivateQuery interface for no privacy average queries."""
|
"""Implements DPQuery interface for no privacy average queries."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
@ -19,13 +19,13 @@ from __future__ import print_function
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from privacy.optimizers import private_queries
|
from privacy.optimizers import dp_query
|
||||||
|
|
||||||
nest = tf.contrib.framework.nest
|
nest = tf.contrib.framework.nest
|
||||||
|
|
||||||
|
|
||||||
class NoPrivacySumQuery(private_queries.PrivateSumQuery):
|
class NoPrivacySumQuery(dp_query.DPQuery):
|
||||||
"""Implements PrivateQuery interface for a sum query with no privacy.
|
"""Implements DPQuery interface for a sum query with no privacy.
|
||||||
|
|
||||||
Accumulates vectors without clipping or adding noise.
|
Accumulates vectors without clipping or adding noise.
|
||||||
"""
|
"""
|
||||||
|
@ -53,13 +53,13 @@ class NoPrivacySumQuery(private_queries.PrivateSumQuery):
|
||||||
|
|
||||||
return nest.map_structure(add_weighted, sample_state, record)
|
return nest.map_structure(add_weighted, sample_state, record)
|
||||||
|
|
||||||
def get_noised_sum(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return sample_state, global_state
|
return sample_state, global_state
|
||||||
|
|
||||||
|
|
||||||
class NoPrivacyAverageQuery(private_queries.PrivateAverageQuery):
|
class NoPrivacyAverageQuery(dp_query.DPQuery):
|
||||||
"""Implements PrivateQuery interface for an average query with no privacy.
|
"""Implements DPQuery interface for an average query with no privacy.
|
||||||
|
|
||||||
Accumulates vectors and normalizes by the total number of accumulated vectors.
|
Accumulates vectors and normalizes by the total number of accumulated vectors.
|
||||||
"""
|
"""
|
||||||
|
@ -89,10 +89,10 @@ class NoPrivacyAverageQuery(private_queries.PrivateAverageQuery):
|
||||||
params, sum_sample_state, record, weight),
|
params, sum_sample_state, record, weight),
|
||||||
tf.add(denominator, weight))
|
tf.add(denominator, weight))
|
||||||
|
|
||||||
def get_noised_average(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
sum_sample_state, denominator = sample_state
|
sum_sample_state, denominator = sample_state
|
||||||
exact_sum, new_global_state = self._numerator.get_noised_sum(
|
exact_sum, new_global_state = self._numerator.get_noised_result(
|
||||||
sum_sample_state, global_state)
|
sum_sample_state, global_state)
|
||||||
|
|
||||||
def normalize(v):
|
def normalize(v):
|
||||||
|
|
|
@ -22,31 +22,7 @@ from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from privacy.optimizers import no_privacy_query
|
from privacy.optimizers import no_privacy_query
|
||||||
|
from privacy.optimizers import test_utils
|
||||||
|
|
||||||
def _run_query(query, records, weights=None):
|
|
||||||
"""Executes query on the given set of records as a single sample.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: A PrivateQuery to run.
|
|
||||||
records: An iterable containing records to pass to the query.
|
|
||||||
weights: An optional iterable containing the weights of the records.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The result of the query.
|
|
||||||
"""
|
|
||||||
global_state = query.initial_global_state()
|
|
||||||
params = query.derive_sample_params(global_state)
|
|
||||||
sample_state = query.initial_sample_state(global_state, next(iter(records)))
|
|
||||||
if weights is None:
|
|
||||||
for record in records:
|
|
||||||
sample_state = query.accumulate_record(params, sample_state, record)
|
|
||||||
else:
|
|
||||||
for weight, record in zip(weights, records):
|
|
||||||
sample_state = query.accumulate_record(params, sample_state, record,
|
|
||||||
weight)
|
|
||||||
result, _ = query.get_query_result(sample_state, global_state)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
@ -57,7 +33,7 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record2 = tf.constant([-1.0, 1.0])
|
record2 = tf.constant([-1.0, 1.0])
|
||||||
|
|
||||||
query = no_privacy_query.NoPrivacySumQuery()
|
query = no_privacy_query.NoPrivacySumQuery()
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [1.0, 1.0]
|
expected = [1.0, 1.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
@ -71,7 +47,8 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
weight2 = 2
|
weight2 = 2
|
||||||
|
|
||||||
query = no_privacy_query.NoPrivacySumQuery()
|
query = no_privacy_query.NoPrivacySumQuery()
|
||||||
query_result = _run_query(query, [record1, record2], [weight1, weight2])
|
query_result = test_utils.run_query(
|
||||||
|
query, [record1, record2], [weight1, weight2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [0.0, 2.0]
|
expected = [0.0, 2.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
@ -82,7 +59,7 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record2 = tf.constant([-1.0, 2.0])
|
record2 = tf.constant([-1.0, 2.0])
|
||||||
|
|
||||||
query = no_privacy_query.NoPrivacyAverageQuery()
|
query = no_privacy_query.NoPrivacyAverageQuery()
|
||||||
query_result = _run_query(query, [record1, record2])
|
query_result = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [2.0, 1.0]
|
expected = [2.0, 1.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
@ -96,7 +73,8 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
weight2 = 3
|
weight2 = 3
|
||||||
|
|
||||||
query = no_privacy_query.NoPrivacyAverageQuery()
|
query = no_privacy_query.NoPrivacyAverageQuery()
|
||||||
query_result = _run_query(query, [record1, record2], [weight1, weight2])
|
query_result = test_utils.run_query(
|
||||||
|
query, [record1, record2], [weight1, weight2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [0.25, 0.75]
|
expected = [0.25, 0.75]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
@ -108,7 +86,7 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
def test_incompatible_records(self, record1, record2, error_type):
|
def test_incompatible_records(self, record1, record2, error_type):
|
||||||
query = no_privacy_query.NoPrivacySumQuery()
|
query = no_privacy_query.NoPrivacySumQuery()
|
||||||
with self.assertRaises(error_type):
|
with self.assertRaises(error_type):
|
||||||
_run_query(query, [record1, record2])
|
test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
46
privacy/optimizers/test_utils.py
Normal file
46
privacy/optimizers/test_utils.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Utility methods for testing private queries.
|
||||||
|
|
||||||
|
Utility methods for testing private queries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
def run_query(query, records, weights=None):
|
||||||
|
"""Executes query on the given set of records as a single sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: A PrivateQuery to run.
|
||||||
|
records: An iterable containing records to pass to the query.
|
||||||
|
weights: An optional iterable containing the weights of the records.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the query.
|
||||||
|
"""
|
||||||
|
global_state = query.initial_global_state()
|
||||||
|
params = query.derive_sample_params(global_state)
|
||||||
|
sample_state = query.initial_sample_state(global_state, next(iter(records)))
|
||||||
|
if weights is None:
|
||||||
|
for record in records:
|
||||||
|
sample_state = query.accumulate_record(params, sample_state, record)
|
||||||
|
else:
|
||||||
|
for weight, record in zip(weights, records):
|
||||||
|
sample_state = query.accumulate_record(
|
||||||
|
params, sample_state, record, weight)
|
||||||
|
result, _ = query.get_noised_result(sample_state, global_state)
|
||||||
|
return result
|
|
@ -70,15 +70,13 @@ def cnn_model_fn(features, labels, mode):
|
||||||
|
|
||||||
if FLAGS.dpsgd:
|
if FLAGS.dpsgd:
|
||||||
# Use DP version of GradientDescentOptimizer. For illustration purposes,
|
# Use DP version of GradientDescentOptimizer. For illustration purposes,
|
||||||
# we do that here by calling make_optimizer_class() explicitly, though DP
|
# we do that here by calling optimizer_from_args() explicitly, though DP
|
||||||
# versions of standard optimizers are available in dp_optimizer.
|
# versions of standard optimizers are available in dp_optimizer.
|
||||||
dp_optimizer_class = dp_optimizer.make_optimizer_class(
|
optimizer = dp_optimizer.DPGradientDescentGaussianOptimizer(
|
||||||
tf.train.GradientDescentOptimizer)
|
|
||||||
optimizer = dp_optimizer_class(
|
|
||||||
learning_rate=FLAGS.learning_rate,
|
|
||||||
noise_multiplier=FLAGS.noise_multiplier,
|
|
||||||
l2_norm_clip=FLAGS.l2_norm_clip,
|
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||||
num_microbatches=FLAGS.microbatches)
|
noise_multiplier=FLAGS.noise_multiplier,
|
||||||
|
num_microbatches=FLAGS.microbatches,
|
||||||
|
learning_rate=FLAGS.learning_rate)
|
||||||
opt_loss = vector_loss
|
opt_loss = vector_loss
|
||||||
else:
|
else:
|
||||||
optimizer = tf.train.GradientDescentOptimizer(
|
optimizer = tf.train.GradientDescentOptimizer(
|
||||||
|
|
Loading…
Reference in a new issue