Remove PrivacyLedger which will soon be replaced by DpEvent and PrivacyAccountant.

PiperOrigin-RevId: 393147667
This commit is contained in:
Galen Andrew 2021-08-26 09:59:45 -07:00 committed by A. Unique TensorFlower
parent 0e04e1baeb
commit d9236d5619
19 changed files with 172 additions and 776 deletions

View file

@ -31,10 +31,6 @@ else:
# Analysis
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
from tensorflow_privacy.privacy.analysis.privacy_ledger import GaussianSumQueryEntry
from tensorflow_privacy.privacy.analysis.privacy_ledger import PrivacyLedger
from tensorflow_privacy.privacy.analysis.privacy_ledger import QueryWithLedger
from tensorflow_privacy.privacy.analysis.privacy_ledger import SampleEntry
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_heterogenous_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp_from_ledger

View file

@ -1,299 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PrivacyLedger class for keeping a record of private queries."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import tensor_buffer
from tensorflow_privacy.privacy.dp_query import dp_query
SampleEntry = collections.namedtuple( # pylint: disable=invalid-name
'SampleEntry', ['population_size', 'selection_probability', 'queries'])
GaussianSumQueryEntry = collections.namedtuple( # pylint: disable=invalid-name
'GaussianSumQueryEntry', ['l2_norm_bound', 'noise_stddev'])
def format_ledger(sample_array, query_array):
"""Converts array representation into a list of SampleEntries."""
samples = []
query_pos = 0
sample_pos = 0
for sample in sample_array:
population_size, selection_probability, num_queries = sample
queries = []
for _ in range(int(num_queries)):
query = query_array[query_pos]
assert int(query[0]) == sample_pos
queries.append(GaussianSumQueryEntry(*query[1:]))
query_pos += 1
samples.append(SampleEntry(population_size, selection_probability, queries))
sample_pos += 1
return samples
class PrivacyLedger(object):
"""Class for keeping a record of private queries.
The PrivacyLedger keeps a record of all queries executed over a given dataset
for the purpose of computing privacy guarantees. To use it, it must be
associated with a `DPQuery` object via a `QueryWithLedger`.
The current implementation works only with DPQueries that consist of composing
Gaussian sum mechanism with Poisson subsampling.
Example usage:
```
import tensorflow_privacy as tfp
dp_query = tfp.QueryWithLedger(
tensorflow_privacy.GaussianSumQuery(
l2_norm_clip=1.0, stddev=1.0),
population_size=10000,
selection_probability=0.01)
# Use dp_query here in training loop.
formatted_ledger = dp_query.ledger.get_formatted_ledger_eager()
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
list(range(5, 64)) + [128, 256, 512])
total_rdp = tfp.compute_rdp_from_ledger(formatted_ledger, orders)
epsilon = tfp.get_privacy_spent(orders, total_rdp, target_delta=1e-5)
```
"""
def __init__(self,
population_size,
selection_probability):
"""Initializes the PrivacyLedger.
Args:
population_size: An integer (may be variable) specifying the size of the
population, i.e. size of the training data used in each epoch.
selection_probability: A floating point value (may be variable) specifying
the probability each record is included in a sample.
Raises:
ValueError: If `selection_probability` is 0.
"""
self._population_size = population_size
self._selection_probability = selection_probability
if tf.executing_eagerly():
if tf.equal(selection_probability, 0):
raise ValueError('Selection probability cannot be 0.')
init_capacity = tf.cast(tf.math.ceil(1 / selection_probability), tf.int32)
else:
if selection_probability == 0:
raise ValueError('Selection probability cannot be 0.')
init_capacity = np.int(np.ceil(1 / selection_probability))
# The query buffer stores rows corresponding to GaussianSumQueryEntries.
self._query_buffer = tensor_buffer.TensorBuffer(
init_capacity, [3], tf.float32, 'query')
self._sample_var = tf.Variable(
initial_value=tf.zeros([3]), trainable=False, name='sample')
# The sample buffer stores rows corresponding to SampleEntries.
self._sample_buffer = tensor_buffer.TensorBuffer(
init_capacity, [3], tf.float32, 'sample')
self._sample_count = tf.Variable(
initial_value=0.0, trainable=False, name='sample_count')
self._query_count = tf.Variable(
initial_value=0.0, trainable=False, name='query_count')
self._cs = tf.CriticalSection()
def record_sum_query(self, l2_norm_bound, noise_stddev):
"""Records that a query was issued.
Args:
l2_norm_bound: The maximum l2 norm of the tensor group in the query.
noise_stddev: The standard deviation of the noise applied to the sum.
Returns:
An operation recording the sum query to the ledger. This should be called
for every Gaussian sum query that is issued on a sample.
"""
def _do_record_query():
with tf.control_dependencies(
[tf.assign(self._query_count, self._query_count + 1)]):
return self._query_buffer.append(
[self._sample_count, l2_norm_bound, noise_stddev])
return self._cs.execute(_do_record_query)
def finalize_sample(self):
"""Finalizes sample and records sample ledger entry.
This should be called once per application of the mechanism on a sample,
after all sum queries have been recorded.
Returns:
An operation recording the complete mechanism (sampling and sum
estimation) to the ledger.
"""
with tf.control_dependencies([
tf.assign(self._sample_var, [
self._population_size, self._selection_probability,
self._query_count
])
]):
with tf.control_dependencies([
tf.assign(self._sample_count, self._sample_count + 1),
tf.assign(self._query_count, 0)
]):
return self._sample_buffer.append(self._sample_var)
def get_unformatted_ledger(self):
"""Returns the raw sample and query values."""
return self._sample_buffer.values, self._query_buffer.values
def get_formatted_ledger(self, sess):
"""Gets the formatted query ledger.
Args:
sess: The tensorflow session in which the ledger was created.
Returns:
The query ledger as a list of `SampleEntry` instances.
"""
sample_array = sess.run(self._sample_buffer.values)
query_array = sess.run(self._query_buffer.values)
return format_ledger(sample_array, query_array)
def get_formatted_ledger_eager(self):
"""Gets the formatted query ledger.
Returns:
The query ledger as a list of `SampleEntry` instances.
"""
sample_array = self._sample_buffer.values.numpy()
query_array = self._query_buffer.values.numpy()
return format_ledger(sample_array, query_array)
class QueryWithLedger(dp_query.DPQuery):
"""A class for DP queries that record events to a `PrivacyLedger`.
`QueryWithLedger` should be the top-level query in a structure of queries that
may include sum queries, nested queries, etc. It should simply wrap another
query and contain a reference to the ledger. Any contained queries (including
those contained in the leaves of a nested query) should also contain a
reference to the same ledger object.
Only composed Gaussian sum queries with Poisson subsampling are supported.
This includes `GaussianSumQuery`, `QuantileEstimatorQuery`, and
`QuantileAdaptiveClipSumQuery`, as well as `NestedQuery` or `NormalizedQuery`
objects that contain the previous mentioned query types.
"""
def __init__(self, query,
population_size=None, selection_probability=None,
ledger=None):
"""Initializes the `QueryWithLedger`.
Args:
query: The query whose events should be recorded to the ledger. Any
subqueries (including those in the leaves of a nested query) should also
contain a reference to the same ledger given here.
population_size: An integer (may be variable) specifying the size of the
population, i.e. size of the training data used in each epoch. May be
`None` if `ledger` is specified.
selection_probability: A floating point value (may be variable) specifying
the probability each record is included in a sample under Poisson
subsampling. May be `None` if `ledger` is specified.
ledger: A `PrivacyLedger` to use. Must be specified if either of
`population_size` or `selection_probability` is `None`.
"""
self._query = query
if population_size is not None and selection_probability is not None:
self.set_ledger(PrivacyLedger(population_size, selection_probability))
elif ledger is not None:
self.set_ledger(ledger)
else:
raise ValueError('One of (population_size, selection_probability) or '
'ledger must be specified.')
@property
def ledger(self):
"""Gets the ledger that all inner queries record to."""
return self._ledger
def set_ledger(self, ledger):
"""Sets a new ledger."""
self._ledger = ledger
self._query.set_ledger(ledger)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return self._query.initial_global_state()
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return self._query.derive_sample_params(global_state)
def initial_sample_state(self, template):
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
return self._query.initial_sample_state(template)
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
return self._query.preprocess_record(params, record)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
"""Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`."""
return self._query.accumulate_preprocessed_record(
sample_state, preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
return self._query.merge_sample_states(sample_state_1, sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`.
Besides noising and returning the result of the inner query, ensures that
the sample is recorded to the ledger.
Args:
sample_state: The sample state after all records have been accumulated.
global_state: The global state, storing long-term privacy bookkeeping.
Returns:
A tuple (result, new_global_state) where "result" is the result of the
query and "new_global_state" is the updated global state.
"""
# Ensure sample_state is fully aggregated before calling get_noised_result.
with tf.control_dependencies(tf.nest.flatten(sample_state)):
result, new_global_state = self._query.get_noised_result(
sample_state, global_state)
# Ensure inner queries have recorded before finalizing.
with tf.control_dependencies(tf.nest.flatten(result)):
finalize = self._ledger.finalize_sample()
# Ensure finalizing happens.
with tf.control_dependencies([finalize]):
return tf.nest.map_structure(tf.identity, result), new_global_state

View file

@ -1,133 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PrivacyLedger."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import nested_query
from tensorflow_privacy.privacy.dp_query import test_utils
tf.enable_eager_execution()
class PrivacyLedgerTest(tf.test.TestCase):
def test_fail_on_probability_zero(self):
with self.assertRaisesRegexp(ValueError,
'Selection probability cannot be 0.'):
privacy_ledger.PrivacyLedger(10, 0)
def test_basic(self):
ledger = privacy_ledger.PrivacyLedger(10, 0.1)
ledger.record_sum_query(5.0, 1.0)
ledger.record_sum_query(2.0, 0.5)
ledger.finalize_sample()
expected_queries = [[5.0, 1.0], [2.0, 0.5]]
formatted = ledger.get_formatted_ledger_eager()
sample = formatted[0]
self.assertAllClose(sample.population_size, 10.0)
self.assertAllClose(sample.selection_probability, 0.1)
self.assertAllClose(sorted(sample.queries), sorted(expected_queries))
def test_sum_query(self):
record1 = tf.constant([2.0, 0.0])
record2 = tf.constant([-1.0, 1.0])
population_size = tf.Variable(0)
selection_probability = tf.Variable(1.0)
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
query = privacy_ledger.QueryWithLedger(query, population_size,
selection_probability)
# First sample.
tf.assign(population_size, 10)
tf.assign(selection_probability, 0.1)
test_utils.run_query(query, [record1, record2])
expected_queries = [[10.0, 0.0]]
formatted = query.ledger.get_formatted_ledger_eager()
sample_1 = formatted[0]
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
# Second sample.
tf.assign(population_size, 20)
tf.assign(selection_probability, 0.2)
test_utils.run_query(query, [record1, record2])
formatted = query.ledger.get_formatted_ledger_eager()
sample_1, sample_2 = formatted
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
self.assertAllClose(sample_2.population_size, 20.0)
self.assertAllClose(sample_2.selection_probability, 0.2)
self.assertAllClose(sample_2.queries, expected_queries)
def test_nested_query(self):
population_size = tf.Variable(0)
selection_probability = tf.Variable(1.0)
query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=4.0, stddev=2.0)
query2 = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=1.0)
query = nested_query.NestedQuery([query1, query2])
query = privacy_ledger.QueryWithLedger(query, population_size,
selection_probability)
record1 = [1.0, [12.0, 9.0]]
record2 = [5.0, [1.0, 2.0]]
# First sample.
tf.assign(population_size, 10)
tf.assign(selection_probability, 0.1)
test_utils.run_query(query, [record1, record2])
expected_queries = [[4.0, 2.0], [5.0, 1.0]]
formatted = query.ledger.get_formatted_ledger_eager()
sample_1 = formatted[0]
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries))
# Second sample.
tf.assign(population_size, 20)
tf.assign(selection_probability, 0.2)
test_utils.run_query(query, [record1, record2])
formatted = query.ledger.get_formatted_ledger_eager()
sample_1, sample_2 = formatted
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries))
self.assertAllClose(sample_2.population_size, 20.0)
self.assertAllClose(sample_2.selection_probability, 0.2)
self.assertAllClose(sorted(sample_2.queries), sorted(expected_queries))
if __name__ == '__main__':
tf.test.main()

View file

@ -47,7 +47,6 @@ import numpy as np
from scipy import special
import six
########################
# LOG-SPACE ARITHMETIC #
########################
@ -102,8 +101,8 @@ def _log_print(logx):
def _log_comb(n, k):
return (special.gammaln(n + 1) -
special.gammaln(k + 1) - special.gammaln(n - k + 1))
return (special.gammaln(n + 1) - special.gammaln(k + 1) -
special.gammaln(n - k + 1))
def _compute_log_a_int(q, sigma, alpha):
@ -215,17 +214,19 @@ def _compute_delta(orders, rdp, eps):
# Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4):
logdeltas = [] # work in log space to avoid overflows
for (a, r) in zip(orders_vec, rdp_vec):
if a < 1: raise ValueError("Renyi divergence order must be >=1.")
if r < 0: raise ValueError("Renyi divergence must be >=0.")
if a < 1:
raise ValueError("Renyi divergence order must be >=1.")
if r < 0:
raise ValueError("Renyi divergence must be >=0.")
# For small alpha, we are better of with bound via KL divergence:
# delta <= sqrt(1-exp(-KL)).
# Take a min of the two bounds.
logdelta = 0.5*math.log1p(-math.exp(-r))
logdelta = 0.5 * math.log1p(-math.exp(-r))
if a > 1.01:
# This bound is not numerically stable as alpha->1.
# Thus we have a min value for alpha.
# The bound is also not useful for small alpha, so doesn't matter.
rdp_bound = (a - 1) * (r - eps + math.log1p(-1/a)) - math.log(a)
rdp_bound = (a - 1) * (r - eps + math.log1p(-1 / a)) - math.log(a)
logdelta = min(logdelta, rdp_bound)
logdeltas.append(logdelta)
@ -264,8 +265,10 @@ def _compute_eps(orders, rdp, delta):
# Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1).
eps_vec = []
for (a, r) in zip(orders_vec, rdp_vec):
if a < 1: raise ValueError("Renyi divergence order must be >=1.")
if r < 0: raise ValueError("Renyi divergence must be >=0.")
if a < 1:
raise ValueError("Renyi divergence order must be >=1.")
if r < 0:
raise ValueError("Renyi divergence must be >=0.")
if delta**2 + math.expm1(-r) >= 0:
# In this case, we can simply bound via KL divergence:
@ -378,7 +381,7 @@ def compute_rdp(q, noise_multiplier, steps, orders):
Args:
q: The sampling rate.
noise_multiplier: The ratio of the standard deviation of the Gaussian noise
to the l2-sensitivity of the function to which it is added.
to the l2-sensitivity of the function to which it is added.
steps: The number of steps.
orders: An array (or a scalar) of RDP orders.
@ -388,8 +391,8 @@ def compute_rdp(q, noise_multiplier, steps, orders):
if np.isscalar(orders):
rdp = _compute_rdp(q, noise_multiplier, orders)
else:
rdp = np.array([_compute_rdp(q, noise_multiplier, order)
for order in orders])
rdp = np.array(
[_compute_rdp(q, noise_multiplier, order) for order in orders])
return rdp * steps
@ -572,8 +575,8 @@ def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
target_eps: If not `None`, the epsilon for which we compute the
corresponding delta.
target_delta: If not `None`, the delta for which we compute the
corresponding epsilon. Exactly one of `target_eps` and `target_delta`
must be `None`.
corresponding epsilon. Exactly one of `target_eps` and `target_delta` must
be `None`.
Returns:
A tuple of epsilon, delta, and the optimal order.
@ -595,24 +598,3 @@ def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
else:
eps, opt_order = _compute_eps(orders, rdp, target_delta)
return eps, target_delta, opt_order
def compute_rdp_from_ledger(ledger, orders):
"""Computes RDP of Sampled Gaussian Mechanism from ledger.
Args:
ledger: A formatted privacy ledger.
orders: An array (or a scalar) of RDP orders.
Returns:
RDP at all orders. Can be `np.inf`.
"""
total_rdp = np.zeros_like(orders, dtype=float)
for sample in ledger:
# Compute equivalent z from l2_clip_bounds and noise stddevs in sample.
# See https://arxiv.org/pdf/1812.06210.pdf for derivation of this formula.
effective_z = sum([
(q.noise_stddev / q.l2_norm_bound)**-2 for q in sample.queries])**-0.5
total_rdp += compute_rdp(
sample.selection_probability, effective_z, 1, orders)
return total_rdp

View file

@ -31,7 +31,6 @@ from mpmath import quad
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.analysis import rdp_accountant
@ -121,16 +120,47 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
[6.5007e-04, 1.0854e-03, 2.1808e-03, 2.3846e-02, 1.6742e+02, np.inf],
rtol=1e-4)
params = ({'q': 1e-7, 'sigma': .1, 'order': 1.01},
{'q': 1e-6, 'sigma': .1, 'order': 256},
{'q': 1e-5, 'sigma': .1, 'order': 256.1},
{'q': 1e-6, 'sigma': 1, 'order': 27},
{'q': 1e-4, 'sigma': 1., 'order': 1.5},
{'q': 1e-3, 'sigma': 1., 'order': 2},
{'q': .01, 'sigma': 10, 'order': 20},
{'q': .1, 'sigma': 100, 'order': 20.5},
{'q': .99, 'sigma': .1, 'order': 256},
{'q': .999, 'sigma': 100, 'order': 256.1})
params = ({
'q': 1e-7,
'sigma': .1,
'order': 1.01
}, {
'q': 1e-6,
'sigma': .1,
'order': 256
}, {
'q': 1e-5,
'sigma': .1,
'order': 256.1
}, {
'q': 1e-6,
'sigma': 1,
'order': 27
}, {
'q': 1e-4,
'sigma': 1.,
'order': 1.5
}, {
'q': 1e-3,
'sigma': 1.,
'order': 2
}, {
'q': .01,
'sigma': 10,
'order': 20
}, {
'q': .1,
'sigma': 100,
'order': 20.5
}, {
'q': .99,
'sigma': .1,
'order': 256
}, {
'q': .999,
'sigma': 100,
'order': 256.1
})
# pylint:disable=undefined-variable
@parameterized.parameters(p for p in params)
@ -152,7 +182,8 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
self.assertAlmostEqual(eps, 1.32783806176)
# Second test for Gaussian noise (with no subsampling):
orders = [0.001*i for i in range(1000, 100000)] # Pick fine set of orders.
orders = [0.001 * i for i in range(1000, 100000)
] # Pick fine set of orders.
rdp = rdp_accountant.compute_rdp(1, 4.530877117, 1, orders)
# Scale is chosen to obtain exactly (1,1e-6)-DP.
eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-6)
@ -168,7 +199,7 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
self.assertAlmostEqual(delta, 1e-5)
# Second test for Gaussian noise (with no subsampling):
orders = [0.001*i for i in range(1000, 100000)] # Pick fine set of order.
orders = [0.001 * i for i in range(1000, 100000)] # Pick fine set of order.
rdp = rdp_accountant.compute_rdp(1, 4.530877117, 1, orders)
# Scale is chosen to obtain exactly (1,1e-6)-DP.
_, delta, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_eps=1)
@ -178,17 +209,13 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
orders = (1.25, 1.5, 1.75, 2., 2.5, 3., 4., 5., 6., 7., 8., 10., 12., 14.,
16., 20., 24., 28., 32., 64., 256.)
rdp = rdp_accountant.compute_rdp(q=1e-4,
noise_multiplier=.4,
steps=40000,
orders=orders)
rdp = rdp_accountant.compute_rdp(
q=1e-4, noise_multiplier=.4, steps=40000, orders=orders)
eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-6)
rdp += rdp_accountant.compute_rdp(q=0.1,
noise_multiplier=2,
steps=100,
orders=orders)
rdp += rdp_accountant.compute_rdp(
q=0.1, noise_multiplier=2, steps=100, orders=orders)
eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-5)
# These tests use the old RDP -> approx DP conversion
# self.assertAlmostEqual(eps, 8.509656, places=5)
@ -217,42 +244,25 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
def test_get_privacy_spent_gaussian(self):
# Compare the optimal bound for Gaussian with the one derived from RDP.
# Also compare the RDP upper bound with the "standard" upper bound.
orders = [0.1*x for x in range(10, 505)]
eps_vec = [0.1*x for x in range(500)]
orders = [0.1 * x for x in range(10, 505)]
eps_vec = [0.1 * x for x in range(500)]
rdp = rdp_accountant.compute_rdp(1, 1, 1, orders)
for eps in eps_vec:
_, delta, _ = rdp_accountant.get_privacy_spent(orders, rdp,
target_eps=eps)
_, delta, _ = rdp_accountant.get_privacy_spent(
orders, rdp, target_eps=eps)
# For comparison, we compute the optimal guarantee for Gaussian
# using https://arxiv.org/abs/1805.06530 Theorem 8 (in v2).
delta0 = math.erfc((eps-.5)/math.sqrt(2))/2
delta0 = delta0 - math.exp(eps)*math.erfc((eps+.5)/math.sqrt(2))/2
self.assertLessEqual(delta0, delta+1e-300) # need tolerance 10^-300
delta0 = math.erfc((eps - .5) / math.sqrt(2)) / 2
delta0 = delta0 - math.exp(eps) * math.erfc((eps + .5) / math.sqrt(2)) / 2
self.assertLessEqual(delta0, delta + 1e-300) # need tolerance 10^-300
# Compute the "standard" upper bound, which should be an upper bound.
# Note, if orders is too sparse, this will NOT be an upper bound.
if eps >= 0.5:
delta1 = math.exp(-0.5*(eps-0.5)**2)
delta1 = math.exp(-0.5 * (eps - 0.5)**2)
else:
delta1 = 1
self.assertLessEqual(delta, delta1+1e-300)
def test_compute_rdp_from_ledger(self):
orders = range(2, 33)
q = 0.1
n = 1000
l2_norm_clip = 3.14159
noise_stddev = 2.71828
steps = 3
query_entry = privacy_ledger.GaussianSumQueryEntry(
l2_norm_clip, noise_stddev)
ledger = [privacy_ledger.SampleEntry(n, q, [query_entry])] * steps
z = noise_stddev / l2_norm_clip
rdp = rdp_accountant.compute_rdp(q, z, steps, orders)
rdp_from_ledger = rdp_accountant.compute_rdp_from_ledger(ledger, orders)
self.assertSequenceAlmostEqual(rdp, rdp_from_ledger)
self.assertLessEqual(delta, delta1 + 1e-300)
if __name__ == '__main__':

View file

@ -46,11 +46,6 @@ class DiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
self._l2_norm_bound = l2_norm_bound
self._stddev = stddev
def set_ledger(self, ledger):
del ledger # Unused.
raise NotImplementedError('Ledger has not yet been implemented for'
'DiscreteGaussianSumQuery!')
def initial_global_state(self):
return self._GlobalState(
tf.cast(self._l2_norm_bound, tf.float32),

View file

@ -46,11 +46,6 @@ class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
self._l2_norm_bound = l2_norm_bound
self._local_stddev = local_stddev
def set_ledger(self, ledger):
del ledger # Unused.
raise NotImplementedError('Ledger has not yet been implemented for'
'DistributedDiscreteGaussianSumQuery!')
def initial_global_state(self):
return self._GlobalState(
tf.cast(self._l2_norm_bound, tf.float32),

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An interface for differentially private query mechanisms.
The DPQuery class abstracts the differential privacy mechanism needed by DP-SGD.
@ -100,18 +99,6 @@ class DPQuery(object):
__metaclass__ = abc.ABCMeta
def set_ledger(self, ledger):
"""Supplies privacy ledger to which the query can record privacy events.
The ledger should be updated with each call to get_noised_result.
Args:
ledger: A `PrivacyLedger`.
"""
del ledger
raise TypeError(
'DPQuery type %s does not support set_ledger.' % type(self).__name__)
def initial_global_state(self):
"""Returns the initial global state for the DPQuery.
@ -155,7 +142,6 @@ class DPQuery(object):
as a template to create the initial sample state. It is assumed that the
leaves of the structure are python scalars or some type that has
properties `shape` and `dtype`.
Returns: An initial sample state.
"""
pass
@ -171,12 +157,12 @@ class DPQuery(object):
variables that are stored in self.
Args:
params: The parameters for the sample. In standard DP-SGD training,
the clipping norm for the sample's microbatch gradients (i.e.,
a maximum norm magnitude to which each gradient is clipped)
record: The record to be processed. In standard DP-SGD training,
the gradient computed for the examples in one microbatch, which
may be the gradient for just one example (for size 1 microbatches).
params: The parameters for the sample. In standard DP-SGD training, the
clipping norm for the sample's microbatch gradients (i.e., a maximum
norm magnitude to which each gradient is clipped)
record: The record to be processed. In standard DP-SGD training, the
gradient computed for the examples in one microbatch, which may be the
gradient for just one example (for size 1 microbatches).
Returns:
A structure of tensors to be aggregated.
@ -185,8 +171,7 @@ class DPQuery(object):
return record
@abc.abstractmethod
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record):
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
"""Accumulates a single preprocessed record into the sample state.
This method is intended to only do simple aggregation, typically just a sum.
@ -194,8 +179,8 @@ class DPQuery(object):
declaratively specify the type of aggregation required.
Args:
sample_state: The current sample state. In standard DP-SGD training,
the accumulated sum of previous clipped microbatch gradients.
sample_state: The current sample state. In standard DP-SGD training, the
accumulated sum of previous clipped microbatch gradients.
preprocessed_record: The preprocessed record to accumulate.
Returns:
@ -211,22 +196,22 @@ class DPQuery(object):
functions run on a single device. Typically this will be a simple sum.
Args:
params: The parameters for the sample. In standard DP-SGD training,
the clipping norm for the sample's microbatch gradients (i.e.,
a maximum norm magnitude to which each gradient is clipped)
sample_state: The current sample state. In standard DP-SGD training,
the accumulated sum of previous clipped microbatch gradients.
record: The record to accumulate. In standard DP-SGD training,
the gradient computed for the examples in one microbatch, which
may be the gradient for just one example (for size 1 microbatches).
params: The parameters for the sample. In standard DP-SGD training, the
clipping norm for the sample's microbatch gradients (i.e., a maximum
norm magnitude to which each gradient is clipped)
sample_state: The current sample state. In standard DP-SGD training, the
accumulated sum of previous clipped microbatch gradients.
record: The record to accumulate. In standard DP-SGD training, the
gradient computed for the examples in one microbatch, which may be the
gradient for just one example (for size 1 microbatches).
Returns:
The updated sample state. In standard DP-SGD training, the set of
previous microbatch gradients with the addition of the record argument.
"""
preprocessed_record = self.preprocess_record(params, record)
return self.accumulate_preprocessed_record(
sample_state, preprocessed_record)
return self.accumulate_preprocessed_record(sample_state,
preprocessed_record)
@abc.abstractmethod
def merge_sample_states(self, sample_state_1, sample_state_2):

View file

@ -47,10 +47,6 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
self._stddev = stddev
self._ledger = None
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
self._ledger = ledger
def make_global_state(self, l2_norm_clip, stddev):
"""Creates a global state from the given parameters."""
return self._GlobalState(

View file

@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for queries over nested structures.
"""
"""Implements DPQuery interface for queries over nested structures."""
from __future__ import absolute_import
from __future__ import division
@ -60,16 +58,13 @@ class NestedQuery(dp_query.DPQuery):
def _map_to_queries(self, fn, *inputs, **kwargs):
"""Maps DPQuery methods to the subqueries."""
def caller(query, *args):
return getattr(query, fn)(*args, **kwargs)
return tree.map_structure_up_to(self._queries, caller, self._queries,
*inputs)
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
self._map_to_queries('set_ledger', ledger=ledger)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return self._map_to_queries('initial_global_state')
@ -89,18 +84,15 @@ class NestedQuery(dp_query.DPQuery):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
return self._map_to_queries('preprocess_record', params, record)
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record):
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
"""Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`."""
return self._map_to_queries(
'accumulate_preprocessed_record',
sample_state,
preprocessed_record)
return self._map_to_queries('accumulate_preprocessed_record', sample_state,
preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
return self._map_to_queries(
'merge_sample_states', sample_state_1, sample_state_2)
return self._map_to_queries('merge_sample_states', sample_state_1,
sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
@ -118,12 +110,12 @@ class NestedQuery(dp_query.DPQuery):
def add_metrics(tuple_path, subquery, subquery_global_state):
metrics.update({
'/'.join(str(s) for s in tuple_path + (name,)): metric
for name, metric
in subquery.derive_metrics(subquery_global_state).items()})
'/'.join(str(s) for s in tuple_path + (name,)): metric for name,
metric in subquery.derive_metrics(subquery_global_state).items()
})
tree.map_structure_with_path_up_to(
self._queries, add_metrics, self._queries, global_state)
tree.map_structure_with_path_up_to(self._queries, add_metrics,
self._queries, global_state)
return metrics
@ -137,12 +129,13 @@ class NestedSumQuery(NestedQuery, dp_query.SumAggregationDPQuery):
Args:
queries: A nested structure of queries that must all be
SumAggregationDPQueries.
Raises: TypeError if any of the subqueries are not SumAggregationDPQueries.
"""
def check(query):
if not isinstance(query, dp_query.SumAggregationDPQuery):
raise TypeError('All subqueries must be SumAggregationDPQueries.')
tree.map_structure(check, queries)
super(NestedSumQuery, self).__init__(queries)

View file

@ -17,8 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.dp_query import dp_query
@ -33,20 +31,11 @@ class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
def __init__(self):
self._ledger = None
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
warnings.warn(
'Attempt to use NoPrivacySumQuery with privacy ledger. Privacy '
'guarantees will be vacuous.')
self._ledger = ledger
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
if self._ledger:
dependencies = [
self._ledger.record_sum_query(float('inf'), 0.0)
]
dependencies = [self._ledger.record_sum_query(float('inf'), 0.0)]
else:
dependencies = []
@ -71,17 +60,10 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
"""Initializes the NoPrivacyAverageQuery."""
self._ledger = None
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
warnings.warn(
'Attempt to use NoPrivacyAverageQuery with privacy ledger. Privacy '
'guarantees will be vacuous.')
self._ledger = ledger
def initial_sample_state(self, template):
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
return (super(NoPrivacyAverageQuery, self).initial_sample_state(template),
tf.constant(0.0))
return (super(NoPrivacyAverageQuery,
self).initial_sample_state(template), tf.constant(0.0))
def preprocess_record(self, params, record, weight=1):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
@ -122,9 +104,7 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
sum_state, denominator = sample_state
if self._ledger:
dependencies = [
self._ledger.record_sum_query(float('inf'), 0.0)
]
dependencies = [self._ledger.record_sum_query(float('inf'), 0.0)]
else:
dependencies = []

View file

@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for normalized queries.
"""
"""Implements DPQuery interface for normalized queries."""
from __future__ import absolute_import
from __future__ import division
@ -38,8 +36,8 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple(
'_GlobalState', ['numerator_state', 'denominator'])
_GlobalState = collections.namedtuple('_GlobalState',
['numerator_state', 'denominator'])
def __init__(self, numerator_query, denominator):
"""Initializes the NormalizedQuery.
@ -55,15 +53,11 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
assert isinstance(self._numerator, dp_query.SumAggregationDPQuery)
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
self._numerator.set_ledger(ledger)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
denominator = tf.cast(self._denominator, tf.float32)
return self._GlobalState(
self._numerator.initial_global_state(), denominator)
return self._GlobalState(self._numerator.initial_global_state(),
denominator)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -82,6 +76,7 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
sample_state, global_state.numerator_state)
def normalize(v):
return tf.truediv(v, global_state.denominator)

View file

@ -91,11 +91,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
assert isinstance(self._quantile_estimator_query,
dp_query.SumAggregationDPQuery)
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
self._sum_query.set_ledger(ledger)
self._quantile_estimator_query.set_ledger(ledger)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return self._GlobalState(

View file

@ -22,7 +22,6 @@ from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import quantile_adaptive_clip_sum_query
from tensorflow_privacy.privacy.dp_query import test_utils
@ -291,53 +290,6 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
if t > 40:
self.assertNear(actual_clip, 5.0, 0.5)
def test_ledger(self):
record1 = tf.constant([8.5])
record2 = tf.constant([-7.25])
population_size = tf.Variable(0)
selection_probability = tf.Variable(1.0)
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip=10.0,
noise_multiplier=1.0,
target_unclipped_quantile=0.0,
learning_rate=1.0,
clipped_count_stddev=0.0,
expected_num_records=2.0,
geometric_update=False)
query = privacy_ledger.QueryWithLedger(query, population_size,
selection_probability)
# First sample.
tf.assign(population_size, 10)
tf.assign(selection_probability, 0.1)
_, global_state = test_utils.run_query(query, [record1, record2])
expected_queries = [[10.0, 10.0], [0.5, 0.0]]
formatted = query.ledger.get_formatted_ledger_eager()
sample_1 = formatted[0]
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
# Second sample.
tf.assign(population_size, 20)
tf.assign(selection_probability, 0.2)
test_utils.run_query(query, [record1, record2], global_state)
formatted = query.ledger.get_formatted_ledger_eager()
sample_1, sample_2 = formatted
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
expected_queries_2 = [[9.0, 9.0], [0.5, 0.0]]
self.assertAllClose(sample_2.population_size, 20.0)
self.assertAllClose(sample_2.selection_probability, 0.2)
self.assertAllClose(sample_2.queries, expected_queries_2)
if __name__ == '__main__':
tf.test.main()

View file

@ -101,10 +101,6 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
l2_norm_clip=0.5, stddev=below_estimate_stddev),
denominator=expected_num_records)
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
self._below_estimate_query.set_ledger(ledger)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return self._GlobalState(

View file

@ -21,7 +21,6 @@ from absl import logging
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
@ -166,8 +165,8 @@ def make_optimizer_class(cls):
sample_state = process_microbatch(idx, sample_state)
grad_sums, self._global_state = (
self._dp_sum_query.get_noised_result(
sample_state, self._global_state))
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
def normalize(v):
return v / tf.cast(self._num_microbatches, tf.float32)
@ -197,8 +196,8 @@ def make_optimizer_class(cls):
"""Process one microbatch (record) with privacy helper."""
self_super = super(DPOptimizerClass, self)
mean_loss = tf.reduce_mean(input_tensor=tf.gather(
microbatches_losses, [i]))
mean_loss = tf.reduce_mean(
input_tensor=tf.gather(microbatches_losses, [i]))
if hasattr(self_super, 'compute_gradients'):
# This case covers optimizers in tf.train.
@ -208,8 +207,8 @@ def make_optimizer_class(cls):
compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access
grads, _ = zip(*compute_gradients_fn(
mean_loss, var_list, gate_gradients,
aggregation_method, colocate_gradients_with_ops, grad_loss))
mean_loss, var_list, gate_gradients, aggregation_method,
colocate_gradients_with_ops, grad_loss))
grads_list = list(grads)
sample_state = self._dp_sum_query.accumulate_record(
@ -218,8 +217,8 @@ def make_optimizer_class(cls):
if var_list is None:
var_list = (
tf.trainable_variables() + tf.get_collection(
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
tf.trainable_variables() +
tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
sample_state = self._dp_sum_query.initial_sample_state(var_list)
@ -237,8 +236,8 @@ def make_optimizer_class(cls):
cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state])
grad_sums, self._global_state = (
self._dp_sum_query.get_noised_result(
sample_state, self._global_state))
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
def normalize(v):
try:
@ -307,9 +306,7 @@ def make_gaussian_optimizer_class(cls):
```
""").format(
'tf.compat.v1.train.' + cls.__name__,
cls.__name__,
cls.__name__,
'tf.compat.v1.train.' + cls.__name__, cls.__name__, cls.__name__,
'DP' + cls.__name__.replace('Optimizer', 'GaussianOptimizer'))
def __init__(
@ -317,7 +314,6 @@ def make_gaussian_optimizer_class(cls):
l2_norm_clip,
noise_multiplier,
num_microbatches=None,
ledger=None,
unroll_microbatches=False,
*args, # pylint: disable=keyword-arg-before-vararg
**kwargs):
@ -329,7 +325,6 @@ def make_gaussian_optimizer_class(cls):
num_microbatches: Number of microbatches into which each minibatch is
split. If `None`, will default to the size of the minibatch, and
per-example gradients will be computed.
ledger: Defaults to `None`. An instance of `tf_privacy.PrivacyLedger`.
unroll_microbatches: If true, processes microbatches within a Python
loop instead of a `tf.while_loop`. Can be used if using a
`tf.while_loop` raises an exception.
@ -344,16 +339,9 @@ def make_gaussian_optimizer_class(cls):
dp_sum_query = gaussian_query.GaussianSumQuery(
l2_norm_clip, l2_norm_clip * noise_multiplier)
if ledger:
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query,
ledger=ledger)
super(DPGaussianOptimizerClass, self).__init__(
dp_sum_query,
num_microbatches,
unroll_microbatches,
*args,
**kwargs)
super(DPGaussianOptimizerClass,
self).__init__(dp_sum_query, num_microbatches, unroll_microbatches,
*args, **kwargs)
def get_config(self):
"""Creates configuration for Keras serialization.
@ -370,7 +358,8 @@ def make_gaussian_optimizer_class(cls):
config.update({
'l2_norm_clip': self._l2_norm_clip,
'noise_multiplier': self._noise_multiplier,
'num_microbatches': self._num_microbatches})
'num_microbatches': self._num_microbatches
})
return config
@ -380,6 +369,7 @@ def make_gaussian_optimizer_class(cls):
return DPGaussianOptimizerClass
AdagradOptimizer = tf.train.AdagradOptimizer
AdamOptimizer = tf.train.AdamOptimizer
GradientDescentOptimizer = tf.train.GradientDescentOptimizer

View file

@ -22,7 +22,6 @@ import numpy as np
from six.moves import range
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.optimizers import dp_optimizer
@ -56,13 +55,9 @@ class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase):
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(
dp_sum_query, 1e6, num_microbatches / 1e6)
opt = cls(
dp_sum_query,
num_microbatches=num_microbatches,
learning_rate=2.0)
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
@ -85,7 +80,6 @@ class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase):
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
@ -109,7 +103,6 @@ class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase):
data0 = tf.Variable([[0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(4.0, 8.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)

View file

@ -24,7 +24,6 @@ import numpy as np
from six.moves import range
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.optimizers import dp_optimizer
@ -51,9 +50,8 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-2.5, -2.5]),
('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5]),
('DPRMSPropOptimizer 1', dp_optimizer.DPRMSPropOptimizer, 1,
[-2.5, -2.5]),
('DPRMSPropOptimizer 2', dp_optimizer.DPRMSPropOptimizer, 2,
[-2.5, -2.5]),
[-2.5, -2.5]), ('DPRMSPropOptimizer 2', dp_optimizer.DPRMSPropOptimizer,
2, [-2.5, -2.5]),
('DPRMSPropOptimizer 4', dp_optimizer.DPRMSPropOptimizer, 4, [-2.5, -2.5])
)
def testBaseline(self, cls, num_microbatches, expected_answer):
@ -62,13 +60,9 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(
dp_sum_query, 1e6, num_microbatches / 1e6)
opt = cls(
dp_sum_query,
num_microbatches=num_microbatches,
learning_rate=2.0)
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
@ -91,7 +85,6 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
@ -115,7 +108,6 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
data0 = tf.Variable([[0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(4.0, 8.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
@ -157,11 +149,8 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
vector_loss = tf.math.squared_difference(labels, preds)
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
optimizer = dp_optimizer.DPGradientDescentOptimizer(
dp_sum_query,
num_microbatches=1,
learning_rate=1.0)
dp_sum_query, num_microbatches=1, learning_rate=1.0)
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
return tf.estimator.EstimatorSpec(
@ -201,8 +190,6 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
num_microbatches = 4
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(
dp_sum_query, 1e6, num_microbatches / 1e6)
opt = cls(
dp_sum_query,
@ -283,8 +270,6 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
extra_variable = tf.Variable('foo', trainable=True, dtype=tf.string)
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6,
num_microbatches / 1e6)
opt = cls(
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
@ -298,27 +283,26 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
sess.run(minimize_op)
def _testWriteOutAndReload(self, optimizer_cls):
optimizer = optimizer_cls(l2_norm_clip=1.0,
noise_multiplier=0.01,
num_microbatches=1)
optimizer = optimizer_cls(
l2_norm_clip=1.0, noise_multiplier=0.01, num_microbatches=1)
test_dir = self.get_temp_dir()
model_path = os.path.join(test_dir, 'model')
model = tf.keras.Sequential([tf.keras.layers.InputLayer(input_shape=(1, 1)),
tf.keras.layers.Dense(units=1,
activation='softmax')])
model.compile(optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True))
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(1, 1)),
tf.keras.layers.Dense(units=1, activation='softmax')
])
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
tf.keras.models.save_model(model, filepath=model_path,
include_optimizer=True)
tf.keras.models.save_model(
model, filepath=model_path, include_optimizer=True)
optimizer_cls_str = optimizer_cls.__name__
tf.keras.models.load_model(model_path,
custom_objects={
optimizer_cls_str: optimizer_cls})
tf.keras.models.load_model(
model_path, custom_objects={optimizer_cls_str: optimizer_cls})
return

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training a language model (recurrent neural network) with DP-SGD optimizer.
This tutorial uses a corpus of text from TensorFlow datasets unless a
@ -44,7 +43,6 @@ import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.optimizers import dp_optimizer
@ -92,27 +90,20 @@ def rnn_model_fn(features, labels, mode): # pylint: disable=unused-argument
if mode == tf.estimator.ModeKeys.TRAIN:
if FLAGS.dpsgd:
ledger = privacy_ledger.PrivacyLedger(
population_size=NB_TRAIN,
selection_probability=(FLAGS.batch_size / NB_TRAIN))
optimizer = dp_optimizer.DPAdamGaussianOptimizer(
l2_norm_clip=FLAGS.l2_norm_clip,
noise_multiplier=FLAGS.noise_multiplier,
num_microbatches=FLAGS.microbatches,
ledger=ledger,
learning_rate=FLAGS.learning_rate,
unroll_microbatches=True)
opt_loss = vector_loss
else:
optimizer = tf.train.AdamOptimizer(
learning_rate=FLAGS.learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
opt_loss = scalar_loss
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode,
loss=scalar_loss,
train_op=train_op)
return tf.estimator.EstimatorSpec(
mode=mode, loss=scalar_loss, train_op=train_op)
# Add evaluation metrics (for EVAL mode).
elif mode == tf.estimator.ModeKeys.EVAL:
@ -122,9 +113,8 @@ def rnn_model_fn(features, labels, mode): # pylint: disable=unused-argument
labels=tf.cast(x[:, 1:], dtype=tf.int32),
predictions=tf.argmax(input=logits, axis=2))
}
return tf.estimator.EstimatorSpec(mode=mode,
loss=scalar_loss,
eval_metric_ops=eval_metric_ops)
return tf.estimator.EstimatorSpec(
mode=mode, loss=scalar_loss, eval_metric_ops=eval_metric_ops)
def load_data():
@ -132,13 +122,13 @@ def load_data():
if not FLAGS.data_dir:
print('FLAGS.data_dir containing train.txt and test.txt was not specified, '
'using a substitute dataset from the tensorflow_datasets module.')
train_dataset = tfds.load(name='lm1b/subwords8k',
split=tfds.Split.TRAIN,
batch_size=NB_TRAIN,
shuffle_files=True)
test_dataset = tfds.load(name='lm1b/subwords8k',
split=tfds.Split.TEST,
batch_size=10000)
train_dataset = tfds.load(
name='lm1b/subwords8k',
split=tfds.Split.TRAIN,
batch_size=NB_TRAIN,
shuffle_files=True)
test_dataset = tfds.load(
name='lm1b/subwords8k', split=tfds.Split.TEST, batch_size=10000)
train_data = next(iter(tfds.as_numpy(train_dataset)))
test_data = next(iter(tfds.as_numpy(test_dataset)))
train_data = train_data['text'].flatten()
@ -162,10 +152,11 @@ def compute_epsilon(steps):
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
sampling_probability = FLAGS.batch_size / NB_TRAIN
rdp = compute_rdp(q=sampling_probability,
noise_multiplier=FLAGS.noise_multiplier,
steps=steps,
orders=orders)
rdp = compute_rdp(
q=sampling_probability,
noise_multiplier=FLAGS.noise_multiplier,
steps=steps,
orders=orders)
# Delta is set to 1e-5 because Penn TreeBank has 60000 training points.
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
@ -180,9 +171,8 @@ def main(unused_argv):
# Instantiate the tf.Estimator.
conf = tf.estimator.RunConfig(save_summary_steps=1000)
lm_classifier = tf.estimator.Estimator(model_fn=rnn_model_fn,
model_dir=FLAGS.model_dir,
config=conf)
lm_classifier = tf.estimator.Estimator(
model_fn=rnn_model_fn, model_dir=FLAGS.model_dir, config=conf)
# Create tf.Estimator input functions for the training and test data.
batch_len = FLAGS.batch_size * SEQ_LEN
@ -221,5 +211,6 @@ def main(unused_argv):
else:
print('Trained with vanilla non-private SGD optimizer')
if __name__ == '__main__':
app.run(main)