diff --git a/privacy/optimizers/no_privacy_query.py b/privacy/optimizers/no_privacy_query.py index f15c3c3..551956e 100644 --- a/privacy/optimizers/no_privacy_query.py +++ b/privacy/optimizers/no_privacy_query.py @@ -44,10 +44,14 @@ class NoPrivacySumQuery(private_queries.PrivateSumQuery): del global_state # unused. return nest.map_structure(tf.zeros_like, tensors) - def accumulate_record(self, params, sample_state, record): - """See base class.""" + def accumulate_record(self, params, sample_state, record, weight=1): + """See base class. Optional argument for weighted sum queries.""" del params # unused. - return nest.map_structure(tf.add, sample_state, record) + + def add_weighted(state_tensor, record_tensor): + return tf.add(state_tensor, weight * record_tensor) + + return nest.map_structure(add_weighted, sample_state, record) def get_noised_sum(self, sample_state, global_state): """See base class.""" @@ -77,11 +81,13 @@ class NoPrivacyAverageQuery(private_queries.PrivateAverageQuery): """See base class.""" return self._numerator.initial_sample_state(global_state, tensors), 0.0 - def accumulate_record(self, params, sample_state, record): - """See base class.""" + def accumulate_record(self, params, sample_state, record, weight=1): + """See base class. Optional argument for weighted average queries.""" sum_sample_state, denominator = sample_state - return self._numerator.accumulate_record(params, sum_sample_state, - record), tf.add(denominator, 1.0) + return ( + self._numerator.accumulate_record( + params, sum_sample_state, record, weight), + tf.add(denominator, weight)) def get_noised_average(self, sample_state, global_state): """See base class.""" diff --git a/privacy/optimizers/no_privacy_query_test.py b/privacy/optimizers/no_privacy_query_test.py index 5c39eba..136894e 100644 --- a/privacy/optimizers/no_privacy_query_test.py +++ b/privacy/optimizers/no_privacy_query_test.py @@ -23,18 +23,14 @@ import tensorflow as tf from privacy.optimizers import no_privacy_query -try: - xrange -except NameError: - xrange = range - -def _run_query(query, records): +def _run_query(query, records, weights=None): """Executes query on the given set of records as a single sample. Args: query: A PrivateQuery to run. records: An iterable containing records to pass to the query. + weights: An optional iterable containing the weights of the records. Returns: The result of the query. @@ -42,8 +38,13 @@ def _run_query(query, records): global_state = query.initial_global_state() params = query.derive_sample_params(global_state) sample_state = query.initial_sample_state(global_state, next(iter(records))) - for record in records: - sample_state = query.accumulate_record(params, sample_state, record) + if weights is None: + for record in records: + sample_state = query.accumulate_record(params, sample_state, record) + else: + for weight, record in zip(weights, records): + sample_state = query.accumulate_record(params, sample_state, record, + weight) result, _ = query.get_query_result(sample_state, global_state) return result @@ -61,6 +62,20 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase): expected = [1.0, 1.0] self.assertAllClose(result, expected) + def test_no_privacy_weighted_sum(self): + with self.cached_session() as sess: + record1 = tf.constant([2.0, 0.0]) + record2 = tf.constant([-1.0, 1.0]) + + weight1 = 1 + weight2 = 2 + + query = no_privacy_query.NoPrivacySumQuery() + query_result = _run_query(query, [record1, record2], [weight1, weight2]) + result = sess.run(query_result) + expected = [0.0, 2.0] + self.assertAllClose(result, expected) + def test_no_privacy_average(self): with self.cached_session() as sess: record1 = tf.constant([5.0, 0.0]) @@ -69,8 +84,22 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase): query = no_privacy_query.NoPrivacyAverageQuery() query_result = _run_query(query, [record1, record2]) result = sess.run(query_result) - expected_average = [2.0, 1.0] - self.assertAllClose(result, expected_average) + expected = [2.0, 1.0] + self.assertAllClose(result, expected) + + def test_no_privacy_weighted_average(self): + with self.cached_session() as sess: + record1 = tf.constant([4.0, 0.0]) + record2 = tf.constant([-1.0, 1.0]) + + weight1 = 1 + weight2 = 3 + + query = no_privacy_query.NoPrivacyAverageQuery() + query_result = _run_query(query, [record1, record2], [weight1, weight2]) + result = sess.run(query_result) + expected = [0.25, 0.75] + self.assertAllClose(result, expected) @parameterized.named_parameters( ('type_mismatch', [1.0], (1.0,), TypeError),