Restructure TF Privacy to be more in line with other repos in the TF ecosystem.

PiperOrigin-RevId: 274674077
This commit is contained in:
Steve Chien 2019-10-14 15:29:21 -07:00 committed by A. Unique TensorFlower
parent c0e05f6cad
commit 1ce8cd4032
47 changed files with 6849 additions and 23 deletions

View file

@ -0,0 +1,57 @@
# Copyright 2019, The TensorFlow Privacy Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Privacy library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
# pylint: disable=g-import-not-at-top
if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts.
pass
else:
from tensorflow_privacy.privacy.analysis.privacy_ledger import GaussianSumQueryEntry
from tensorflow_privacy.privacy.analysis.privacy_ledger import PrivacyLedger
from tensorflow_privacy.privacy.analysis.privacy_ledger import QueryWithLedger
from tensorflow_privacy.privacy.analysis.privacy_ledger import SampleEntry
from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianAverageQuery
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery
from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacySumQuery
from tensorflow_privacy.privacy.dp_query.normalized_query import NormalizedQuery
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipAverageQuery
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdagradGaussianOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdamGaussianOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
try:
from tensorflow_privacy.privacy.bolt_on.models import BoltOnModel
from tensorflow_privacy.privacy.bolt_on.optimizers import BoltOn
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexBinaryCrossentropy
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexHuber
except ImportError:
# module `bolt_on` not yet available in this version of TF Privacy
pass

View file

@ -0,0 +1,5 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"])
exports_files(["LICENSE"])

View file

@ -0,0 +1,13 @@
# Copyright 2019, The TensorFlow Privacy Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,97 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Command-line script for computing privacy of a model trained with DP-SGD.
The script applies the RDP accountant to estimate privacy budget of an iterated
Sampled Gaussian Mechanism. The mechanism's parameters are controlled by flags.
Example:
compute_dp_sgd_privacy
--N=60000 \
--batch_size=256 \
--noise_multiplier=1.12 \
--epochs=60 \
--delta=1e-5
The output states that DP-SGD with these parameters satisfies (2.92, 1e-5)-DP.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import sys
from absl import app
from absl import flags
# Opting out of loading all sibling packages and their dependencies.
sys.skip_tf_privacy_import = True
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp # pylint: disable=g-import-not-at-top
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
FLAGS = flags.FLAGS
flags.DEFINE_integer('N', None, 'Total number of examples')
flags.DEFINE_integer('batch_size', None, 'Batch size')
flags.DEFINE_float('noise_multiplier', None, 'Noise multiplier for DP-SGD')
flags.DEFINE_float('epochs', None, 'Number of epochs (may be fractional)')
flags.DEFINE_float('delta', 1e-6, 'Target delta')
flags.mark_flag_as_required('N')
flags.mark_flag_as_required('batch_size')
flags.mark_flag_as_required('noise_multiplier')
flags.mark_flag_as_required('epochs')
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
"""Compute and print results of DP-SGD analysis."""
# compute_rdp requires that sigma be the ratio of the standard deviation of
# the Gaussian noise to the l2-sensitivity of the function to which it is
# added. Hence, sigma here corresponds to the `noise_multiplier` parameter
# in the DP-SGD implementation found in privacy.optimizers.dp_optimizer
rdp = compute_rdp(q, sigma, steps, orders)
eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta)
print('DP-SGD with sampling rate = {:.3g}% and noise_multiplier = {} iterated'
' over {} steps satisfies'.format(100 * q, sigma, steps), end=' ')
print('differential privacy with eps = {:.3g} and delta = {}.'.format(
eps, delta))
print('The optimal RDP order is {}.'.format(opt_order))
if opt_order == max(orders) or opt_order == min(orders):
print('The privacy estimate is likely to be improved by expanding '
'the set of orders.')
def main(argv):
del argv # argv is not used.
q = FLAGS.batch_size / FLAGS.N # q - the sampling ratio.
if q > 1:
raise app.UsageError('N must be larger than the batch size.')
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
list(range(5, 64)) + [128, 256, 512])
steps = int(math.ceil(FLAGS.epochs * FLAGS.N / FLAGS.batch_size))
apply_dp_sgd_analysis(q, FLAGS.noise_multiplier, steps, orders, FLAGS.delta)
if __name__ == '__main__':
app.run(main)

View file

@ -0,0 +1,257 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PrivacyLedger class for keeping a record of private queries."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from distutils.version import LooseVersion
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import tensor_buffer
from tensorflow_privacy.privacy.dp_query import dp_query
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
else:
nest = tf.nest
SampleEntry = collections.namedtuple( # pylint: disable=invalid-name
'SampleEntry', ['population_size', 'selection_probability', 'queries'])
GaussianSumQueryEntry = collections.namedtuple( # pylint: disable=invalid-name
'GaussianSumQueryEntry', ['l2_norm_bound', 'noise_stddev'])
def format_ledger(sample_array, query_array):
"""Converts array representation into a list of SampleEntries."""
samples = []
query_pos = 0
sample_pos = 0
for sample in sample_array:
population_size, selection_probability, num_queries = sample
queries = []
for _ in range(int(num_queries)):
query = query_array[query_pos]
assert int(query[0]) == sample_pos
queries.append(GaussianSumQueryEntry(*query[1:]))
query_pos += 1
samples.append(SampleEntry(population_size, selection_probability, queries))
sample_pos += 1
return samples
class PrivacyLedger(object):
"""Class for keeping a record of private queries.
The PrivacyLedger keeps a record of all queries executed over a given dataset
for the purpose of computing privacy guarantees.
"""
def __init__(self,
population_size,
selection_probability):
"""Initialize the PrivacyLedger.
Args:
population_size: An integer (may be variable) specifying the size of the
population, i.e. size of the training data used in each epoch.
selection_probability: A float (may be variable) specifying the
probability each record is included in a sample.
Raises:
ValueError: If selection_probability is 0.
"""
self._population_size = population_size
self._selection_probability = selection_probability
if tf.executing_eagerly():
if tf.equal(selection_probability, 0):
raise ValueError('Selection probability cannot be 0.')
init_capacity = tf.cast(tf.ceil(1 / selection_probability), tf.int32)
else:
if selection_probability == 0:
raise ValueError('Selection probability cannot be 0.')
init_capacity = np.int(np.ceil(1 / selection_probability))
# The query buffer stores rows corresponding to GaussianSumQueryEntries.
self._query_buffer = tensor_buffer.TensorBuffer(
init_capacity, [3], tf.float32, 'query')
self._sample_var = tf.Variable(
initial_value=tf.zeros([3]), trainable=False, name='sample')
# The sample buffer stores rows corresponding to SampleEntries.
self._sample_buffer = tensor_buffer.TensorBuffer(
init_capacity, [3], tf.float32, 'sample')
self._sample_count = tf.Variable(
initial_value=0.0, trainable=False, name='sample_count')
self._query_count = tf.Variable(
initial_value=0.0, trainable=False, name='query_count')
try:
# Newer versions of TF
self._cs = tf.CriticalSection()
except AttributeError:
# Older versions of TF
self._cs = tf.contrib.framework.CriticalSection()
def record_sum_query(self, l2_norm_bound, noise_stddev):
"""Records that a query was issued.
Args:
l2_norm_bound: The maximum l2 norm of the tensor group in the query.
noise_stddev: The standard deviation of the noise applied to the sum.
Returns:
An operation recording the sum query to the ledger.
"""
def _do_record_query():
with tf.control_dependencies(
[tf.assign(self._query_count, self._query_count + 1)]):
return self._query_buffer.append(
[self._sample_count, l2_norm_bound, noise_stddev])
return self._cs.execute(_do_record_query)
def finalize_sample(self):
"""Finalizes sample and records sample ledger entry."""
with tf.control_dependencies([
tf.assign(self._sample_var, [
self._population_size, self._selection_probability,
self._query_count
])
]):
with tf.control_dependencies([
tf.assign(self._sample_count, self._sample_count + 1),
tf.assign(self._query_count, 0)
]):
return self._sample_buffer.append(self._sample_var)
def get_unformatted_ledger(self):
return self._sample_buffer.values, self._query_buffer.values
def get_formatted_ledger(self, sess):
"""Gets the formatted query ledger.
Args:
sess: The tensorflow session in which the ledger was created.
Returns:
The query ledger as a list of SampleEntries.
"""
sample_array = sess.run(self._sample_buffer.values)
query_array = sess.run(self._query_buffer.values)
return format_ledger(sample_array, query_array)
def get_formatted_ledger_eager(self):
"""Gets the formatted query ledger.
Returns:
The query ledger as a list of SampleEntries.
"""
sample_array = self._sample_buffer.values.numpy()
query_array = self._query_buffer.values.numpy()
return format_ledger(sample_array, query_array)
class QueryWithLedger(dp_query.DPQuery):
"""A class for DP queries that record events to a PrivacyLedger.
QueryWithLedger should be the top-level query in a structure of queries that
may include sum queries, nested queries, etc. It should simply wrap another
query and contain a reference to the ledger. Any contained queries (including
those contained in the leaves of a nested query) should also contain a
reference to the same ledger object.
For example usage, see privacy_ledger_test.py.
"""
def __init__(self, query,
population_size=None, selection_probability=None,
ledger=None):
"""Initializes the QueryWithLedger.
Args:
query: The query whose events should be recorded to the ledger. Any
subqueries (including those in the leaves of a nested query) should also
contain a reference to the same ledger given here.
population_size: An integer (may be variable) specifying the size of the
population, i.e. size of the training data used in each epoch. May be
None if `ledger` is specified.
selection_probability: A float (may be variable) specifying the
probability each record is included in a sample. May be None if `ledger`
is specified.
ledger: A PrivacyLedger to use. Must be specified if either of
`population_size` or `selection_probability` is None.
"""
self._query = query
if population_size is not None and selection_probability is not None:
self.set_ledger(PrivacyLedger(population_size, selection_probability))
elif ledger is not None:
self.set_ledger(ledger)
else:
raise ValueError('One of (population_size, selection_probability) or '
'ledger must be specified.')
@property
def ledger(self):
return self._ledger
def set_ledger(self, ledger):
self._ledger = ledger
self._query.set_ledger(ledger)
def initial_global_state(self):
"""See base class."""
return self._query.initial_global_state()
def derive_sample_params(self, global_state):
"""See base class."""
return self._query.derive_sample_params(global_state)
def initial_sample_state(self, template):
"""See base class."""
return self._query.initial_sample_state(template)
def preprocess_record(self, params, record):
"""See base class."""
return self._query.preprocess_record(params, record)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
"""See base class."""
return self._query.accumulate_preprocessed_record(
sample_state, preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):
"""See base class."""
return self._query.merge_sample_states(sample_state_1, sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Ensures sample is recorded to the ledger and returns noised result."""
# Ensure sample_state is fully aggregated before calling get_noised_result.
with tf.control_dependencies(nest.flatten(sample_state)):
result, new_global_state = self._query.get_noised_result(
sample_state, global_state)
# Ensure inner queries have recorded before finalizing.
with tf.control_dependencies(nest.flatten(result)):
finalize = self._ledger.finalize_sample()
# Ensure finalizing happens.
with tf.control_dependencies([finalize]):
return nest.map_structure(tf.identity, result), new_global_state

View file

@ -0,0 +1,137 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PrivacyLedger."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import nested_query
from tensorflow_privacy.privacy.dp_query import test_utils
tf.enable_eager_execution()
class PrivacyLedgerTest(tf.test.TestCase):
def test_fail_on_probability_zero(self):
with self.assertRaisesRegexp(ValueError,
'Selection probability cannot be 0.'):
privacy_ledger.PrivacyLedger(10, 0)
def test_basic(self):
ledger = privacy_ledger.PrivacyLedger(10, 0.1)
ledger.record_sum_query(5.0, 1.0)
ledger.record_sum_query(2.0, 0.5)
ledger.finalize_sample()
expected_queries = [[5.0, 1.0], [2.0, 0.5]]
formatted = ledger.get_formatted_ledger_eager()
sample = formatted[0]
self.assertAllClose(sample.population_size, 10.0)
self.assertAllClose(sample.selection_probability, 0.1)
self.assertAllClose(sorted(sample.queries), sorted(expected_queries))
def test_sum_query(self):
record1 = tf.constant([2.0, 0.0])
record2 = tf.constant([-1.0, 1.0])
population_size = tf.Variable(0)
selection_probability = tf.Variable(1.0)
query = gaussian_query.GaussianSumQuery(
l2_norm_clip=10.0, stddev=0.0)
query = privacy_ledger.QueryWithLedger(
query, population_size, selection_probability)
# First sample.
tf.assign(population_size, 10)
tf.assign(selection_probability, 0.1)
test_utils.run_query(query, [record1, record2])
expected_queries = [[10.0, 0.0]]
formatted = query.ledger.get_formatted_ledger_eager()
sample_1 = formatted[0]
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
# Second sample.
tf.assign(population_size, 20)
tf.assign(selection_probability, 0.2)
test_utils.run_query(query, [record1, record2])
formatted = query.ledger.get_formatted_ledger_eager()
sample_1, sample_2 = formatted
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
self.assertAllClose(sample_2.population_size, 20.0)
self.assertAllClose(sample_2.selection_probability, 0.2)
self.assertAllClose(sample_2.queries, expected_queries)
def test_nested_query(self):
population_size = tf.Variable(0)
selection_probability = tf.Variable(1.0)
query1 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=4.0, sum_stddev=2.0, denominator=5.0)
query2 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=5.0, sum_stddev=1.0, denominator=5.0)
query = nested_query.NestedQuery([query1, query2])
query = privacy_ledger.QueryWithLedger(
query, population_size, selection_probability)
record1 = [1.0, [12.0, 9.0]]
record2 = [5.0, [1.0, 2.0]]
# First sample.
tf.assign(population_size, 10)
tf.assign(selection_probability, 0.1)
test_utils.run_query(query, [record1, record2])
expected_queries = [[4.0, 2.0], [5.0, 1.0]]
formatted = query.ledger.get_formatted_ledger_eager()
sample_1 = formatted[0]
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries))
# Second sample.
tf.assign(population_size, 20)
tf.assign(selection_probability, 0.2)
test_utils.run_query(query, [record1, record2])
formatted = query.ledger.get_formatted_ledger_eager()
sample_1, sample_2 = formatted
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries))
self.assertAllClose(sample_2.population_size, 20.0)
self.assertAllClose(sample_2.selection_probability, 0.2)
self.assertAllClose(sorted(sample_2.queries), sorted(expected_queries))
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,318 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RDP analysis of the Sampled Gaussian Mechanism.
Functionality for computing Renyi differential privacy (RDP) of an additive
Sampled Gaussian Mechanism (SGM). Its public interface consists of two methods:
compute_rdp(q, noise_multiplier, T, orders) computes RDP for SGM iterated
T times.
get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta
(or eps) given RDP at multiple orders and
a target value for eps (or delta).
Example use:
Suppose that we have run an SGM applied to a function with l2-sensitivity 1.
Its parameters are given as a list of tuples (q1, sigma1, T1), ...,
(qk, sigma_k, Tk), and we wish to compute eps for a given delta.
The example code would be:
max_order = 32
orders = range(2, max_order + 1)
rdp = np.zeros_like(orders, dtype=float)
for q, sigma, T in parameters:
rdp += rdp_accountant.compute_rdp(q, sigma, T, orders)
eps, _, opt_order = rdp_accountant.get_privacy_spent(rdp, target_delta=delta)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import sys
import numpy as np
from scipy import special
import six
########################
# LOG-SPACE ARITHMETIC #
########################
def _log_add(logx, logy):
"""Add two numbers in the log space."""
a, b = min(logx, logy), max(logx, logy)
if a == -np.inf: # adding 0
return b
# Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b)
return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1)
def _log_sub(logx, logy):
"""Subtract two numbers in the log space. Answer must be non-negative."""
if logx < logy:
raise ValueError("The result of subtraction must be non-negative.")
if logy == -np.inf: # subtracting 0
return logx
if logx == logy:
return -np.inf # 0 is represented as -np.inf in the log space.
try:
# Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y).
return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1
except OverflowError:
return logx
def _log_print(logx):
"""Pretty print."""
if logx < math.log(sys.float_info.max):
return "{}".format(math.exp(logx))
else:
return "exp({})".format(logx)
def _compute_log_a_int(q, sigma, alpha):
"""Compute log(A_alpha) for integer alpha. 0 < q < 1."""
assert isinstance(alpha, six.integer_types)
# Initialize with 0 in the log space.
log_a = -np.inf
for i in range(alpha + 1):
log_coef_i = (
math.log(special.binom(alpha, i)) + i * math.log(q) +
(alpha - i) * math.log(1 - q))
s = log_coef_i + (i * i - i) / (2 * (sigma**2))
log_a = _log_add(log_a, s)
return float(log_a)
def _compute_log_a_frac(q, sigma, alpha):
"""Compute log(A_alpha) for fractional alpha. 0 < q < 1."""
# The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are
# initialized to 0 in the log space:
log_a0, log_a1 = -np.inf, -np.inf
i = 0
z0 = sigma**2 * math.log(1 / q - 1) + .5
while True: # do ... until loop
coef = special.binom(alpha, i)
log_coef = math.log(abs(coef))
j = alpha - i
log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q)
log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q)
log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma))
log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma))
log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0
log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1
if coef > 0:
log_a0 = _log_add(log_a0, log_s0)
log_a1 = _log_add(log_a1, log_s1)
else:
log_a0 = _log_sub(log_a0, log_s0)
log_a1 = _log_sub(log_a1, log_s1)
i += 1
if max(log_s0, log_s1) < -30:
break
return _log_add(log_a0, log_a1)
def _compute_log_a(q, sigma, alpha):
"""Compute log(A_alpha) for any positive finite alpha."""
if float(alpha).is_integer():
return _compute_log_a_int(q, sigma, int(alpha))
else:
return _compute_log_a_frac(q, sigma, alpha)
def _log_erfc(x):
"""Compute log(erfc(x)) with high accuracy for large x."""
try:
return math.log(2) + special.log_ndtr(-x * 2**.5)
except NameError:
# If log_ndtr is not available, approximate as follows:
r = special.erfc(x)
if r == 0.0:
# Using the Laurent series at infinity for the tail of the erfc function:
# erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5)
# To verify in Mathematica:
# Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}]
return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 +
.625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8)
else:
return math.log(r)
def _compute_delta(orders, rdp, eps):
"""Compute delta given a list of RDP values and target epsilon.
Args:
orders: An array (or a scalar) of orders.
rdp: A list (or a scalar) of RDP guarantees.
eps: The target epsilon.
Returns:
Pair of (delta, optimal_order).
Raises:
ValueError: If input is malformed.
"""
orders_vec = np.atleast_1d(orders)
rdp_vec = np.atleast_1d(rdp)
if len(orders_vec) != len(rdp_vec):
raise ValueError("Input lists must have the same length.")
deltas = np.exp((rdp_vec - eps) * (orders_vec - 1))
idx_opt = np.argmin(deltas)
return min(deltas[idx_opt], 1.), orders_vec[idx_opt]
def _compute_eps(orders, rdp, delta):
"""Compute epsilon given a list of RDP values and target delta.
Args:
orders: An array (or a scalar) of orders.
rdp: A list (or a scalar) of RDP guarantees.
delta: The target delta.
Returns:
Pair of (eps, optimal_order).
Raises:
ValueError: If input is malformed.
"""
orders_vec = np.atleast_1d(orders)
rdp_vec = np.atleast_1d(rdp)
if len(orders_vec) != len(rdp_vec):
raise ValueError("Input lists must have the same length.")
eps = rdp_vec - math.log(delta) / (orders_vec - 1)
idx_opt = np.nanargmin(eps) # Ignore NaNs
return eps[idx_opt], orders_vec[idx_opt]
def _compute_rdp(q, sigma, alpha):
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
Args:
q: The sampling rate.
sigma: The std of the additive Gaussian noise.
alpha: The order at which RDP is computed.
Returns:
RDP at alpha, can be np.inf.
"""
if q == 0:
return 0
if q == 1.:
return alpha / (2 * sigma**2)
if np.isinf(alpha):
return np.inf
return _compute_log_a(q, sigma, alpha) / (alpha - 1)
def compute_rdp(q, noise_multiplier, steps, orders):
"""Compute RDP of the Sampled Gaussian Mechanism.
Args:
q: The sampling rate.
noise_multiplier: The ratio of the standard deviation of the Gaussian noise
to the l2-sensitivity of the function to which it is added.
steps: The number of steps.
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders, can be np.inf.
"""
if np.isscalar(orders):
rdp = _compute_rdp(q, noise_multiplier, orders)
else:
rdp = np.array([_compute_rdp(q, noise_multiplier, order)
for order in orders])
return rdp * steps
def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
"""Compute delta (or eps) for given eps (or delta) from RDP values.
Args:
orders: An array (or a scalar) of RDP orders.
rdp: An array of RDP values. Must be of the same length as the orders list.
target_eps: If not None, the epsilon for which we compute the corresponding
delta.
target_delta: If not None, the delta for which we compute the corresponding
epsilon. Exactly one of target_eps and target_delta must be None.
Returns:
eps, delta, opt_order.
Raises:
ValueError: If target_eps and target_delta are messed up.
"""
if target_eps is None and target_delta is None:
raise ValueError(
"Exactly one out of eps and delta must be None. (Both are).")
if target_eps is not None and target_delta is not None:
raise ValueError(
"Exactly one out of eps and delta must be None. (None is).")
if target_eps is not None:
delta, opt_order = _compute_delta(orders, rdp, target_eps)
return target_eps, delta, opt_order
else:
eps, opt_order = _compute_eps(orders, rdp, target_delta)
return eps, target_delta, opt_order
def compute_rdp_from_ledger(ledger, orders):
"""Compute RDP of Sampled Gaussian Mechanism from ledger.
Args:
ledger: A formatted privacy ledger.
orders: An array (or a scalar) of RDP orders.
Returns:
RDP at all orders, can be np.inf.
"""
total_rdp = np.zeros_like(orders, dtype=float)
for sample in ledger:
# Compute equivalent z from l2_clip_bounds and noise stddevs in sample.
# See https://arxiv.org/pdf/1812.06210.pdf for derivation of this formula.
effective_z = sum([
(q.noise_stddev / q.l2_norm_bound)**-2 for q in sample.queries])**-0.5
total_rdp += compute_rdp(
sample.selection_probability, effective_z, 1, orders)
return total_rdp

View file

@ -0,0 +1,177 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for rdp_accountant.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from absl.testing import absltest
from absl.testing import parameterized
from mpmath import exp
from mpmath import inf
from mpmath import log
from mpmath import npdf
from mpmath import quad
import numpy as np
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.analysis import rdp_accountant
class TestGaussianMoments(parameterized.TestCase):
#################################
# HELPER FUNCTIONS: #
# Exact computations using #
# multi-precision arithmetic. #
#################################
def _log_float_mp(self, x):
# Convert multi-precision input to float log space.
if x >= sys.float_info.min:
return float(log(x))
else:
return -np.inf
def _integral_mp(self, fn, bounds=(-inf, inf)):
integral, _ = quad(fn, bounds, error=True, maxdegree=8)
return integral
def _distributions_mp(self, sigma, q):
def _mu0(x):
return npdf(x, mu=0, sigma=sigma)
def _mu1(x):
return npdf(x, mu=1, sigma=sigma)
def _mu(x):
return (1 - q) * _mu0(x) + q * _mu1(x)
return _mu0, _mu # Closure!
def _mu1_over_mu0(self, x, sigma):
# Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x.
return exp((2 * x - 1) / (2 * sigma**2))
def _mu_over_mu0(self, x, q, sigma):
return (1 - q) + q * self._mu1_over_mu0(x, sigma)
def _compute_a_mp(self, sigma, q, alpha):
"""Compute A_alpha for arbitrary alpha by numerical integration."""
mu0, _ = self._distributions_mp(sigma, q)
a_alpha_fn = lambda z: mu0(z) * self._mu_over_mu0(z, q, sigma)**alpha
a_alpha = self._integral_mp(a_alpha_fn)
return a_alpha
# TEST ROUTINES
def test_compute_rdp_no_data(self):
# q = 0
self.assertEqual(rdp_accountant.compute_rdp(0, 10, 1, 20), 0)
def test_compute_rdp_no_sampling(self):
# q = 1, RDP = alpha/2 * sigma^2
self.assertEqual(rdp_accountant.compute_rdp(1, 10, 1, 20), 0.1)
def test_compute_rdp_scalar(self):
rdp_scalar = rdp_accountant.compute_rdp(0.1, 2, 10, 5)
self.assertAlmostEqual(rdp_scalar, 0.07737, places=5)
def test_compute_rdp_sequence(self):
rdp_vec = rdp_accountant.compute_rdp(0.01, 2.5, 50,
[1.5, 2.5, 5, 50, 100, np.inf])
self.assertSequenceAlmostEqual(
rdp_vec, [0.00065, 0.001085, 0.00218075, 0.023846, 167.416307, np.inf],
delta=1e-5)
params = ({'q': 1e-7, 'sigma': .1, 'order': 1.01},
{'q': 1e-6, 'sigma': .1, 'order': 256},
{'q': 1e-5, 'sigma': .1, 'order': 256.1},
{'q': 1e-6, 'sigma': 1, 'order': 27},
{'q': 1e-4, 'sigma': 1., 'order': 1.5},
{'q': 1e-3, 'sigma': 1., 'order': 2},
{'q': .01, 'sigma': 10, 'order': 20},
{'q': .1, 'sigma': 100, 'order': 20.5},
{'q': .99, 'sigma': .1, 'order': 256},
{'q': .999, 'sigma': 100, 'order': 256.1})
# pylint:disable=undefined-variable
@parameterized.parameters(p for p in params)
def test_compute_log_a_equals_mp(self, q, sigma, order):
# Compare the cheap computation of log(A) with an expensive, multi-precision
# computation.
log_a = rdp_accountant._compute_log_a(q, sigma, order)
log_a_mp = self._log_float_mp(self._compute_a_mp(sigma, q, order))
np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4)
def test_get_privacy_spent_check_target_delta(self):
orders = range(2, 33)
rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders)
eps, _, opt_order = rdp_accountant.get_privacy_spent(
orders, rdp, target_delta=1e-5)
self.assertAlmostEqual(eps, 1.258575, places=5)
self.assertEqual(opt_order, 20)
def test_get_privacy_spent_check_target_eps(self):
orders = range(2, 33)
rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders)
_, delta, opt_order = rdp_accountant.get_privacy_spent(
orders, rdp, target_eps=1.258575)
self.assertAlmostEqual(delta, 1e-5)
self.assertEqual(opt_order, 20)
def test_check_composition(self):
orders = (1.25, 1.5, 1.75, 2., 2.5, 3., 4., 5., 6., 7., 8., 10., 12., 14.,
16., 20., 24., 28., 32., 64., 256.)
rdp = rdp_accountant.compute_rdp(q=1e-4,
noise_multiplier=.4,
steps=40000,
orders=orders)
eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp,
target_delta=1e-6)
rdp += rdp_accountant.compute_rdp(q=0.1,
noise_multiplier=2,
steps=100,
orders=orders)
eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp,
target_delta=1e-5)
self.assertAlmostEqual(eps, 8.509656, places=5)
self.assertEqual(opt_order, 2.5)
def test_compute_rdp_from_ledger(self):
orders = range(2, 33)
q = 0.1
n = 1000
l2_norm_clip = 3.14159
noise_stddev = 2.71828
steps = 3
query_entry = privacy_ledger.GaussianSumQueryEntry(
l2_norm_clip, noise_stddev)
ledger = [privacy_ledger.SampleEntry(n, q, [query_entry])] * steps
z = noise_stddev / l2_norm_clip
rdp = rdp_accountant.compute_rdp(q, z, steps, orders)
rdp_from_ledger = rdp_accountant.compute_rdp_from_ledger(ledger, orders)
self.assertSequenceAlmostEqual(rdp, rdp_from_ledger)
if __name__ == '__main__':
absltest.main()

View file

@ -0,0 +1,134 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A lightweight buffer for maintaining tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class TensorBuffer(object):
"""A lightweight buffer for maintaining lists.
The TensorBuffer accumulates tensors of the given shape into a tensor (whose
rank is one more than that of the given shape) via calls to `append`. The
current value of the accumulated tensor can be extracted via the property
`values`.
"""
def __init__(self, capacity, shape, dtype=tf.int32, name=None):
"""Initializes the TensorBuffer.
Args:
capacity: Initial capacity. Buffer will double in capacity each time it is
filled to capacity.
shape: The shape (as tuple or list) of the tensors to accumulate.
dtype: The type of the tensors.
name: A string name for the variable_scope used.
Raises:
ValueError: If the shape is empty (specifies scalar shape).
"""
shape = list(shape)
self._rank = len(shape)
self._name = name
self._dtype = dtype
if not self._rank:
raise ValueError('Shape cannot be scalar.')
shape = [capacity] + shape
with tf.variable_scope(self._name):
# We need to use a placeholder as the initial value to allow resizing.
self._buffer = tf.Variable(
initial_value=tf.placeholder_with_default(
tf.zeros(shape, dtype), shape=None),
trainable=False,
name='buffer',
use_resource=True)
self._current_size = tf.Variable(
initial_value=0, dtype=tf.int32, trainable=False, name='current_size')
self._capacity = tf.Variable(
initial_value=capacity,
dtype=tf.int32,
trainable=False,
name='capacity')
def append(self, value):
"""Appends a new tensor to the end of the buffer.
Args:
value: The tensor to append. Must match the shape specified in the
initializer.
Returns:
An op appending the new tensor to the end of the buffer.
"""
def _double_capacity():
"""Doubles the capacity of the current tensor buffer."""
padding = tf.zeros_like(self._buffer, self._buffer.dtype)
new_buffer = tf.concat([self._buffer, padding], axis=0)
if tf.executing_eagerly():
with tf.variable_scope(self._name, reuse=True):
self._buffer = tf.get_variable(
name='buffer',
dtype=self._dtype,
initializer=new_buffer,
trainable=False)
return self._buffer, tf.assign(self._capacity,
tf.multiply(self._capacity, 2))
else:
return tf.assign(
self._buffer, new_buffer,
validate_shape=False), tf.assign(self._capacity,
tf.multiply(self._capacity, 2))
update_buffer, update_capacity = tf.cond(
tf.equal(self._current_size, self._capacity),
_double_capacity, lambda: (self._buffer, self._capacity))
with tf.control_dependencies([update_buffer, update_capacity]):
with tf.control_dependencies([
tf.assert_less(
self._current_size,
self._capacity,
message='Appending past end of TensorBuffer.'),
tf.assert_equal(
tf.shape(value),
tf.shape(self._buffer)[1:],
message='Appending value of inconsistent shape.')
]):
with tf.control_dependencies(
[tf.assign(self._buffer[self._current_size, :], value)]):
return tf.assign_add(self._current_size, 1)
@property
def values(self):
"""Returns the accumulated tensor."""
begin_value = tf.zeros([self._rank + 1], dtype=tf.int32)
value_size = tf.concat([[self._current_size],
tf.constant(-1, tf.int32, [self._rank])], 0)
return tf.slice(self._buffer, begin_value, value_size)
@property
def current_size(self):
"""Returns the current number of tensors in the buffer."""
return self._current_size
@property
def capacity(self):
"""Returns the current capacity of the buffer."""
return self._capacity

View file

@ -0,0 +1,84 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor_buffer in eager mode."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import tensor_buffer
tf.enable_eager_execution()
class TensorBufferTest(tf.test.TestCase):
"""Tests for TensorBuffer in eager mode."""
def test_basic(self):
size, shape = 2, [2, 3]
my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer')
value1 = [[1, 2, 3], [4, 5, 6]]
my_buffer.append(value1)
self.assertAllEqual(my_buffer.values.numpy(), [value1])
value2 = [[4, 5, 6], [7, 8, 9]]
my_buffer.append(value2)
self.assertAllEqual(my_buffer.values.numpy(), [value1, value2])
def test_fail_on_scalar(self):
with self.assertRaisesRegexp(ValueError, 'Shape cannot be scalar.'):
tensor_buffer.TensorBuffer(1, ())
def test_fail_on_inconsistent_shape(self):
size, shape = 1, [2, 3]
my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer')
with self.assertRaisesRegexp(
tf.errors.InvalidArgumentError,
'Appending value of inconsistent shape.'):
my_buffer.append(tf.ones(shape=[3, 4], dtype=tf.int32))
def test_resize(self):
size, shape = 2, [2, 3]
my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer')
# Append three buffers. Third one should succeed after resizing.
value1 = [[1, 2, 3], [4, 5, 6]]
my_buffer.append(value1)
self.assertAllEqual(my_buffer.values.numpy(), [value1])
self.assertAllEqual(my_buffer.current_size.numpy(), 1)
self.assertAllEqual(my_buffer.capacity.numpy(), 2)
value2 = [[4, 5, 6], [7, 8, 9]]
my_buffer.append(value2)
self.assertAllEqual(my_buffer.values.numpy(), [value1, value2])
self.assertAllEqual(my_buffer.current_size.numpy(), 2)
self.assertAllEqual(my_buffer.capacity.numpy(), 2)
value3 = [[7, 8, 9], [10, 11, 12]]
my_buffer.append(value3)
self.assertAllEqual(my_buffer.values.numpy(), [value1, value2, value3])
self.assertAllEqual(my_buffer.current_size.numpy(), 3)
# Capacity should have doubled.
self.assertAllEqual(my_buffer.capacity.numpy(), 4)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,72 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor_buffer in graph mode."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import tensor_buffer
class TensorBufferTest(tf.test.TestCase):
"""Tests for TensorBuffer in graph mode."""
def test_noresize(self):
"""Test buffer does not resize if capacity is not exceeded."""
with self.cached_session() as sess:
size, shape = 2, [2, 3]
my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer')
value1 = [[1, 2, 3], [4, 5, 6]]
with tf.control_dependencies([my_buffer.append(value1)]):
value2 = [[7, 8, 9], [10, 11, 12]]
with tf.control_dependencies([my_buffer.append(value2)]):
values = my_buffer.values
current_size = my_buffer.current_size
capacity = my_buffer.capacity
self.evaluate(tf.global_variables_initializer())
v, cs, cap = sess.run([values, current_size, capacity])
self.assertAllEqual(v, [value1, value2])
self.assertEqual(cs, 2)
self.assertEqual(cap, 2)
def test_resize(self):
"""Test buffer resizes if capacity is exceeded."""
with self.cached_session() as sess:
size, shape = 2, [2, 3]
my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer')
value1 = [[1, 2, 3], [4, 5, 6]]
with tf.control_dependencies([my_buffer.append(value1)]):
value2 = [[7, 8, 9], [10, 11, 12]]
with tf.control_dependencies([my_buffer.append(value2)]):
value3 = [[13, 14, 15], [16, 17, 18]]
with tf.control_dependencies([my_buffer.append(value3)]):
values = my_buffer.values
current_size = my_buffer.current_size
capacity = my_buffer.capacity
self.evaluate(tf.global_variables_initializer())
v, cs, cap = sess.run([values, current_size, capacity])
self.assertAllEqual(v, [value1, value2, value3])
self.assertEqual(cs, 3)
self.assertEqual(cap, 4)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,67 @@
# BoltOn Subpackage
This package contains source code for the BoltOn method, a particular
differential-privacy (DP) technique that uses output perturbations and
leverages additional assumptions to provide a new way of approaching the
privacy guarantees.
## BoltOn Description
This method uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R, the radius of the hypothesis space,
after each batch. This value is configurable by the user.
3. Limits learning rate
4. Uses a strongly convex loss function (see compile)
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et al. at https://arxiv.org/pdf/1606.04722.pdf
## Why BoltOn?
The major difference for the BoltOn method is that it injects noise post model
convergence, rather than noising gradients or weights during training. This
approach requires some additional constraints listed in the Description.
Should the use-case and model satisfy these constraints, this is another
approach that can be trained to maximize utility while maintaining the privacy.
The paper describes in detail the advantages and disadvantages of this approach
and its results compared to some other methods, namely noising at each iteration
and no noising.
## Tutorials
This package has a tutorial that can be found in the root tutorials directory,
under `bolton_tutorial.py`.
## Contribution
This package was initially contributed by Georgian Partners with the hope of
growing the tensorflow/privacy library. There are several rich use cases for
delta-epsilon privacy in machine learning, some of which can be explored here:
https://medium.com/apache-mxnet/epsilon-differential-privacy-for-machine-learning-using-mxnet-a4270fe3865e
https://arxiv.org/pdf/1811.04911.pdf
## Stability
As we are pegged on tensorflow2.0, this package may encounter stability
issues in the ongoing development of tensorflow2.0.
This sub-package is currently stable for 2.0.0a0, 2.0.0b0, and 2.0.0.b1 If you
would like to use this subpackage, please do use one of these versions as we
cannot guarantee it will work for all latest releases. If you do find issues,
feel free to raise an issue to the contributors listed below.
## Contacts
In addition to the maintainers of tensorflow/privacy listed in the root
README.md, please feel free to contact members of Georgian Partners. In
particular,
* Georgian Partners(@georgianpartners)
* Ji Chao Zhang(@Jichaogp)
* Christopher Choquette(@cchoquette)
## Copyright
Copyright 2019 - Google LLC

View file

@ -0,0 +1,29 @@
# Copyright 2019, The TensorFlow Privacy Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BoltOn Method for privacy."""
import sys
from distutils.version import LooseVersion
import tensorflow as tf
if LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
raise ImportError("Please upgrade your version "
"of tensorflow from: {0} to at least 2.0.0 to "
"use privacy/bolt_on".format(LooseVersion(tf.__version__)))
if hasattr(sys, "skip_tf_privacy_import"): # Useful for standalone scripts.
pass
else:
from tensorflow_privacy.privacy.bolt_on.models import BoltOnModel # pylint: disable=g-import-not-at-top
from tensorflow_privacy.privacy.bolt_on.optimizers import BoltOn # pylint: disable=g-import-not-at-top
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexHuber # pylint: disable=g-import-not-at-top
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexBinaryCrossentropy # pylint: disable=g-import-not-at-top

View file

@ -0,0 +1,304 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss functions for BoltOn method."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.framework import ops as _ops
from tensorflow.python.keras import losses
from tensorflow.python.keras.regularizers import L1L2
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.platform import tf_logging as logging
class StrongConvexMixin: # pylint: disable=old-style-class
"""Strong Convex Mixin base class.
Strong Convex Mixin base class for any loss function that will be used with
BoltOn model. Subclasses must be strongly convex and implement the
associated constants. They must also conform to the requirements of tf losses
(see super class).
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
"""
def radius(self):
"""Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space.
Returns:
R
"""
raise NotImplementedError("Radius not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def gamma(self):
"""Returns strongly convex parameter, gamma."""
raise NotImplementedError("Gamma not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def beta(self, class_weight):
"""Smoothness, beta.
Args:
class_weight: the class weights as scalar or 1d tensor, where its
dimensionality is equal to the number of outputs.
Returns:
Beta
"""
raise NotImplementedError("Beta not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def lipchitz_constant(self, class_weight):
"""Lipchitz constant, L.
Args:
class_weight: class weights used
Returns: L
"""
raise NotImplementedError("lipchitz constant not implemented for "
"StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def kernel_regularizer(self):
"""Returns the kernel_regularizer to be used.
Any subclass should override this method if they want a kernel_regularizer
(if required for the loss function to be StronglyConvex.
"""
return None
def max_class_weight(self, class_weight, dtype):
"""The maximum weighting in class weights (max value) as a scalar tensor.
Args:
class_weight: class weights used
dtype: the data type for tensor conversions.
Returns:
maximum class weighting as tensor scalar
"""
class_weight = _ops.convert_to_tensor_v2(class_weight, dtype)
return tf.math.reduce_max(class_weight)
class StrongConvexHuber(losses.Loss, StrongConvexMixin):
"""Strong Convex version of Huber loss using l2 weight regularization."""
def __init__(self,
reg_lambda,
c_arg,
radius_constant,
delta,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
dtype=tf.float32):
"""Constructor.
Args:
reg_lambda: Weight regularization constant
c_arg: Penalty parameter C of the loss term
radius_constant: constant defining the length of the radius
delta: delta value in huber loss. When to switch from quadratic to
absolute deviation.
reduction: reduction type to use. See super class
dtype: tf datatype to use for tensor conversions.
Returns:
Loss values per sample.
"""
if c_arg <= 0:
raise ValueError("c: {0}, should be >= 0".format(c_arg))
if reg_lambda <= 0:
raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
if radius_constant <= 0:
raise ValueError("radius_constant: {0}, should be >= 0".format(
radius_constant
))
if delta <= 0:
raise ValueError("delta: {0}, should be >= 0".format(
delta
))
self.C = c_arg # pylint: disable=invalid-name
self.delta = delta
self.radius_constant = radius_constant
self.dtype = dtype
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
super(StrongConvexHuber, self).__init__(
name="strongconvexhuber",
reduction=reduction,
)
def call(self, y_true, y_pred):
"""Computes loss.
Args:
y_true: Ground truth values. One hot encoded using -1 and 1.
y_pred: The predicted values.
Returns:
Loss values per sample.
"""
h = self.delta
z = y_pred * y_true
one = tf.constant(1, dtype=self.dtype)
four = tf.constant(4, dtype=self.dtype)
if z > one + h: # pylint: disable=no-else-return
return _ops.convert_to_tensor_v2(0, dtype=self.dtype)
elif tf.math.abs(one - z) <= h:
return one / (four * h) * tf.math.pow(one + h - z, 2)
return one - z
def radius(self):
"""See super class."""
return self.radius_constant / self.reg_lambda
def gamma(self):
"""See super class."""
return self.reg_lambda
def beta(self, class_weight):
"""See super class."""
max_class_weight = self.max_class_weight(class_weight, self.dtype)
delta = _ops.convert_to_tensor_v2(self.delta,
dtype=self.dtype
)
return self.C * max_class_weight / (delta *
tf.constant(2, dtype=self.dtype)) + \
self.reg_lambda
def lipchitz_constant(self, class_weight):
"""See super class."""
# if class_weight is provided,
# it should be a vector of the same size of number of classes
max_class_weight = self.max_class_weight(class_weight, self.dtype)
lc = self.C * max_class_weight + \
self.reg_lambda * self.radius()
return lc
def kernel_regularizer(self):
"""Return l2 loss using 0.5*reg_lambda as the l2 term (as desired).
L2 regularization is required for this loss function to be strongly convex.
Returns:
The L2 regularizer layer for this loss function, with regularizer constant
set to half the 0.5 * reg_lambda.
"""
return L1L2(l2=self.reg_lambda/2)
class StrongConvexBinaryCrossentropy(
losses.BinaryCrossentropy,
StrongConvexMixin
):
"""Strongly Convex BinaryCrossentropy loss using l2 weight regularization."""
def __init__(self,
reg_lambda,
c_arg,
radius_constant,
from_logits=True,
label_smoothing=0,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
dtype=tf.float32):
"""StrongConvexBinaryCrossentropy class.
Args:
reg_lambda: Weight regularization constant
c_arg: Penalty parameter C of the loss term
radius_constant: constant defining the length of the radius
from_logits: True if the input are unscaled logits. False if they are
already scaled.
label_smoothing: amount of smoothing to perform on labels
relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x). Note, the
impact of this parameter's effect on privacy is not known and thus the
default should be used.
reduction: reduction type to use. See super class
dtype: tf datatype to use for tensor conversions.
"""
if label_smoothing != 0:
logging.warning("The impact of label smoothing on privacy is unknown. "
"Use label smoothing at your own risk as it may not "
"guarantee privacy.")
if reg_lambda <= 0:
raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
if c_arg <= 0:
raise ValueError("c: {0}, should be >= 0".format(c_arg))
if radius_constant <= 0:
raise ValueError("radius_constant: {0}, should be >= 0".format(
radius_constant
))
self.dtype = dtype
self.C = c_arg # pylint: disable=invalid-name
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
super(StrongConvexBinaryCrossentropy, self).__init__(
reduction=reduction,
name="strongconvexbinarycrossentropy",
from_logits=from_logits,
label_smoothing=label_smoothing,
)
self.radius_constant = radius_constant
def call(self, y_true, y_pred):
"""Computes loss.
Args:
y_true: Ground truth values.
y_pred: The predicted values.
Returns:
Loss values per sample.
"""
loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred)
loss = loss * self.C
return loss
def radius(self):
"""See super class."""
return self.radius_constant / self.reg_lambda
def gamma(self):
"""See super class."""
return self.reg_lambda
def beta(self, class_weight):
"""See super class."""
max_class_weight = self.max_class_weight(class_weight, self.dtype)
return self.C * max_class_weight + self.reg_lambda
def lipchitz_constant(self, class_weight):
"""See super class."""
max_class_weight = self.max_class_weight(class_weight, self.dtype)
return self.C * max_class_weight + self.reg_lambda * self.radius()
def kernel_regularizer(self):
"""Return l2 loss using 0.5*reg_lambda as the l2 term (as desired).
L2 regularization is required for this loss function to be strongly convex.
Returns:
The L2 regularizer layer for this loss function, with regularizer constant
set to half the 0.5 * reg_lambda.
"""
return L1L2(l2=self.reg_lambda/2)

View file

@ -0,0 +1,431 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit testing for losses."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from contextlib import contextmanager # pylint: disable=g-importing-member
from io import StringIO # pylint: disable=g-importing-member
import sys
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.regularizers import L1L2
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexBinaryCrossentropy
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexHuber
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin
@contextmanager
def captured_output():
"""Capture std_out and std_err within context."""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err
class StrongConvexMixinTests(keras_parameterized.TestCase):
"""Tests for the StrongConvexMixin."""
@parameterized.named_parameters([
{'testcase_name': 'beta not implemented',
'fn': 'beta',
'args': [1]},
{'testcase_name': 'gamma not implemented',
'fn': 'gamma',
'args': []},
{'testcase_name': 'lipchitz not implemented',
'fn': 'lipchitz_constant',
'args': [1]},
{'testcase_name': 'radius not implemented',
'fn': 'radius',
'args': []},
])
def test_not_implemented(self, fn, args):
"""Test that the given fn's are not implemented on the mixin.
Args:
fn: fn on Mixin to test
args: arguments to fn of Mixin
"""
with self.assertRaises(NotImplementedError):
loss = StrongConvexMixin()
getattr(loss, fn, None)(*args)
@parameterized.named_parameters([
{'testcase_name': 'radius not implemented',
'fn': 'kernel_regularizer',
'args': []},
])
def test_return_none(self, fn, args):
"""Test that fn of Mixin returns None.
Args:
fn: fn of Mixin to test
args: arguments to fn of Mixin
"""
loss = StrongConvexMixin()
ret = getattr(loss, fn, None)(*args)
self.assertEqual(ret, None)
class BinaryCrossesntropyTests(keras_parameterized.TestCase):
"""tests for BinaryCrossesntropy StrongConvex loss."""
@parameterized.named_parameters([
{'testcase_name': 'normal',
'reg_lambda': 1,
'C': 1,
'radius_constant': 1
}, # pylint: disable=invalid-name
])
def test_init_params(self, reg_lambda, C, radius_constant):
"""Test initialization for given arguments.
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable
loss = StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
self.assertIsInstance(loss, StrongConvexBinaryCrossentropy)
@parameterized.named_parameters([
{'testcase_name': 'negative c',
'reg_lambda': 1,
'C': -1,
'radius_constant': 1
},
{'testcase_name': 'negative radius',
'reg_lambda': 1,
'C': 1,
'radius_constant': -1
},
{'testcase_name': 'negative lambda',
'reg_lambda': -1,
'C': 1,
'radius_constant': 1
}, # pylint: disable=invalid-name
])
def test_bad_init_params(self, reg_lambda, C, radius_constant):
"""Test invalid domain for given params. Should return ValueError.
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable
with self.assertRaises(ValueError):
StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
# [] for compatibility with tensorflow loss calculation
{'testcase_name': 'both positive',
'logits': [10000],
'y_true': [1],
'result': 0,
},
{'testcase_name': 'positive gradient negative logits',
'logits': [-10000],
'y_true': [1],
'result': 10000,
},
{'testcase_name': 'positivee gradient positive logits',
'logits': [10000],
'y_true': [0],
'result': 10000,
},
{'testcase_name': 'both negative',
'logits': [-10000],
'y_true': [0],
'result': 0
},
])
def test_calculation(self, logits, y_true, result):
"""Test the call method to ensure it returns the correct value.
Args:
logits: unscaled output of model
y_true: label
result: correct loss calculation value
"""
logits = tf.Variable(logits, False, dtype=tf.float32)
y_true = tf.Variable(y_true, False, dtype=tf.float32)
loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1)
loss = loss(y_true, logits)
self.assertEqual(loss.numpy(), result)
@parameterized.named_parameters([
{'testcase_name': 'beta',
'init_args': [1, 1, 1],
'fn': 'beta',
'args': [1],
'result': tf.constant(2, dtype=tf.float32)
},
{'testcase_name': 'gamma',
'fn': 'gamma',
'init_args': [1, 1, 1],
'args': [],
'result': tf.constant(1, dtype=tf.float32),
},
{'testcase_name': 'lipchitz constant',
'fn': 'lipchitz_constant',
'init_args': [1, 1, 1],
'args': [1],
'result': tf.constant(2, dtype=tf.float32),
},
{'testcase_name': 'kernel regularizer',
'fn': 'kernel_regularizer',
'init_args': [1, 1, 1],
'args': [],
'result': L1L2(l2=0.5),
},
])
def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result.
Args:
init_args: init values for loss instance
fn: the fn to test
args: the arguments to above function
result: the correct result from the fn
"""
loss = StrongConvexBinaryCrossentropy(*init_args)
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
expected = expected.numpy()
result = result.numpy()
if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer
expected = expected.l2
result = result.l2
self.assertEqual(expected, result)
@parameterized.named_parameters([
{'testcase_name': 'label_smoothing',
'init_args': [1, 1, 1, True, 0.1],
'fn': None,
'args': None,
'print_res': 'The impact of label smoothing on privacy is unknown.'
},
])
def test_prints(self, init_args, fn, args, print_res):
"""Test logger warning from StrongConvexBinaryCrossentropy.
Args:
init_args: arguments to init the object with.
fn: function to test
args: arguments to above function
print_res: print result that should have been printed.
"""
with captured_output() as (out, err): # pylint: disable=unused-variable
loss = StrongConvexBinaryCrossentropy(*init_args)
if fn is not None:
getattr(loss, fn, lambda *arguments: print('error'))(*args)
self.assertRegexMatch(err.getvalue().strip(), [print_res])
class HuberTests(keras_parameterized.TestCase):
"""tests for BinaryCrossesntropy StrongConvex loss."""
@parameterized.named_parameters([
{'testcase_name': 'normal',
'reg_lambda': 1,
'c': 1,
'radius_constant': 1,
'delta': 1,
},
])
def test_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test initialization for given arguments.
Args:
reg_lambda: initialization value for reg_lambda arg
c: initialization value for C arg
radius_constant: initialization value for radius_constant arg
delta: the delta parameter for the huber loss
"""
# test valid domains for each variable
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta)
self.assertIsInstance(loss, StrongConvexHuber)
@parameterized.named_parameters([
{'testcase_name': 'negative c',
'reg_lambda': 1,
'c': -1,
'radius_constant': 1,
'delta': 1
},
{'testcase_name': 'negative radius',
'reg_lambda': 1,
'c': 1,
'radius_constant': -1,
'delta': 1
},
{'testcase_name': 'negative lambda',
'reg_lambda': -1,
'c': 1,
'radius_constant': 1,
'delta': 1
},
{'testcase_name': 'negative delta',
'reg_lambda': 1,
'c': 1,
'radius_constant': 1,
'delta': -1
},
])
def test_bad_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test invalid domain for given params. Should return ValueError.
Args:
reg_lambda: initialization value for reg_lambda arg
c: initialization value for C arg
radius_constant: initialization value for radius_constant arg
delta: the delta parameter for the huber loss
"""
# test valid domains for each variable
with self.assertRaises(ValueError):
StrongConvexHuber(reg_lambda, c, radius_constant, delta)
# test the bounds and test varied delta's
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary',
'logits': 2.1,
'y_true': 1,
'delta': 1,
'result': 0,
},
{'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary',
'logits': 1.9,
'y_true': 1,
'delta': 1,
'result': 0.01*0.25,
},
{'testcase_name': 'delta=1,y_true=1 1-z< h decision boundary',
'logits': 0.1,
'y_true': 1,
'delta': 1,
'result': 1.9**2 * 0.25,
},
{'testcase_name': 'delta=1,y_true=1 z < 1-h decision boundary',
'logits': -0.1,
'y_true': 1,
'delta': 1,
'result': 1.1,
},
{'testcase_name': 'delta=2,y_true=1 z>1+h decision boundary',
'logits': 3.1,
'y_true': 1,
'delta': 2,
'result': 0,
},
{'testcase_name': 'delta=2,y_true=1 z<1+h decision boundary',
'logits': 2.9,
'y_true': 1,
'delta': 2,
'result': 0.01*0.125,
},
{'testcase_name': 'delta=2,y_true=1 1-z < h decision boundary',
'logits': 1.1,
'y_true': 1,
'delta': 2,
'result': 1.9**2 * 0.125,
},
{'testcase_name': 'delta=2,y_true=1 z < 1-h decision boundary',
'logits': -1.1,
'y_true': 1,
'delta': 2,
'result': 2.1,
},
{'testcase_name': 'delta=1,y_true=-1 z>1+h decision boundary',
'logits': -2.1,
'y_true': -1,
'delta': 1,
'result': 0,
},
])
def test_calculation(self, logits, y_true, delta, result):
"""Test the call method to ensure it returns the correct value.
Args:
logits: unscaled output of model
y_true: label
delta: delta value for StrongConvexHuber loss.
result: correct loss calculation value
"""
logits = tf.Variable(logits, False, dtype=tf.float32)
y_true = tf.Variable(y_true, False, dtype=tf.float32)
loss = StrongConvexHuber(0.00001, 1, 1, delta)
loss = loss(y_true, logits)
self.assertAllClose(loss.numpy(), result)
@parameterized.named_parameters([
{'testcase_name': 'beta',
'init_args': [1, 1, 1, 1],
'fn': 'beta',
'args': [1],
'result': tf.Variable(1.5, dtype=tf.float32)
},
{'testcase_name': 'gamma',
'fn': 'gamma',
'init_args': [1, 1, 1, 1],
'args': [],
'result': tf.Variable(1, dtype=tf.float32),
},
{'testcase_name': 'lipchitz constant',
'fn': 'lipchitz_constant',
'init_args': [1, 1, 1, 1],
'args': [1],
'result': tf.Variable(2, dtype=tf.float32),
},
{'testcase_name': 'kernel regularizer',
'fn': 'kernel_regularizer',
'init_args': [1, 1, 1, 1],
'args': [],
'result': L1L2(l2=0.5),
},
])
def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result.
Args:
init_args: init values for loss instance
fn: the fn to test
args: the arguments to above function
result: the correct result from the fn
"""
loss = StrongConvexHuber(*init_args)
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
expected = expected.numpy()
result = result.numpy()
if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer
expected = expected.l2
result = result.l2
self.assertEqual(expected, result)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,303 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BoltOn model for Bolt-on method of differentially private ML."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.framework import ops as _ops
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.models import Model
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin
from tensorflow_privacy.privacy.bolt_on.optimizers import BoltOn
class BoltOnModel(Model): # pylint: disable=abstract-method
"""BoltOn episilon-delta differential privacy model.
The privacy guarantees are dependent on the noise that is sampled. Please
see the paper linked below for more details.
Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et al.
"""
def __init__(self,
n_outputs,
seed=1,
dtype=tf.float32):
"""Private constructor.
Args:
n_outputs: number of output classes to predict.
seed: random seed to use
dtype: data type to use for tensors
"""
super(BoltOnModel, self).__init__(name='bolton', dynamic=False)
if n_outputs <= 0:
raise ValueError('n_outputs = {0} is not valid. Must be > 0.'.format(
n_outputs
))
self.n_outputs = n_outputs
self.seed = seed
self._layers_instantiated = False
self._dtype = dtype
def call(self, inputs): # pylint: disable=arguments-differ
"""Forward pass of network.
Args:
inputs: inputs to neural network
Returns:
Output logits for the given inputs.
"""
return self.output_layer(inputs)
def compile(self,
optimizer,
loss,
kernel_initializer=tf.initializers.GlorotUniform,
**kwargs): # pylint: disable=arguments-differ
"""See super class. Default optimizer used in BoltOn method is SGD.
Args:
optimizer: The optimizer to use. This will be automatically wrapped
with the BoltOn Optimizer.
loss: The loss function to use. Must be a StrongConvex loss (extend the
StrongConvexMixin).
kernel_initializer: The kernel initializer to use for the single layer.
**kwargs: kwargs to keras Model.compile. See super.
"""
if not isinstance(loss, StrongConvexMixin):
raise ValueError('loss function must be a Strongly Convex and therefore '
'extend the StrongConvexMixin.')
if not self._layers_instantiated: # compile may be called multiple times
# for instance, if the input/outputs are not defined until fit.
self.output_layer = tf.keras.layers.Dense(
self.n_outputs,
kernel_regularizer=loss.kernel_regularizer(),
kernel_initializer=kernel_initializer(),
)
self._layers_instantiated = True
if not isinstance(optimizer, BoltOn):
optimizer = optimizers.get(optimizer)
optimizer = BoltOn(optimizer, loss)
super(BoltOnModel, self).compile(optimizer, loss=loss, **kwargs)
def fit(self,
x=None,
y=None,
batch_size=None,
class_weight=None,
n_samples=None,
epsilon=2,
noise_distribution='laplace',
steps_per_epoch=None,
**kwargs): # pylint: disable=arguments-differ
"""Reroutes to super fit with BoltOn delta-epsilon privacy requirements.
Note, inputs must be normalized s.t. ||x|| < 1.
Requirements are as follows:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
See super implementation for more details.
Args:
x: Inputs to fit on, see super.
y: Labels to fit on, see super.
batch_size: The batch size to use for training, see super.
class_weight: the class weights to be used. Can be a scalar or 1D tensor
whose dim == n_classes.
n_samples: the number of individual samples in x.
epsilon: privacy parameter, which trades off between utility an privacy.
See the bolt-on paper for more description.
noise_distribution: the distribution to pull noise from.
steps_per_epoch:
**kwargs: kwargs to keras Model.fit. See super.
Returns:
Output from super fit method.
"""
if class_weight is None:
class_weight_ = self.calculate_class_weights(class_weight)
else:
class_weight_ = class_weight
if n_samples is not None:
data_size = n_samples
elif hasattr(x, 'shape'):
data_size = x.shape[0]
elif hasattr(x, '__len__'):
data_size = len(x)
else:
data_size = None
batch_size_ = self._validate_or_infer_batch_size(batch_size,
steps_per_epoch,
x)
if batch_size_ is None:
batch_size_ = 32
# inferring batch_size to be passed to optimizer. batch_size must remain its
# initial value when passed to super().fit()
if batch_size_ is None:
raise ValueError('batch_size: {0} is an '
'invalid value'.format(batch_size_))
if data_size is None:
raise ValueError('Could not infer the number of samples. Please pass '
'this in using n_samples.')
with self.optimizer(noise_distribution,
epsilon,
self.layers,
class_weight_,
data_size,
batch_size_) as _:
out = super(BoltOnModel, self).fit(x=x,
y=y,
batch_size=batch_size,
class_weight=class_weight,
steps_per_epoch=steps_per_epoch,
**kwargs)
return out
def fit_generator(self,
generator,
class_weight=None,
noise_distribution='laplace',
epsilon=2,
n_samples=None,
steps_per_epoch=None,
**kwargs): # pylint: disable=arguments-differ
"""Fit with a generator.
This method is the same as fit except for when the passed dataset
is a generator. See super method and fit for more details.
Args:
generator: Inputs generator following Tensorflow guidelines, see super.
class_weight: the class weights to be used. Can be a scalar or 1D tensor
whose dim == n_classes.
noise_distribution: the distribution to get noise from.
epsilon: privacy parameter, which trades off utility and privacy. See
BoltOn paper for more description.
n_samples: number of individual samples in x
steps_per_epoch: Number of steps per training epoch, see super.
**kwargs: **kwargs
Returns:
Output from super fit_generator method.
"""
if class_weight is None:
class_weight = self.calculate_class_weights(class_weight)
if n_samples is not None:
data_size = n_samples
elif hasattr(generator, 'shape'):
data_size = generator.shape[0]
elif hasattr(generator, '__len__'):
data_size = len(generator)
else:
raise ValueError('The number of samples could not be determined. '
'Please make sure that if you are using a generator'
'to call this method directly with n_samples kwarg '
'passed.')
batch_size = self._validate_or_infer_batch_size(None, steps_per_epoch,
generator)
if batch_size is None:
batch_size = 32
with self.optimizer(noise_distribution,
epsilon,
self.layers,
class_weight,
data_size,
batch_size) as _:
out = super(BoltOnModel, self).fit_generator(
generator,
class_weight=class_weight,
steps_per_epoch=steps_per_epoch,
**kwargs)
return out
def calculate_class_weights(self,
class_weights=None,
class_counts=None,
num_classes=None):
"""Calculates class weighting to be used in training.
Args:
class_weights: str specifying type, array giving weights, or None.
class_counts: If class_weights is not None, then an array of
the number of samples for each class
num_classes: If class_weights is not None, then the number of
classes.
Returns:
class_weights as 1D tensor, to be passed to model's fit method.
"""
# Value checking
class_keys = ['balanced']
is_string = False
if isinstance(class_weights, str):
is_string = True
if class_weights not in class_keys:
raise ValueError('Detected string class_weights with '
'value: {0}, which is not one of {1}.'
'Please select a valid class_weight type'
'or pass an array'.format(class_weights,
class_keys))
if class_counts is None:
raise ValueError('Class counts must be provided if using '
'class_weights=%s' % class_weights)
class_counts_shape = tf.Variable(class_counts,
trainable=False,
dtype=self._dtype).shape
if len(class_counts_shape) != 1:
raise ValueError('class counts must be a 1D array.'
'Detected: {0}'.format(class_counts_shape))
if num_classes is None:
raise ValueError('num_classes must be provided if using '
'class_weights=%s' % class_weights)
elif class_weights is not None:
if num_classes is None:
raise ValueError('You must pass a value for num_classes if '
'creating an array of class_weights')
# performing class weight calculation
if class_weights is None:
class_weights = 1
elif is_string and class_weights == 'balanced':
num_samples = sum(class_counts)
weighted_counts = tf.dtypes.cast(tf.math.multiply(num_classes,
class_counts),
self._dtype)
class_weights = tf.Variable(num_samples, dtype=self._dtype) / \
tf.Variable(weighted_counts, dtype=self._dtype)
else:
class_weights = _ops.convert_to_tensor_v2(class_weights)
if len(class_weights.shape) != 1:
raise ValueError('Detected class_weights shape: {0} instead of '
'1D array'.format(class_weights.shape))
if class_weights.shape[0] != num_classes:
raise ValueError(
'Detected array length: {0} instead of: {1}'.format(
class_weights.shape[0],
num_classes))
return class_weights

View file

@ -0,0 +1,548 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit testing for models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.framework import ops as _ops
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import losses
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras.regularizers import L1L2
from tensorflow_privacy.privacy.bolt_on import models
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin
from tensorflow_privacy.privacy.bolt_on.optimizers import BoltOn
class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing BoltOn model."""
def __init__(self, reg_lambda, c_arg, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
self.reg_lambda = reg_lambda
self.C = c_arg # pylint: disable=invalid-name
self.radius_constant = radius_constant
def radius(self):
"""Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space.
Returns:
radius
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def gamma(self):
"""Returns strongly convex parameter, gamma."""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def beta(self, class_weight): # pylint: disable=unused-argument
"""Smoothness, beta.
Args:
class_weight: the class weights as scalar or 1d tensor, where its
dimensionality is equal to the number of outputs.
Returns:
Beta
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
"""Lipchitz constant, L.
Args:
class_weight: class weights used
Returns:
L
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def call(self, y_true, y_pred):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(
tf.math.squared_difference(y_true, y_pred),
axis=1
)
def max_class_weight(self, class_weight):
"""the maximum weighting in class weights (max value) as a scalar tensor.
Args:
class_weight: class weights used
Returns:
maximum class weighting as tensor scalar
"""
if class_weight is None:
return 1
raise ValueError('')
def kernel_regularizer(self):
"""Returns the kernel_regularizer to be used.
Any subclass should override this method if they want a kernel_regularizer
(if required for the loss function to be StronglyConvex.
"""
return L1L2(l2=self.reg_lambda)
class TestOptimizer(OptimizerV2):
"""Test optimizer used for testing BoltOn model."""
def __init__(self):
super(TestOptimizer, self).__init__('test')
def compute_gradients(self):
return 0
def get_config(self):
return {}
def _create_slots(self, var):
pass
def _resource_apply_dense(self, grad, handle):
return grad
def _resource_apply_sparse(self, grad, handle, indices):
return grad
class InitTests(keras_parameterized.TestCase):
"""Tests for keras model initialization."""
@parameterized.named_parameters([
{'testcase_name': 'normal',
'n_outputs': 1,
},
{'testcase_name': 'many outputs',
'n_outputs': 100,
},
])
def test_init_params(self, n_outputs):
"""Test initialization of BoltOnModel.
Args:
n_outputs: number of output neurons
"""
# test valid domains for each variable
clf = models.BoltOnModel(n_outputs)
self.assertIsInstance(clf, models.BoltOnModel)
@parameterized.named_parameters([
{'testcase_name': 'invalid n_outputs',
'n_outputs': -1,
},
])
def test_bad_init_params(self, n_outputs):
"""test bad initializations of BoltOnModel that should raise errors.
Args:
n_outputs: number of output neurons
"""
# test invalid domains for each variable, especially noise
with self.assertRaises(ValueError):
models.BoltOnModel(n_outputs)
@parameterized.named_parameters([
{'testcase_name': 'string compile',
'n_outputs': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': 'adam',
},
{'testcase_name': 'test compile',
'n_outputs': 100,
'loss': TestLoss(1, 1, 1),
'optimizer': TestOptimizer(),
},
])
def test_compile(self, n_outputs, loss, optimizer):
"""Test compilation of BoltOnModel.
Args:
n_outputs: number of output neurons
loss: instantiated TestLoss instance
optimizer: instantiated TestOptimizer instance
"""
# test compilation of valid tf.optimizer and tf.loss
with self.cached_session():
clf = models.BoltOnModel(n_outputs)
clf.compile(optimizer, loss)
self.assertEqual(clf.loss, loss)
@parameterized.named_parameters([
{'testcase_name': 'Not strong loss',
'n_outputs': 1,
'loss': losses.BinaryCrossentropy(),
'optimizer': 'adam',
},
{'testcase_name': 'Not valid optimizer',
'n_outputs': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': 'ada',
}
])
def test_bad_compile(self, n_outputs, loss, optimizer):
"""test bad compilations of BoltOnModel that should raise errors.
Args:
n_outputs: number of output neurons
loss: instantiated TestLoss instance
optimizer: instantiated TestOptimizer instance
"""
# test compilaton of invalid tf.optimizer and non instantiated loss.
with self.cached_session():
with self.assertRaises((ValueError, AttributeError)):
clf = models.BoltOnModel(n_outputs)
clf.compile(optimizer, loss)
def _cat_dataset(n_samples, input_dim, n_classes, batch_size, generator=False):
"""Creates a categorically encoded dataset.
Creates a categorically encoded dataset (y is categorical).
returns the specified dataset either as a static array or as a generator.
Will have evenly split samples across each output class.
Each output class will be a different point in the input space.
Args:
n_samples: number of rows
input_dim: input dimensionality
n_classes: output dimensionality
batch_size: The desired batch_size
generator: False for array, True for generator
Returns:
X as (n_samples, input_dim), Y as (n_samples, n_outputs)
"""
x_stack = []
y_stack = []
for i_class in range(n_classes):
x_stack.append(
tf.constant(1*i_class, tf.float32, (n_samples, input_dim))
)
y_stack.append(
tf.constant(i_class, tf.float32, (n_samples, n_classes))
)
x_set, y_set = tf.stack(x_stack), tf.stack(y_stack)
if generator:
dataset = tf.data.Dataset.from_tensor_slices(
(x_set, y_set)
)
dataset = dataset.batch(batch_size=batch_size)
return dataset
return x_set, y_set
def _do_fit(n_samples,
input_dim,
n_outputs,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
distribution='laplace'):
"""Instantiate necessary components for fitting and perform a model fit.
Args:
n_samples: number of samples in dataset
input_dim: the sample dimensionality
n_outputs: number of output neurons
epsilon: privacy parameter
generator: True to create a generator, False to use an iterator
batch_size: batch_size to use
reset_n_samples: True to set _samples to None prior to fitting.
False does nothing
optimizer: instance of TestOptimizer
loss: instance of TestLoss
distribution: distribution to get noise from.
Returns:
BoltOnModel instsance
"""
clf = models.BoltOnModel(n_outputs)
clf.compile(optimizer, loss)
if generator:
x = _cat_dataset(
n_samples,
input_dim,
n_outputs,
batch_size,
generator=generator
)
y = None
# x = x.batch(batch_size)
x = x.shuffle(n_samples//2)
batch_size = None
if reset_n_samples:
n_samples = None
clf.fit_generator(x,
n_samples=n_samples,
noise_distribution=distribution,
epsilon=epsilon)
else:
x, y = _cat_dataset(
n_samples,
input_dim,
n_outputs,
batch_size,
generator=generator)
if reset_n_samples:
n_samples = None
clf.fit(x,
y,
batch_size=batch_size,
n_samples=n_samples,
noise_distribution=distribution,
epsilon=epsilon)
return clf
class FitTests(keras_parameterized.TestCase):
"""Test cases for keras model fitting."""
# @test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'iterator fit',
'generator': False,
'reset_n_samples': True,
},
{'testcase_name': 'iterator fit no samples',
'generator': False,
'reset_n_samples': True,
},
{'testcase_name': 'generator fit',
'generator': True,
'reset_n_samples': False,
},
{'testcase_name': 'with callbacks',
'generator': True,
'reset_n_samples': False,
},
])
def test_fit(self, generator, reset_n_samples):
"""Tests fitting of BoltOnModel.
Args:
generator: True for generator test, False for iterator test.
reset_n_samples: True to reset the n_samples to None, False does nothing
"""
loss = TestLoss(1, 1, 1)
optimizer = BoltOn(TestOptimizer(), loss)
n_classes = 2
input_dim = 5
epsilon = 1
batch_size = 1
n_samples = 10
clf = _do_fit(
n_samples,
input_dim,
n_classes,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
)
self.assertEqual(hasattr(clf, 'layers'), True)
@parameterized.named_parameters([
{'testcase_name': 'generator fit',
'generator': True,
},
])
def test_fit_gen(self, generator):
"""Tests the fit_generator method of BoltOnModel.
Args:
generator: True to test with a generator dataset
"""
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
n_classes = 2
input_dim = 5
batch_size = 1
n_samples = 10
clf = models.BoltOnModel(n_classes)
clf.compile(optimizer, loss)
x = _cat_dataset(
n_samples,
input_dim,
n_classes,
batch_size,
generator=generator
)
x = x.batch(batch_size)
x = x.shuffle(n_samples // 2)
clf.fit_generator(x, n_samples=n_samples)
self.assertEqual(hasattr(clf, 'layers'), True)
@parameterized.named_parameters([
{'testcase_name': 'iterator no n_samples',
'generator': True,
'reset_n_samples': True,
'distribution': 'laplace'
},
{'testcase_name': 'invalid distribution',
'generator': True,
'reset_n_samples': True,
'distribution': 'not_valid'
},
])
def test_bad_fit(self, generator, reset_n_samples, distribution):
"""Tests fitting with invalid parameters, which should raise an error.
Args:
generator: True to test with generator, False is iterator
reset_n_samples: True to reset the n_samples param to None prior to
passing it to fit
distribution: distribution to get noise from.
"""
with self.assertRaises(ValueError):
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
n_classes = 2
input_dim = 5
epsilon = 1
batch_size = 1
n_samples = 10
_do_fit(
n_samples,
input_dim,
n_classes,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
distribution
)
@parameterized.named_parameters([
{'testcase_name': 'None class_weights',
'class_weights': None,
'class_counts': None,
'num_classes': None,
'result': 1},
{'testcase_name': 'class weights array',
'class_weights': [1, 1],
'class_counts': [1, 1],
'num_classes': 2,
'result': [1, 1]},
{'testcase_name': 'class weights balanced',
'class_weights': 'balanced',
'class_counts': [1, 1],
'num_classes': 2,
'result': [1, 1]},
])
def test_class_calculate(self,
class_weights,
class_counts,
num_classes,
result):
"""Tests the BOltonModel calculate_class_weights method.
Args:
class_weights: the class_weights to use
class_counts: count of number of samples for each class
num_classes: number of outputs neurons
result: expected result
"""
clf = models.BoltOnModel(1, 1)
expected = clf.calculate_class_weights(class_weights,
class_counts,
num_classes)
if hasattr(expected, 'numpy'):
expected = expected.numpy()
self.assertAllEqual(
expected,
result
)
@parameterized.named_parameters([
{'testcase_name': 'class weight not valid str',
'class_weights': 'not_valid',
'class_counts': 1,
'num_classes': 1,
'err_msg': 'Detected string class_weights with value: not_valid'},
{'testcase_name': 'no class counts',
'class_weights': 'balanced',
'class_counts': None,
'num_classes': 1,
'err_msg': 'Class counts must be provided if '
'using class_weights=balanced'},
{'testcase_name': 'no num classes',
'class_weights': 'balanced',
'class_counts': [1],
'num_classes': None,
'err_msg': 'num_classes must be provided if '
'using class_weights=balanced'},
{'testcase_name': 'class counts not array',
'class_weights': 'balanced',
'class_counts': 1,
'num_classes': None,
'err_msg': 'class counts must be a 1D array.'},
{'testcase_name': 'class counts array, no num classes',
'class_weights': [1],
'class_counts': None,
'num_classes': None,
'err_msg': 'You must pass a value for num_classes if '
'creating an array of class_weights'},
{'testcase_name': 'class counts array, improper shape',
'class_weights': [[1], [1]],
'class_counts': None,
'num_classes': 2,
'err_msg': 'Detected class_weights shape'},
{'testcase_name': 'class counts array, wrong number classes',
'class_weights': [1, 1, 1],
'class_counts': None,
'num_classes': 2,
'err_msg': 'Detected array length:'},
])
def test_class_errors(self,
class_weights,
class_counts,
num_classes,
err_msg):
"""Tests the BOltonModel calculate_class_weights method.
This test passes invalid params which should raise the expected errors.
Args:
class_weights: the class_weights to use.
class_counts: count of number of samples for each class.
num_classes: number of outputs neurons.
err_msg: The expected error message.
"""
clf = models.BoltOnModel(1, 1)
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
clf.calculate_class_weights(class_weights,
class_counts,
num_classes)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,388 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BoltOn Optimizer for Bolt-on method."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import math_ops
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin
_accepted_distributions = ['laplace'] # implemented distributions for noising
class GammaBetaDecreasingStep(
optimizer_v2.learning_rate_schedule.LearningRateSchedule):
"""Computes LR as minimum of 1/beta and 1/(gamma * step) at each step.
This is a required step for privacy guarantees.
"""
def __init__(self):
self.is_init = False
self.beta = None
self.gamma = None
def __call__(self, step):
"""Computes and returns the learning rate.
Args:
step: the current iteration number
Returns:
decayed learning rate to minimum of 1/beta and 1/(gamma * step) as per
the BoltOn privacy requirements.
"""
if not self.is_init:
raise AttributeError('Please initialize the {0} Learning Rate Scheduler.'
'This is performed automatically by using the '
'{1} as a context manager, '
'as desired'.format(self.__class__.__name__,
BoltOn.__class__.__name__
)
)
dtype = self.beta.dtype
one = tf.constant(1, dtype)
return tf.math.minimum(tf.math.reduce_min(one/self.beta),
one/(self.gamma*math_ops.cast(step, dtype))
)
def get_config(self):
"""Return config to setup the learning rate scheduler."""
return {'beta': self.beta, 'gamma': self.gamma}
def initialize(self, beta, gamma):
"""Setups scheduler with beta and gamma values from the loss function.
Meant to be used with .fit as the loss params may depend on values passed to
fit.
Args:
beta: Smoothness value. See StrongConvexMixin
gamma: Strong Convexity parameter. See StrongConvexMixin.
"""
self.is_init = True
self.beta = beta
self.gamma = gamma
def de_initialize(self):
"""De initialize post fit, as another fit call may use other parameters."""
self.is_init = False
self.beta = None
self.gamma = None
class BoltOn(optimizer_v2.OptimizerV2):
"""Wrap another tf optimizer with BoltOn privacy protocol.
BoltOn optimizer wraps another tf optimizer to be used
as the visible optimizer to the tf model. No matter the optimizer
passed, "BoltOn" enables the bolt-on model to control the learning rate
based on the strongly convex loss.
To use the BoltOn method, you must:
1. instantiate it with an instantiated tf optimizer and StrongConvexLoss.
2. use it as a context manager around your .fit method internals.
This can be accomplished by the following:
optimizer = tf.optimizers.SGD()
loss = privacy.bolt_on.losses.StrongConvexBinaryCrossentropy()
bolton = BoltOn(optimizer, loss)
with bolton(*args) as _:
model.fit()
The args required for the context manager can be found in the __call__
method.
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
"""
def __init__(self, # pylint: disable=super-init-not-called
optimizer,
loss,
dtype=tf.float32,
):
"""Constructor.
Args:
optimizer: Optimizer_v2 or subclass to be used as the optimizer
(wrapped).
loss: StrongConvexLoss function that the model is being compiled with.
dtype: dtype
"""
if not isinstance(loss, StrongConvexMixin):
raise ValueError('loss function must be a Strongly Convex and therefore '
'extend the StrongConvexMixin.')
self._private_attributes = [
'_internal_optimizer',
'dtype',
'noise_distribution',
'epsilon',
'loss',
'class_weights',
'input_dim',
'n_samples',
'layers',
'batch_size',
'_is_init',
]
self._internal_optimizer = optimizer
self.learning_rate = GammaBetaDecreasingStep() # use the BoltOn Learning
# rate scheduler, as required for privacy guarantees. This will still need
# to get values from the loss function near the time that .fit is called
# on the model (when this optimizer will be called as a context manager)
self.dtype = dtype
self.loss = loss
self._is_init = False
def get_config(self):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer.get_config()
def project_weights_to_r(self, force=False):
"""Normalize the weights to the R-ball.
Args:
force: True to normalize regardless of previous weight values.
False to check if weights > R-ball and only normalize then.
Raises:
Exception: If not called from inside this optimizer context.
"""
if not self._is_init:
raise Exception('This method must be called from within the optimizer\'s '
'context.')
radius = self.loss.radius()
for layer in self.layers:
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / radius)
else:
layer.kernel = tf.cond(
tf.reduce_sum(tf.cast(weight_norm > radius, dtype=self.dtype)) > 0,
lambda k=layer.kernel, w=weight_norm, r=radius: k / (w / r), # pylint: disable=cell-var-from-loop
lambda k=layer.kernel: k # pylint: disable=cell-var-from-loop
)
def get_noise(self, input_dim, output_dim):
"""Sample noise to be added to weights for privacy guarantee.
Args:
input_dim: the input dimensionality for the weights
output_dim: the output dimensionality for the weights
Returns:
Noise in shape of layer's weights to be added to the weights.
Raises:
Exception: If not called from inside this optimizer's context.
"""
if not self._is_init:
raise Exception('This method must be called from within the optimizer\'s '
'context.')
loss = self.loss
distribution = self.noise_distribution.lower()
if distribution == _accepted_distributions[0]: # laplace
per_class_epsilon = self.epsilon / (output_dim)
l2_sensitivity = (2 *
loss.lipchitz_constant(self.class_weights)) / \
(loss.gamma() * self.n_samples * self.batch_size)
unit_vector = tf.random.normal(shape=(input_dim, output_dim),
mean=0,
seed=1,
stddev=1.0,
dtype=self.dtype)
unit_vector = unit_vector / tf.math.sqrt(
tf.reduce_sum(tf.math.square(unit_vector), axis=0)
)
beta = l2_sensitivity / per_class_epsilon
alpha = input_dim # input_dim
gamma = tf.random.gamma([output_dim],
alpha,
beta=1 / beta,
seed=1,
dtype=self.dtype
)
return unit_vector * gamma
raise NotImplementedError('Noise distribution: {0} is not '
'a valid distribution'.format(distribution))
def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer.from_config(*args, **kwargs)
def __getattr__(self, name):
"""Get attr.
return _internal_optimizer off self instance, and everything else
from the _internal_optimizer instance.
Args:
name: Name of attribute to get from this or aggregate optimizer.
Returns:
attribute from BoltOn if specified to come from self, else
from _internal_optimizer.
"""
if name == '_private_attributes' or name in self._private_attributes:
return getattr(self, name)
optim = object.__getattribute__(self, '_internal_optimizer')
try:
return object.__getattribute__(optim, name)
except AttributeError:
raise AttributeError(
"Neither '{0}' nor '{1}' object has attribute '{2}'"
"".format(self.__class__.__name__,
self._internal_optimizer.__class__.__name__,
name)
)
def __setattr__(self, key, value):
"""Set attribute to self instance if its the internal optimizer.
Reroute everything else to the _internal_optimizer.
Args:
key: attribute name
value: attribute value
"""
if key == '_private_attributes':
object.__setattr__(self, key, value)
elif key in self._private_attributes:
object.__setattr__(self, key, value)
else:
setattr(self._internal_optimizer, key, value)
def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access
def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access
def get_updates(self, loss, params):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
out = self._internal_optimizer.get_updates(loss, params)
self.project_weights_to_r()
return out
def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
out = self._internal_optimizer.apply_gradients(*args, **kwargs)
self.project_weights_to_r()
return out
def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
out = self._internal_optimizer.minimize(*args, **kwargs)
self.project_weights_to_r()
return out
def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access
def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer.get_gradients(*args, **kwargs)
def __enter__(self):
"""Context manager call at the beginning of with statement.
Returns:
self, to be used in context manager
"""
self._is_init = True
return self
def __call__(self,
noise_distribution,
epsilon,
layers,
class_weights,
n_samples,
batch_size):
"""Accepts required values for bolton method from context entry point.
Stores them on the optimizer for use throughout fitting.
Args:
noise_distribution: the noise distribution to pick.
see _accepted_distributions and get_noise for possible values.
epsilon: privacy parameter. Lower gives more privacy but less utility.
layers: list of Keras/Tensorflow layers. Can be found as model.layers
class_weights: class_weights used, which may either be a scalar or 1D
tensor with dim == n_classes.
n_samples: number of rows/individual samples in the training set
batch_size: batch size used.
Returns:
self, to be used by the __enter__ method for context.
"""
if epsilon <= 0:
raise ValueError('Detected epsilon: {0}. '
'Valid range is 0 < epsilon <inf'.format(epsilon))
if noise_distribution not in _accepted_distributions:
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
'distributions'.format(noise_distribution,
_accepted_distributions))
self.noise_distribution = noise_distribution
self.learning_rate.initialize(self.loss.beta(class_weights),
self.loss.gamma())
self.epsilon = tf.constant(epsilon, dtype=self.dtype)
self.class_weights = tf.constant(class_weights, dtype=self.dtype)
self.n_samples = tf.constant(n_samples, dtype=self.dtype)
self.layers = layers
self.batch_size = tf.constant(batch_size, dtype=self.dtype)
return self
def __exit__(self, *args):
"""Exit call from with statement.
Used to:
1.reset the model and fit parameters passed to the optimizer
to enable the BoltOn Privacy guarantees. These are reset to ensure
that any future calls to fit with the same instance of the optimizer
will properly error out.
2.call post-fit methods normalizing/projecting the model weights and
adding noise to the weights.
Args:
*args: encompasses the type, value, and traceback values which are unused.
"""
self.project_weights_to_r(True)
for layer in self.layers:
input_dim = layer.kernel.shape[0]
output_dim = layer.units
noise = self.get_noise(input_dim,
output_dim,
)
layer.kernel = tf.math.add(layer.kernel, noise)
self.noise_distribution = None
self.learning_rate.de_initialize()
self.epsilon = -1
self.batch_size = -1
self.class_weights = None
self.n_samples = None
self.input_dim = None
self.layers = None
self._is_init = False

View file

@ -0,0 +1,579 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit testing for optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python import ops as _ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import losses
from tensorflow.python.keras.initializers import constant
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras.regularizers import L1L2
from tensorflow.python.platform import test
from tensorflow_privacy.privacy.bolt_on import optimizers as opt
from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin
class TestModel(Model): # pylint: disable=abstract-method
"""BoltOn episilon-delta model.
Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
"""
def __init__(self, n_outputs=2, input_shape=(16,), init_value=2):
"""Constructor.
Args:
n_outputs: number of output neurons
input_shape:
init_value:
"""
super(TestModel, self).__init__(name='bolton', dynamic=False)
self.n_outputs = n_outputs
self.layer_input_shape = input_shape
self.output_layer = tf.keras.layers.Dense(
self.n_outputs,
input_shape=self.layer_input_shape,
kernel_regularizer=L1L2(l2=1),
kernel_initializer=constant(init_value),
)
class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing BoltOn model."""
def __init__(self, reg_lambda, c_arg, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
self.reg_lambda = reg_lambda
self.C = c_arg # pylint: disable=invalid-name
self.radius_constant = radius_constant
def radius(self):
"""Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space.
Returns:
a tensor
"""
return _ops.convert_to_tensor_v2(self.radius_constant, dtype=tf.float32)
def gamma(self):
"""Returns strongly convex parameter, gamma."""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def beta(self, class_weight): # pylint: disable=unused-argument
"""Smoothness, beta.
Args:
class_weight: the class weights as scalar or 1d tensor, where its
dimensionality is equal to the number of outputs.
Returns:
Beta
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
"""Lipchitz constant, L.
Args:
class_weight: class weights used
Returns:
constant L
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def call(self, y_true, y_pred):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(
tf.math.squared_difference(y_true, y_pred),
axis=1
)
def max_class_weight(self, class_weight, dtype=tf.float32):
"""the maximum weighting in class weights (max value) as a scalar tensor.
Args:
class_weight: class weights used
dtype: the data type for tensor conversions.
Returns:
maximum class weighting as tensor scalar
"""
if class_weight is None:
return 1
raise NotImplementedError('')
def kernel_regularizer(self):
"""Returns the kernel_regularizer to be used.
Any subclass should override this method if they want a kernel_regularizer
(if required for the loss function to be StronglyConvex.
"""
return L1L2(l2=self.reg_lambda)
class TestOptimizer(OptimizerV2):
"""Optimizer used for testing the BoltOn optimizer."""
def __init__(self):
super(TestOptimizer, self).__init__('test')
self.not_private = 'test'
self.iterations = tf.constant(1, dtype=tf.float32)
self._iterations = tf.constant(1, dtype=tf.float32)
def _compute_gradients(self, loss, var_list, grad_loss=None):
return 'test'
def get_config(self):
return 'test'
def from_config(self, config, custom_objects=None):
return 'test'
def _create_slots(self):
return 'test'
def _resource_apply_dense(self, grad, handle):
return 'test'
def _resource_apply_sparse(self, grad, handle, indices):
return 'test'
def get_updates(self, loss, params):
return 'test'
def apply_gradients(self, grads_and_vars, name=None):
return 'test'
def minimize(self, loss, var_list, grad_loss=None, name=None):
return 'test'
def get_gradients(self, loss, params):
return 'test'
def limit_learning_rate(self):
return 'test'
class BoltonOptimizerTest(keras_parameterized.TestCase):
"""BoltOn Optimizer tests."""
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'getattr',
'fn': '__getattr__',
'args': ['dtype'],
'result': tf.float32,
'test_attr': None},
{'testcase_name': 'project_weights_to_r',
'fn': 'project_weights_to_r',
'args': ['dtype'],
'result': None,
'test_attr': ''},
])
def test_fn(self, fn, args, result, test_attr):
"""test that a fn of BoltOn optimizer is working as expected.
Args:
fn: method of Optimizer to test
args: args to optimizer fn
result: the expected result
test_attr: None if the fn returns the test result. Otherwise, this is
the attribute of BoltOn to check against result with.
"""
tf.random.set_seed(1)
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True # pylint: disable=protected-access
bolton.layers = model.layers
bolton.epsilon = 2
bolton.noise_distribution = 'laplace'
bolton.n_outputs = 1
bolton.n_samples = 1
res = getattr(bolton, fn, None)(*args)
if test_attr is not None:
res = getattr(bolton, test_attr, None)
if hasattr(res, 'numpy') and hasattr(result, 'numpy'): # both tensors/not
res = res.numpy()
result = result.numpy()
self.assertEqual(res, result)
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': '1 value project to r=1',
'r': 1,
'init_value': 2,
'shape': (1,),
'n_out': 1,
'result': [[1]]},
{'testcase_name': '2 value project to r=1',
'r': 1,
'init_value': 2,
'shape': (2,),
'n_out': 1,
'result': [[0.707107], [0.707107]]},
{'testcase_name': '1 value project to r=2',
'r': 2,
'init_value': 3,
'shape': (1,),
'n_out': 1,
'result': [[2]]},
{'testcase_name': 'no project',
'r': 2,
'init_value': 1,
'shape': (1,),
'n_out': 1,
'result': [[1]]},
])
def test_project(self, r, shape, n_out, init_value, result):
"""test that a fn of BoltOn optimizer is working as expected.
Args:
r: Radius value for StrongConvex loss function.
shape: input_dimensionality
n_out: output dimensionality
init_value: the initial value for 'constant' kernel initializer
result: the expected output after projection.
"""
tf.random.set_seed(1)
def project_fn(r):
loss = TestLoss(1, 1, r)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(n_out, shape, init_value)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True # pylint: disable=protected-access
bolton.layers = model.layers
bolton.epsilon = 2
bolton.noise_distribution = 'laplace'
bolton.n_outputs = 1
bolton.n_samples = 1
bolton.project_weights_to_r()
return _ops.convert_to_tensor_v2(bolton.layers[0].kernel, tf.float32)
res = project_fn(r)
self.assertAllClose(res, result)
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'normal values',
'epsilon': 2,
'noise': 'laplace',
'class_weights': 1},
])
def test_context_manager(self, noise, epsilon, class_weights):
"""Tests the context manager functionality of the optimizer.
Args:
noise: noise distribution to pick
epsilon: epsilon privacy parameter to use
class_weights: class_weights to use
"""
@tf.function
def test_run():
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
with bolton(noise, epsilon, model.layers, class_weights, 1, 1) as _:
pass
return _ops.convert_to_tensor_v2(bolton.epsilon, dtype=tf.float32)
epsilon = test_run()
self.assertEqual(epsilon.numpy(), -1)
@parameterized.named_parameters([
{'testcase_name': 'invalid noise',
'epsilon': 1,
'noise': 'not_valid',
'err_msg': 'Detected noise distribution: not_valid not one of:'},
{'testcase_name': 'invalid epsilon',
'epsilon': -1,
'noise': 'laplace',
'err_msg': 'Detected epsilon: -1. Valid range is 0 < epsilon <inf'},
])
def test_context_domains(self, noise, epsilon, err_msg):
"""Tests the context domains.
Args:
noise: noise distribution to pick
epsilon: epsilon privacy parameter to use
err_msg: the expected error message
"""
@tf.function
def test_run(noise, epsilon):
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
with bolton(noise, epsilon, model.layers, 1, 1, 1) as _:
pass
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
test_run(noise, epsilon)
@parameterized.named_parameters([
{'testcase_name': 'fn: get_noise',
'fn': 'get_noise',
'args': [1, 1],
'err_msg': 'This method must be called from within the '
'optimizer\'s context'},
])
def test_not_in_context(self, fn, args, err_msg):
"""Tests that the expected functions raise errors when not in context.
Args:
fn: the function to test
args: the arguments for said function
err_msg: expected error message
"""
def test_run(fn, args):
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
getattr(bolton, fn)(*args)
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
test_run(fn, args)
@parameterized.named_parameters([
{'testcase_name': 'fn: get_updates',
'fn': 'get_updates',
'args': [0, 0]},
{'testcase_name': 'fn: get_config',
'fn': 'get_config',
'args': []},
{'testcase_name': 'fn: from_config',
'fn': 'from_config',
'args': [0]},
{'testcase_name': 'fn: _resource_apply_dense',
'fn': '_resource_apply_dense',
'args': [1, 1]},
{'testcase_name': 'fn: _resource_apply_sparse',
'fn': '_resource_apply_sparse',
'args': [1, 1, 1]},
{'testcase_name': 'fn: apply_gradients',
'fn': 'apply_gradients',
'args': [1]},
{'testcase_name': 'fn: minimize',
'fn': 'minimize',
'args': [1, 1]},
{'testcase_name': 'fn: _compute_gradients',
'fn': '_compute_gradients',
'args': [1, 1]},
{'testcase_name': 'fn: get_gradients',
'fn': 'get_gradients',
'args': [1, 1]},
])
def test_rerouted_function(self, fn, args):
"""Tests rerouted function.
Tests that a method of the internal optimizer is correctly routed from
the BoltOn instance to the internal optimizer instance (TestOptimizer,
here).
Args:
fn: fn to test
args: arguments to that fn
"""
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
bolton = opt.BoltOn(optimizer, loss)
model = TestModel(3)
model.compile(optimizer, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True # pylint: disable=protected-access
bolton.layers = model.layers
bolton.epsilon = 2
bolton.noise_distribution = 'laplace'
bolton.n_outputs = 1
bolton.n_samples = 1
self.assertEqual(
getattr(bolton, fn, lambda: 'fn not found')(*args),
'test'
)
@parameterized.named_parameters([
{'testcase_name': 'fn: project_weights_to_r',
'fn': 'project_weights_to_r',
'args': []},
{'testcase_name': 'fn: get_noise',
'fn': 'get_noise',
'args': [1, 1]},
])
def test_not_reroute_fn(self, fn, args):
"""Test function is not rerouted.
Test that a fn that should not be rerouted to the internal optimizer is
in fact not rerouted.
Args:
fn: fn to test
args: arguments to that fn
"""
def test_run(fn, args):
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True # pylint: disable=protected-access
bolton.noise_distribution = 'laplace'
bolton.epsilon = 1
bolton.layers = model.layers
bolton.class_weights = 1
bolton.n_samples = 1
bolton.batch_size = 1
bolton.n_outputs = 1
res = getattr(bolton, fn, lambda: 'test')(*args)
if res != 'test':
res = 1
else:
res = 0
return _ops.convert_to_tensor_v2(res, dtype=tf.float32)
self.assertNotEqual(test_run(fn, args), 0)
@parameterized.named_parameters([
{'testcase_name': 'attr: _iterations',
'attr': '_iterations'}
])
def test_reroute_attr(self, attr):
"""Test a function is rerouted.
Test that attribute of internal optimizer is correctly rerouted to the
internal optimizer.
Args:
attr: attribute to test
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer()
optimizer = opt.BoltOn(internal_optimizer, loss)
self.assertEqual(getattr(optimizer, attr),
getattr(internal_optimizer, attr))
@parameterized.named_parameters([
{'testcase_name': 'attr does not exist',
'attr': '_not_valid'}
])
def test_attribute_error(self, attr):
"""Test rerouting of attributes.
Test that attribute of internal optimizer is correctly rerouted to the
internal optimizer
Args:
attr: attribute to test
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer()
optimizer = opt.BoltOn(internal_optimizer, loss)
with self.assertRaises(AttributeError):
getattr(optimizer, attr)
class SchedulerTest(keras_parameterized.TestCase):
"""GammaBeta Scheduler tests."""
@parameterized.named_parameters([
{'testcase_name': 'not in context',
'err_msg': 'Please initialize the GammaBetaDecreasingStep Learning Rate'
' Scheduler'
}
])
def test_bad_call(self, err_msg):
"""Test attribute of internal opt correctly rerouted to the internal opt.
Args:
err_msg: The expected error message from the scheduler bad call.
"""
scheduler = opt.GammaBetaDecreasingStep()
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
scheduler(1)
@parameterized.named_parameters([
{'testcase_name': 'step 1',
'step': 1,
'res': 0.5},
{'testcase_name': 'step 2',
'step': 2,
'res': 0.5},
{'testcase_name': 'step 3',
'step': 3,
'res': 0.333333333},
])
def test_call(self, step, res):
"""Test call.
Test that attribute of internal optimizer is correctly rerouted to the
internal optimizer
Args:
step: step number to 'GammaBetaDecreasingStep' 'Scheduler'.
res: expected result from call to 'GammaBetaDecreasingStep' 'Scheduler'.
"""
beta = _ops.convert_to_tensor_v2(2, dtype=tf.float32)
gamma = _ops.convert_to_tensor_v2(1, dtype=tf.float32)
scheduler = opt.GammaBetaDecreasingStep()
scheduler.initialize(beta, gamma)
step = _ops.convert_to_tensor_v2(step, dtype=tf.float32)
lr = scheduler(step)
self.assertAllClose(lr.numpy(), res)
if __name__ == '__main__':
test.main()
unittest.main()

View file

@ -0,0 +1,140 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "dp_query",
srcs = ["dp_query.py"],
deps = [
"//third_party/py/distutils",
"//third_party/py/tensorflow",
],
)
py_library(
name = "gaussian_query",
srcs = ["gaussian_query.py"],
deps = [
":dp_query",
":normalized_query",
"//third_party/py/distutils",
"//third_party/py/tensorflow",
],
)
py_test(
name = "gaussian_query_test",
size = "small",
srcs = ["gaussian_query_test.py"],
python_version = "PY2",
deps = [
":gaussian_query",
":test_utils",
"//third_party/py/absl/testing:parameterized",
"//third_party/py/numpy",
"//third_party/py/six",
"//third_party/py/tensorflow",
],
)
py_library(
name = "no_privacy_query",
srcs = ["no_privacy_query.py"],
deps = [
":dp_query",
"//third_party/py/distutils",
"//third_party/py/tensorflow",
],
)
py_test(
name = "no_privacy_query_test",
size = "small",
srcs = ["no_privacy_query_test.py"],
python_version = "PY2",
deps = [
":no_privacy_query",
":test_utils",
"//third_party/py/absl/testing:parameterized",
"//third_party/py/tensorflow",
],
)
py_library(
name = "normalized_query",
srcs = ["normalized_query.py"],
deps = [
":dp_query",
"//third_party/py/distutils",
"//third_party/py/tensorflow",
],
)
py_test(
name = "normalized_query_test",
size = "small",
srcs = ["normalized_query_test.py"],
python_version = "PY2",
deps = [
":gaussian_query",
":normalized_query",
":test_utils",
"//third_party/py/tensorflow",
],
)
py_library(
name = "nested_query",
srcs = ["nested_query.py"],
deps = [
":dp_query",
"//third_party/py/distutils",
"//third_party/py/tensorflow",
],
)
py_test(
name = "nested_query_test",
size = "small",
srcs = ["nested_query_test.py"],
python_version = "PY2",
deps = [
":gaussian_query",
":nested_query",
":test_utils",
"//third_party/py/absl/testing:parameterized",
"//third_party/py/distutils",
"//third_party/py/numpy",
"//third_party/py/tensorflow",
],
)
py_library(
name = "quantile_adaptive_clip_sum_query",
srcs = ["quantile_adaptive_clip_sum_query.py"],
deps = [
":dp_query",
":gaussian_query",
":normalized_query",
"//third_party/py/tensorflow",
],
)
py_test(
name = "quantile_adaptive_clip_sum_query_test",
srcs = ["quantile_adaptive_clip_sum_query_test.py"],
python_version = "PY2",
deps = [
":quantile_adaptive_clip_sum_query",
":test_utils",
"//third_party/py/numpy",
"//third_party/py/tensorflow",
"//third_party/py/tensorflow_privacy/privacy/analysis:privacy_ledger",
],
)
py_library(
name = "test_utils",
srcs = ["test_utils.py"],
deps = [],
)

View file

@ -0,0 +1,225 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An interface for differentially private query mechanisms.
The DPQuery class abstracts the differential privacy mechanism needed by DP-SGD.
The nomenclature is not specific to machine learning, but rather comes from
the differential privacy literature. Therefore, instead of talking about
examples, minibatches, and gradients, the code talks about records, samples and
queries. For more detail, please see the paper here:
https://arxiv.org/pdf/1812.06210.pdf
A common usage paradigm for this class is centralized DP-SGD training on a
fixed set of training examples, which we call "standard DP-SGD training."
In such training, SGD applies as usual by computing gradient updates from a set
of training examples that form a minibatch. However, each minibatch is broken
up into disjoint "microbatches." The gradient of each microbatch is computed
and clipped to a maximum norm, with the "records" for all such clipped gradients
forming a "sample" that constitutes the entire minibatch. Subsequently, that
sample can be "queried" to get an averaged, noised gradient update that can be
applied to model parameters.
In order to prevent inaccurate accounting of privacy parameters, the only
means of inspecting the gradients and updates of SGD training is via the use
of the below interfaces, and through the accumulation and querying of a
"sample state" abstraction. Thus, accessing data is indirect on purpose.
The DPQuery class also allows the use of a global state that may change between
samples. In the common situation where the privacy mechanism remains unchanged
throughout the entire training process, the global state is usually None.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
from distutils.version import LooseVersion
import tensorflow as tf
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
else:
nest = tf.nest
class DPQuery(object):
"""Interface for differentially private query mechanisms."""
__metaclass__ = abc.ABCMeta
def set_ledger(self, ledger):
"""Supplies privacy ledger to which the query can record privacy events.
Args:
ledger: A `PrivacyLedger`.
"""
del ledger
raise TypeError(
'DPQuery type %s does not support set_ledger.' % type(self).__name__)
def initial_global_state(self):
"""Returns the initial global state for the DPQuery."""
return ()
def derive_sample_params(self, global_state):
"""Given the global state, derives parameters to use for the next sample.
Args:
global_state: The current global state.
Returns:
Parameters to use to process records in the next sample.
"""
del global_state # unused.
return ()
@abc.abstractmethod
def initial_sample_state(self, template):
"""Returns an initial state to use for the next sample.
Args:
template: A nested structure of tensors, TensorSpecs, or numpy arrays used
as a template to create the initial sample state. It is assumed that the
leaves of the structure are python scalars or some type that has
properties `shape` and `dtype`.
Returns: An initial sample state.
"""
pass
def preprocess_record(self, params, record):
"""Preprocesses a single record.
This preprocessing is applied to one client's record, e.g. selecting vectors
and clipping them to a fixed L2 norm. This method can be executed in a
separate TF session, or even on a different machine, so it should not depend
on any TF inputs other than those provided as input arguments. In
particular, implementations should avoid accessing any TF tensors or
variables that are stored in self.
Args:
params: The parameters for the sample. In standard DP-SGD training,
the clipping norm for the sample's microbatch gradients (i.e.,
a maximum norm magnitude to which each gradient is clipped)
record: The record to be processed. In standard DP-SGD training,
the gradient computed for the examples in one microbatch, which
may be the gradient for just one example (for size 1 microbatches).
Returns:
A structure of tensors to be aggregated.
"""
del params # unused.
return record
@abc.abstractmethod
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record):
"""Accumulates a single preprocessed record into the sample state.
This method is intended to only do simple aggregation, typically just a sum.
In the future, we might remove this method and replace it with a way to
declaratively specify the type of aggregation required.
Args:
sample_state: The current sample state. In standard DP-SGD training,
the accumulated sum of previous clipped microbatch gradients.
preprocessed_record: The preprocessed record to accumulate.
Returns:
The updated sample state.
"""
pass
def accumulate_record(self, params, sample_state, record):
"""Accumulates a single record into the sample state.
This is a helper method that simply delegates to `preprocess_record` and
`accumulate_preprocessed_record` for the common case when both of those
functions run on a single device.
Args:
params: The parameters for the sample. In standard DP-SGD training,
the clipping norm for the sample's microbatch gradients (i.e.,
a maximum norm magnitude to which each gradient is clipped)
sample_state: The current sample state. In standard DP-SGD training,
the accumulated sum of previous clipped microbatch gradients.
record: The record to accumulate. In standard DP-SGD training,
the gradient computed for the examples in one microbatch, which
may be the gradient for just one example (for size 1 microbatches).
Returns:
The updated sample state. In standard DP-SGD training, the set of
previous mcrobatch gradients with the addition of the record argument.
"""
preprocessed_record = self.preprocess_record(params, record)
return self.accumulate_preprocessed_record(
sample_state, preprocessed_record)
@abc.abstractmethod
def merge_sample_states(self, sample_state_1, sample_state_2):
"""Merges two sample states into a single state.
Args:
sample_state_1: The first sample state to merge.
sample_state_2: The second sample state to merge.
Returns:
The merged sample state.
"""
pass
@abc.abstractmethod
def get_noised_result(self, sample_state, global_state):
"""Gets query result after all records of sample have been accumulated.
Args:
sample_state: The sample state after all records have been accumulated.
In standard DP-SGD training, the accumulated sum of clipped microbatch
gradients (in the special case of microbatches of size 1, the clipped
per-example gradients).
global_state: The global state, storing long-term privacy bookkeeping.
Returns:
A tuple (result, new_global_state) where "result" is the result of the
query and "new_global_state" is the updated global state. In standard
DP-SGD training, the result is a gradient update comprising a noised
average of the clipped gradients in the sample state---with the noise and
averaging performed in a manner that guarantees differential privacy.
"""
pass
def zeros_like(arg):
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
try:
arg = tf.convert_to_tensor(arg)
except TypeError:
pass
return tf.zeros(arg.shape, arg.dtype)
class SumAggregationDPQuery(DPQuery):
"""Base class for DPQueries that aggregate via sum."""
def initial_sample_state(self, template):
return nest.map_structure(zeros_like, template)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
return nest.map_structure(tf.add, sample_state, preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):
return nest.map_structure(tf.add, sample_state_1, sample_state_2)

View file

@ -0,0 +1,145 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for Gaussian average queries.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from distutils.version import LooseVersion
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import normalized_query
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
else:
nest = tf.nest
class GaussianSumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery interface for Gaussian sum queries.
Accumulates clipped vectors, then adds Gaussian noise to the sum.
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple(
'_GlobalState', ['l2_norm_clip', 'stddev'])
def __init__(self, l2_norm_clip, stddev):
"""Initializes the GaussianSumQuery.
Args:
l2_norm_clip: The clipping norm to apply to the global norm of each
record.
stddev: The stddev of the noise added to the sum.
"""
self._l2_norm_clip = l2_norm_clip
self._stddev = stddev
self._ledger = None
def set_ledger(self, ledger):
self._ledger = ledger
def make_global_state(self, l2_norm_clip, stddev):
"""Creates a global state from the given parameters."""
return self._GlobalState(tf.cast(l2_norm_clip, tf.float32),
tf.cast(stddev, tf.float32))
def initial_global_state(self):
return self.make_global_state(self._l2_norm_clip, self._stddev)
def derive_sample_params(self, global_state):
return global_state.l2_norm_clip
def initial_sample_state(self, template):
return nest.map_structure(
dp_query.zeros_like, template)
def preprocess_record_impl(self, params, record):
"""Clips the l2 norm, returning the clipped record and the l2 norm.
Args:
params: The parameters for the sample.
record: The record to be processed.
Returns:
A tuple (preprocessed_records, l2_norm) where `preprocessed_records` is
the structure of preprocessed tensors, and l2_norm is the total l2 norm
before clipping.
"""
l2_norm_clip = params
record_as_list = nest.flatten(record)
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
return nest.pack_sequence_as(record, clipped_as_list), norm
def preprocess_record(self, params, record):
preprocessed_record, _ = self.preprocess_record_impl(params, record)
return preprocessed_record
def get_noised_result(self, sample_state, global_state):
"""See base class."""
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
def add_noise(v):
return v + tf.random_normal(tf.shape(v), stddev=global_state.stddev)
else:
random_normal = tf.random_normal_initializer(stddev=global_state.stddev)
def add_noise(v):
return v + random_normal(tf.shape(v))
if self._ledger:
dependencies = [
self._ledger.record_sum_query(
global_state.l2_norm_clip, global_state.stddev)
]
else:
dependencies = []
with tf.control_dependencies(dependencies):
return nest.map_structure(add_noise, sample_state), global_state
class GaussianAverageQuery(normalized_query.NormalizedQuery):
"""Implements DPQuery interface for Gaussian average queries.
Accumulates clipped vectors, adds Gaussian noise, and normalizes.
Note that we use "fixed-denominator" estimation: the denominator should be
specified as the expected number of records per sample. Accumulating the
denominator separately would also be possible but would be produce a higher
variance estimator.
"""
def __init__(self,
l2_norm_clip,
sum_stddev,
denominator):
"""Initializes the GaussianAverageQuery.
Args:
l2_norm_clip: The clipping norm to apply to the global norm of each
record.
sum_stddev: The stddev of the noise added to the sum (before
normalization).
denominator: The normalization constant (applied after noise is added to
the sum).
"""
super(GaussianAverageQuery, self).__init__(
numerator_query=GaussianSumQuery(l2_norm_clip, sum_stddev),
denominator=denominator)

View file

@ -0,0 +1,161 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for GaussianAverageQuery."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from six.moves import xrange
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import test_utils
class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_gaussian_sum_no_clip_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([2.0, 0.0])
record2 = tf.constant([-1.0, 1.0])
query = gaussian_query.GaussianSumQuery(
l2_norm_clip=10.0, stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [1.0, 1.0]
self.assertAllClose(result, expected)
def test_gaussian_sum_with_clip_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
record2 = tf.constant([4.0, -3.0]) # Not clipped.
query = gaussian_query.GaussianSumQuery(
l2_norm_clip=5.0, stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [1.0, 1.0]
self.assertAllClose(result, expected)
def test_gaussian_sum_with_changing_clip_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
record2 = tf.constant([4.0, -3.0]) # Not clipped.
l2_norm_clip = tf.Variable(5.0)
l2_norm_clip_placeholder = tf.placeholder(tf.float32)
assign_l2_norm_clip = tf.assign(l2_norm_clip, l2_norm_clip_placeholder)
query = gaussian_query.GaussianSumQuery(
l2_norm_clip=l2_norm_clip, stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
self.evaluate(tf.global_variables_initializer())
result = sess.run(query_result)
expected = [1.0, 1.0]
self.assertAllClose(result, expected)
sess.run(assign_l2_norm_clip, {l2_norm_clip_placeholder: 0.0})
result = sess.run(query_result)
expected = [0.0, 0.0]
self.assertAllClose(result, expected)
def test_gaussian_sum_with_noise(self):
with self.cached_session() as sess:
record1, record2 = 2.71828, 3.14159
stddev = 1.0
query = gaussian_query.GaussianSumQuery(
l2_norm_clip=5.0, stddev=stddev)
query_result, _ = test_utils.run_query(query, [record1, record2])
noised_sums = []
for _ in xrange(1000):
noised_sums.append(sess.run(query_result))
result_stddev = np.std(noised_sums)
self.assertNear(result_stddev, stddev, 0.1)
def test_gaussian_sum_merge(self):
records1 = [tf.constant([2.0, 0.0]), tf.constant([-1.0, 1.0])]
records2 = [tf.constant([3.0, 5.0]), tf.constant([-1.0, 4.0])]
def get_sample_state(records):
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
sample_state = query.initial_sample_state(records[0])
for record in records:
sample_state = query.accumulate_record(params, sample_state, record)
return sample_state
sample_state_1 = get_sample_state(records1)
sample_state_2 = get_sample_state(records2)
merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states(
sample_state_1,
sample_state_2)
with self.cached_session() as sess:
result = sess.run(merged)
expected = [3.0, 10.0]
self.assertAllClose(result, expected)
def test_gaussian_average_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].
record2 = tf.constant([-1.0, 2.0]) # Not clipped.
query = gaussian_query.GaussianAverageQuery(
l2_norm_clip=3.0, sum_stddev=0.0, denominator=2.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected_average = [1.0, 1.0]
self.assertAllClose(result, expected_average)
def test_gaussian_average_with_noise(self):
with self.cached_session() as sess:
record1, record2 = 2.71828, 3.14159
sum_stddev = 1.0
denominator = 2.0
query = gaussian_query.GaussianAverageQuery(
l2_norm_clip=5.0, sum_stddev=sum_stddev, denominator=denominator)
query_result, _ = test_utils.run_query(query, [record1, record2])
noised_averages = []
for _ in range(1000):
noised_averages.append(sess.run(query_result))
result_stddev = np.std(noised_averages)
avg_stddev = sum_stddev / denominator
self.assertNear(result_stddev, avg_stddev, 0.1)
@parameterized.named_parameters(
('type_mismatch', [1.0], (1.0,), TypeError),
('too_few_on_left', [1.0], [1.0, 1.0], ValueError),
('too_few_on_right', [1.0, 1.0], [1.0], ValueError))
def test_incompatible_records(self, record1, record2, error_type):
query = gaussian_query.GaussianSumQuery(1.0, 0.0)
with self.assertRaises(error_type):
test_utils.run_query(query, [record1, record2])
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,116 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for queries over nested structures.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from distutils.version import LooseVersion
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
else:
nest = tf.nest
class NestedQuery(dp_query.DPQuery):
"""Implements DPQuery interface for structured queries.
NestedQuery evaluates arbitrary nested structures of queries. Records must be
nested structures of tensors that are compatible (in type and arity) with the
query structure, but are allowed to have deeper structure within each leaf of
the query structure. For example, the nested query [q1, q2] is compatible with
the record [t1, t2] or [t1, (t2, t3)], but not with (t1, t2), [t1] or
[t1, t2, t3]. The entire substructure of each record corresponding to a leaf
node of the query structure is routed to the corresponding query. If the same
tensor should be consumed by multiple sub-queries, it can be replicated in the
record, for example [t1, t1].
NestedQuery is intended to allow privacy mechanisms for groups as described in
[McMahan & Andrew, 2018: "A General Approach to Adding Differential Privacy to
Iterative Training Procedures" (https://arxiv.org/abs/1812.06210)].
"""
def __init__(self, queries):
"""Initializes the NestedQuery.
Args:
queries: A nested structure of queries.
"""
self._queries = queries
def _map_to_queries(self, fn, *inputs, **kwargs):
def caller(query, *args):
return getattr(query, fn)(*args, **kwargs)
return nest.map_structure_up_to(
self._queries, caller, self._queries, *inputs)
def set_ledger(self, ledger):
self._map_to_queries('set_ledger', ledger=ledger)
def initial_global_state(self):
"""See base class."""
return self._map_to_queries('initial_global_state')
def derive_sample_params(self, global_state):
"""See base class."""
return self._map_to_queries('derive_sample_params', global_state)
def initial_sample_state(self, template):
"""See base class."""
return self._map_to_queries('initial_sample_state', template)
def preprocess_record(self, params, record):
"""See base class."""
return self._map_to_queries('preprocess_record', params, record)
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record):
"""See base class."""
return self._map_to_queries(
'accumulate_preprocessed_record',
sample_state,
preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):
return self._map_to_queries(
'merge_sample_states', sample_state_1, sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Gets query result after all records of sample have been accumulated.
Args:
sample_state: The sample state after all records have been accumulated.
global_state: The global state.
Returns:
A tuple (result, new_global_state) where "result" is a structure matching
the query structure containing the results of the subqueries and
"new_global_state" is a structure containing the updated global states
for the subqueries.
"""
estimates_and_new_global_states = self._map_to_queries(
'get_noised_result', sample_state, global_state)
flat_estimates, flat_new_global_states = zip(
*nest.flatten_up_to(self._queries, estimates_and_new_global_states))
return (
nest.pack_sequence_as(self._queries, flat_estimates),
nest.pack_sequence_as(self._queries, flat_new_global_states))

View file

@ -0,0 +1,148 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for NestedQuery."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from distutils.version import LooseVersion
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import nested_query
from tensorflow_privacy.privacy.dp_query import test_utils
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
else:
nest = tf.nest
_basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_nested_gaussian_sum_no_clip_no_noise(self):
with self.cached_session() as sess:
query1 = gaussian_query.GaussianSumQuery(
l2_norm_clip=10.0, stddev=0.0)
query2 = gaussian_query.GaussianSumQuery(
l2_norm_clip=10.0, stddev=0.0)
query = nested_query.NestedQuery([query1, query2])
record1 = [1.0, [2.0, 3.0]]
record2 = [4.0, [3.0, 2.0]]
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [5.0, [5.0, 5.0]]
self.assertAllClose(result, expected)
def test_nested_gaussian_average_no_clip_no_noise(self):
with self.cached_session() as sess:
query1 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
query2 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
query = nested_query.NestedQuery([query1, query2])
record1 = [1.0, [2.0, 3.0]]
record2 = [4.0, [3.0, 2.0]]
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [1.0, [1.0, 1.0]]
self.assertAllClose(result, expected)
def test_nested_gaussian_average_with_clip_no_noise(self):
with self.cached_session() as sess:
query1 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=4.0, sum_stddev=0.0, denominator=5.0)
query2 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=5.0, sum_stddev=0.0, denominator=5.0)
query = nested_query.NestedQuery([query1, query2])
record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]]
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [1.0, [1.0, 1.0]]
self.assertAllClose(result, expected)
def test_complex_nested_query(self):
with self.cached_session() as sess:
query_ab = gaussian_query.GaussianSumQuery(
l2_norm_clip=1.0, stddev=0.0)
query_c = gaussian_query.GaussianAverageQuery(
l2_norm_clip=10.0, sum_stddev=0.0, denominator=2.0)
query_d = gaussian_query.GaussianSumQuery(
l2_norm_clip=10.0, stddev=0.0)
query = nested_query.NestedQuery(
[query_ab, {'c': query_c, 'd': [query_d]}])
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [{'a': 1.0, 'b': 1.0}, {'c': (1.0, 1.0), 'd': [1.0]}]
self.assertAllClose(result, expected)
def test_nested_query_with_noise(self):
with self.cached_session() as sess:
sum_stddev = 2.71828
denominator = 3.14159
query1 = gaussian_query.GaussianSumQuery(
l2_norm_clip=1.5, stddev=sum_stddev)
query2 = gaussian_query.GaussianAverageQuery(
l2_norm_clip=0.5, sum_stddev=sum_stddev, denominator=denominator)
query = nested_query.NestedQuery((query1, query2))
record1 = (3.0, [2.0, 1.5])
record2 = (0.0, [-1.0, -3.5])
query_result, _ = test_utils.run_query(query, [record1, record2])
noised_averages = []
for _ in range(1000):
noised_averages.append(nest.flatten(sess.run(query_result)))
result_stddev = np.std(noised_averages, 0)
avg_stddev = sum_stddev / denominator
expected_stddev = [sum_stddev, avg_stddev, avg_stddev]
self.assertArrayNear(result_stddev, expected_stddev, 0.1)
@parameterized.named_parameters(
('type_mismatch', [_basic_query], (1.0,), TypeError),
('too_many_queries', [_basic_query, _basic_query], [1.0], ValueError),
('query_too_deep', [_basic_query, [_basic_query]], [1.0, 1.0], TypeError))
def test_record_incompatible_with_query(
self, queries, record, error_type):
with self.assertRaises(error_type):
test_utils.run_query(nested_query.NestedQuery(queries), [record])
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,70 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for no privacy average queries."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from distutils.version import LooseVersion
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
else:
nest = tf.nest
class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery interface for a sum query with no privacy.
Accumulates vectors without clipping or adding noise.
"""
def get_noised_result(self, sample_state, global_state):
"""See base class."""
return sample_state, global_state
class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery interface for an average query with no privacy.
Accumulates vectors and normalizes by the total number of accumulated vectors.
"""
def initial_sample_state(self, template):
"""See base class."""
return (super(NoPrivacyAverageQuery, self).initial_sample_state(template),
tf.constant(0.0))
def preprocess_record(self, params, record, weight=1):
"""Multiplies record by weight."""
weighted_record = nest.map_structure(lambda t: weight * t, record)
return (weighted_record, tf.cast(weight, tf.float32))
def accumulate_record(self, params, sample_state, record, weight=1):
"""Accumulates record, multiplying by weight."""
weighted_record = nest.map_structure(lambda t: weight * t, record)
return self.accumulate_preprocessed_record(
sample_state, (weighted_record, tf.cast(weight, tf.float32)))
def get_noised_result(self, sample_state, global_state):
"""See base class."""
sum_state, denominator = sample_state
return (
nest.map_structure(lambda t: t / denominator, sum_state),
global_state)

View file

@ -0,0 +1,77 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for NoPrivacyAverageQuery."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import no_privacy_query
from tensorflow_privacy.privacy.dp_query import test_utils
class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_sum(self):
with self.cached_session() as sess:
record1 = tf.constant([2.0, 0.0])
record2 = tf.constant([-1.0, 1.0])
query = no_privacy_query.NoPrivacySumQuery()
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [1.0, 1.0]
self.assertAllClose(result, expected)
def test_no_privacy_average(self):
with self.cached_session() as sess:
record1 = tf.constant([5.0, 0.0])
record2 = tf.constant([-1.0, 2.0])
query = no_privacy_query.NoPrivacyAverageQuery()
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [2.0, 1.0]
self.assertAllClose(result, expected)
def test_no_privacy_weighted_average(self):
with self.cached_session() as sess:
record1 = tf.constant([4.0, 0.0])
record2 = tf.constant([-1.0, 1.0])
weights = [1, 3]
query = no_privacy_query.NoPrivacyAverageQuery()
query_result, _ = test_utils.run_query(
query, [record1, record2], weights=weights)
result = sess.run(query_result)
expected = [0.25, 0.75]
self.assertAllClose(result, expected)
@parameterized.named_parameters(
('type_mismatch', [1.0], (1.0,), TypeError),
('too_few_on_left', [1.0], [1.0, 1.0], ValueError),
('too_few_on_right', [1.0, 1.0], [1.0], ValueError))
def test_incompatible_records(self, record1, record2, error_type):
query = no_privacy_query.NoPrivacySumQuery()
with self.assertRaises(error_type):
test_utils.run_query(query, [record1, record2])
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,97 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for normalized queries.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from distutils.version import LooseVersion
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
else:
nest = tf.nest
class NormalizedQuery(dp_query.DPQuery):
"""DPQuery for queries with a DPQuery numerator and fixed denominator."""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple(
'_GlobalState', ['numerator_state', 'denominator'])
def __init__(self, numerator_query, denominator):
"""Initializer for NormalizedQuery.
Args:
numerator_query: A DPQuery for the numerator.
denominator: A value for the denominator. May be None if it will be
supplied via the set_denominator function before get_noised_result is
called.
"""
self._numerator = numerator_query
self._denominator = denominator
def set_ledger(self, ledger):
"""See base class."""
self._numerator.set_ledger(ledger)
def initial_global_state(self):
"""See base class."""
if self._denominator is not None:
denominator = tf.cast(self._denominator, tf.float32)
else:
denominator = None
return self._GlobalState(
self._numerator.initial_global_state(), denominator)
def derive_sample_params(self, global_state):
"""See base class."""
return self._numerator.derive_sample_params(global_state.numerator_state)
def initial_sample_state(self, template):
"""See base class."""
# NormalizedQuery has no sample state beyond the numerator state.
return self._numerator.initial_sample_state(template)
def preprocess_record(self, params, record):
return self._numerator.preprocess_record(params, record)
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record):
"""See base class."""
return self._numerator.accumulate_preprocessed_record(
sample_state, preprocessed_record)
def get_noised_result(self, sample_state, global_state):
"""See base class."""
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
sample_state, global_state.numerator_state)
def normalize(v):
return tf.truediv(v, global_state.denominator)
return (nest.map_structure(normalize, noised_sum),
self._GlobalState(new_sum_global_state, global_state.denominator))
def merge_sample_states(self, sample_state_1, sample_state_2):
"""See base class."""
return self._numerator.merge_sample_states(sample_state_1, sample_state_2)

View file

@ -0,0 +1,47 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for GaussianAverageQuery."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import normalized_query
from tensorflow_privacy.privacy.dp_query import test_utils
class NormalizedQueryTest(tf.test.TestCase):
def test_normalization(self):
with self.cached_session() as sess:
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
record2 = tf.constant([4.0, -3.0]) # Not clipped.
sum_query = gaussian_query.GaussianSumQuery(
l2_norm_clip=5.0, stddev=0.0)
query = normalized_query.NormalizedQuery(
numerator_query=sum_query, denominator=2.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [0.5, 0.5]
self.assertAllClose(result, expected)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,288 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for adaptive clip queries.
Instead of a fixed clipping norm specified in advance, the clipping norm is
dynamically adjusted to match a target fraction of clipped updates per sample,
where the actual fraction of clipped updates is itself estimated in a
differentially private manner. For details see Thakkar et al., "Differentially
Private Learning with Adaptive Clipping" [http://arxiv.org/abs/1905.03871].
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from distutils.version import LooseVersion
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import normalized_query
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
else:
nest = tf.nest
class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
"""DPQuery for sum queries with adaptive clipping.
Clipping norm is tuned adaptively to converge to a value such that a specified
quantile of updates are clipped.
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple(
'_GlobalState', [
'l2_norm_clip',
'noise_multiplier',
'target_unclipped_quantile',
'learning_rate',
'sum_state',
'clipped_fraction_state'])
# pylint: disable=invalid-name
_SampleState = collections.namedtuple(
'_SampleState', ['sum_state', 'clipped_fraction_state'])
# pylint: disable=invalid-name
_SampleParams = collections.namedtuple(
'_SampleParams', ['sum_params', 'clipped_fraction_params'])
def __init__(
self,
initial_l2_norm_clip,
noise_multiplier,
target_unclipped_quantile,
learning_rate,
clipped_count_stddev,
expected_num_records):
"""Initializes the QuantileAdaptiveClipSumQuery.
Args:
initial_l2_norm_clip: The initial value of clipping norm.
noise_multiplier: The multiplier of the l2_norm_clip to make the stddev of
the noise added to the output of the sum query.
target_unclipped_quantile: The desired quantile of updates which should be
unclipped. I.e., a value of 0.8 means a value of l2_norm_clip should be
found for which approximately 20% of updates are clipped each round.
learning_rate: The learning rate for the clipping norm adaptation. A
rate of r means that the clipping norm will change by a maximum of r at
each step. This maximum is attained when |clip - target| is 1.0.
clipped_count_stddev: The stddev of the noise added to the clipped_count.
Since the sensitivity of the clipped count is 0.5, as a rule of thumb it
should be about 0.5 for reasonable privacy.
expected_num_records: The expected number of records per round, used to
estimate the clipped count quantile.
"""
self._initial_l2_norm_clip = initial_l2_norm_clip
self._noise_multiplier = noise_multiplier
self._target_unclipped_quantile = target_unclipped_quantile
self._learning_rate = learning_rate
# Initialize sum query's global state with None, to be set later.
self._sum_query = gaussian_query.GaussianSumQuery(None, None)
# self._clipped_fraction_query is a DPQuery used to estimate the fraction of
# records that are clipped. It accumulates an indicator 0/1 of whether each
# record is clipped, and normalizes by the expected number of records. In
# practice, we accumulate clipped counts shifted by -0.5 so they are
# centered at zero. This makes the sensitivity of the clipped count query
# 0.5 instead of 1.0, since the maximum that a single record could affect
# the count is 0.5. Note that although the l2_norm_clip of the clipped
# fraction query is 0.5, no clipping will ever actually occur because the
# value of each record is always +/-0.5.
self._clipped_fraction_query = gaussian_query.GaussianAverageQuery(
l2_norm_clip=0.5,
sum_stddev=clipped_count_stddev,
denominator=expected_num_records)
def set_ledger(self, ledger):
"""See base class."""
self._sum_query.set_ledger(ledger)
self._clipped_fraction_query.set_ledger(ledger)
def initial_global_state(self):
"""See base class."""
initial_l2_norm_clip = tf.cast(self._initial_l2_norm_clip, tf.float32)
noise_multiplier = tf.cast(self._noise_multiplier, tf.float32)
target_unclipped_quantile = tf.cast(self._target_unclipped_quantile,
tf.float32)
learning_rate = tf.cast(self._learning_rate, tf.float32)
sum_stddev = initial_l2_norm_clip * noise_multiplier
sum_query_global_state = self._sum_query.make_global_state(
l2_norm_clip=initial_l2_norm_clip,
stddev=sum_stddev)
return self._GlobalState(
initial_l2_norm_clip,
noise_multiplier,
target_unclipped_quantile,
learning_rate,
sum_query_global_state,
self._clipped_fraction_query.initial_global_state())
def derive_sample_params(self, global_state):
"""See base class."""
# Assign values to variables that inner sum query uses.
sum_params = self._sum_query.derive_sample_params(global_state.sum_state)
clipped_fraction_params = self._clipped_fraction_query.derive_sample_params(
global_state.clipped_fraction_state)
return self._SampleParams(sum_params, clipped_fraction_params)
def initial_sample_state(self, template):
"""See base class."""
sum_state = self._sum_query.initial_sample_state(template)
clipped_fraction_state = self._clipped_fraction_query.initial_sample_state(
tf.constant(0.0))
return self._SampleState(sum_state, clipped_fraction_state)
def preprocess_record(self, params, record):
preprocessed_sum_record, global_norm = (
self._sum_query.preprocess_record_impl(params.sum_params, record))
# Note we are relying on the internals of GaussianSumQuery here. If we want
# to open this up to other kinds of inner queries we'd have to do this in a
# more general way.
l2_norm_clip = params.sum_params
# We accumulate clipped counts shifted by 0.5 so they are centered at zero.
# This makes the sensitivity of the clipped count query 0.5 instead of 1.0.
was_clipped = tf.cast(global_norm >= l2_norm_clip, tf.float32) - 0.5
preprocessed_clipped_fraction_record = (
self._clipped_fraction_query.preprocess_record(
params.clipped_fraction_params, was_clipped))
return preprocessed_sum_record, preprocessed_clipped_fraction_record
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record, weight=1):
"""See base class."""
preprocessed_sum_record, preprocessed_clipped_fraction_record = preprocessed_record
sum_state = self._sum_query.accumulate_preprocessed_record(
sample_state.sum_state, preprocessed_sum_record)
clipped_fraction_state = self._clipped_fraction_query.accumulate_preprocessed_record(
sample_state.clipped_fraction_state,
preprocessed_clipped_fraction_record)
return self._SampleState(sum_state, clipped_fraction_state)
def merge_sample_states(self, sample_state_1, sample_state_2):
"""See base class."""
return self._SampleState(
self._sum_query.merge_sample_states(
sample_state_1.sum_state,
sample_state_2.sum_state),
self._clipped_fraction_query.merge_sample_states(
sample_state_1.clipped_fraction_state,
sample_state_2.clipped_fraction_state))
def get_noised_result(self, sample_state, global_state):
"""See base class."""
gs = global_state
noised_vectors, sum_state = self._sum_query.get_noised_result(
sample_state.sum_state, gs.sum_state)
del sum_state # Unused. To be set explicitly later.
clipped_fraction_result, new_clipped_fraction_state = (
self._clipped_fraction_query.get_noised_result(
sample_state.clipped_fraction_state,
gs.clipped_fraction_state))
# Unshift clipped percentile by 0.5. (See comment in accumulate_record.)
clipped_quantile = clipped_fraction_result + 0.5
unclipped_quantile = 1.0 - clipped_quantile
# Protect against out-of-range estimates.
unclipped_quantile = tf.minimum(1.0, tf.maximum(0.0, unclipped_quantile))
# Loss function is convex, with derivative in [-1, 1], and minimized when
# the true quantile matches the target.
loss_grad = unclipped_quantile - global_state.target_unclipped_quantile
new_l2_norm_clip = gs.l2_norm_clip - global_state.learning_rate * loss_grad
new_l2_norm_clip = tf.maximum(0.0, new_l2_norm_clip)
new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier
new_sum_query_global_state = self._sum_query.make_global_state(
l2_norm_clip=new_l2_norm_clip,
stddev=new_sum_stddev)
new_global_state = global_state._replace(
l2_norm_clip=new_l2_norm_clip,
sum_state=new_sum_query_global_state,
clipped_fraction_state=new_clipped_fraction_state)
return noised_vectors, new_global_state
class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery):
"""DPQuery for average queries with adaptive clipping.
Clipping norm is tuned adaptively to converge to a value such that a specified
quantile of updates are clipped.
Note that we use "fixed-denominator" estimation: the denominator should be
specified as the expected number of records per sample. Accumulating the
denominator separately would also be possible but would be produce a higher
variance estimator.
"""
def __init__(
self,
initial_l2_norm_clip,
noise_multiplier,
denominator,
target_unclipped_quantile,
learning_rate,
clipped_count_stddev,
expected_num_records):
"""Initializes the AdaptiveClipAverageQuery.
Args:
initial_l2_norm_clip: The initial value of clipping norm.
noise_multiplier: The multiplier of the l2_norm_clip to make the stddev of
the noise.
denominator: The normalization constant (applied after noise is added to
the sum).
target_unclipped_quantile: The desired quantile of updates which should be
clipped.
learning_rate: The learning rate for the clipping norm adaptation. A
rate of r means that the clipping norm will change by a maximum of r at
each step. The maximum is attained when |clip - target| is 1.0.
clipped_count_stddev: The stddev of the noise added to the clipped_count.
Since the sensitivity of the clipped count is 0.5, as a rule of thumb it
should be about 0.5 for reasonable privacy.
expected_num_records: The expected number of records, used to estimate the
clipped count quantile.
"""
numerator_query = QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip,
noise_multiplier,
target_unclipped_quantile,
learning_rate,
clipped_count_stddev,
expected_num_records)
super(QuantileAdaptiveClipAverageQuery, self).__init__(
numerator_query=numerator_query,
denominator=denominator)

View file

@ -0,0 +1,296 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for QuantileAdaptiveClipSumQuery."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import quantile_adaptive_clip_sum_query
from tensorflow_privacy.privacy.dp_query import test_utils
tf.enable_eager_execution()
class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase):
def test_sum_no_clip_no_noise(self):
record1 = tf.constant([2.0, 0.0])
record2 = tf.constant([-1.0, 1.0])
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip=10.0,
noise_multiplier=0.0,
target_unclipped_quantile=1.0,
learning_rate=0.0,
clipped_count_stddev=0.0,
expected_num_records=2.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = query_result.numpy()
expected = [1.0, 1.0]
self.assertAllClose(result, expected)
def test_sum_with_clip_no_noise(self):
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
record2 = tf.constant([4.0, -3.0]) # Not clipped.
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip=5.0,
noise_multiplier=0.0,
target_unclipped_quantile=1.0,
learning_rate=0.0,
clipped_count_stddev=0.0,
expected_num_records=2.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = query_result.numpy()
expected = [1.0, 1.0]
self.assertAllClose(result, expected)
def test_sum_with_noise(self):
record1, record2 = 2.71828, 3.14159
stddev = 1.0
clip = 5.0
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip=clip,
noise_multiplier=stddev / clip,
target_unclipped_quantile=1.0,
learning_rate=0.0,
clipped_count_stddev=0.0,
expected_num_records=2.0)
noised_sums = []
for _ in xrange(1000):
query_result, _ = test_utils.run_query(query, [record1, record2])
noised_sums.append(query_result.numpy())
result_stddev = np.std(noised_sums)
self.assertNear(result_stddev, stddev, 0.1)
def test_average_no_noise(self):
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].
record2 = tf.constant([-1.0, 2.0]) # Not clipped.
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipAverageQuery(
initial_l2_norm_clip=3.0,
noise_multiplier=0.0,
denominator=2.0,
target_unclipped_quantile=1.0,
learning_rate=0.0,
clipped_count_stddev=0.0,
expected_num_records=2.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = query_result.numpy()
expected_average = [1.0, 1.0]
self.assertAllClose(result, expected_average)
def test_average_with_noise(self):
record1, record2 = 2.71828, 3.14159
sum_stddev = 1.0
denominator = 2.0
clip = 3.0
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipAverageQuery(
initial_l2_norm_clip=clip,
noise_multiplier=sum_stddev / clip,
denominator=denominator,
target_unclipped_quantile=1.0,
learning_rate=0.0,
clipped_count_stddev=0.0,
expected_num_records=2.0)
noised_averages = []
for _ in range(1000):
query_result, _ = test_utils.run_query(query, [record1, record2])
noised_averages.append(query_result.numpy())
result_stddev = np.std(noised_averages)
avg_stddev = sum_stddev / denominator
self.assertNear(result_stddev, avg_stddev, 0.1)
def test_adaptation_target_zero(self):
record1 = tf.constant([8.5])
record2 = tf.constant([-7.25])
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip=10.0,
noise_multiplier=0.0,
target_unclipped_quantile=0.0,
learning_rate=1.0,
clipped_count_stddev=0.0,
expected_num_records=2.0)
global_state = query.initial_global_state()
initial_clip = global_state.l2_norm_clip
self.assertAllClose(initial_clip, 10.0)
# On the first two iterations, nothing is clipped, so the clip goes down
# by 1.0 (the learning rate). When the clip reaches 8.0, one record is
# clipped, so the clip goes down by only 0.5. After two more iterations,
# both records are clipped, and the clip norm stays there (at 7.0).
expected_sums = [1.25, 1.25, 0.75, 0.25, 0.0]
expected_clips = [9.0, 8.0, 7.5, 7.0, 7.0]
for expected_sum, expected_clip in zip(expected_sums, expected_clips):
actual_sum, global_state = test_utils.run_query(
query, [record1, record2], global_state)
actual_clip = global_state.l2_norm_clip
self.assertAllClose(actual_clip.numpy(), expected_clip)
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
def test_adaptation_target_one(self):
record1 = tf.constant([-1.5])
record2 = tf.constant([2.75])
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip=0.0,
noise_multiplier=0.0,
target_unclipped_quantile=1.0,
learning_rate=1.0,
clipped_count_stddev=0.0,
expected_num_records=2.0)
global_state = query.initial_global_state()
initial_clip = global_state.l2_norm_clip
self.assertAllClose(initial_clip, 0.0)
# On the first two iterations, both are clipped, so the clip goes up
# by 1.0 (the learning rate). When the clip reaches 2.0, only one record is
# clipped, so the clip goes up by only 0.5. After two more iterations,
# both records are clipped, and the clip norm stays there (at 3.0).
expected_sums = [0.0, 0.0, 0.5, 1.0, 1.25]
expected_clips = [1.0, 2.0, 2.5, 3.0, 3.0]
for expected_sum, expected_clip in zip(expected_sums, expected_clips):
actual_sum, global_state = test_utils.run_query(
query, [record1, record2], global_state)
actual_clip = global_state.l2_norm_clip
self.assertAllClose(actual_clip.numpy(), expected_clip)
self.assertAllClose(actual_sum.numpy(), (expected_sum,))
def test_adaptation_linspace(self):
# 100 records equally spaced from 0 to 10 in 0.1 increments.
# Test that with a decaying learning rate we converge to the correct
# median with error at most 0.1.
records = [tf.constant(x) for x in np.linspace(
0.0, 10.0, num=21, dtype=np.float32)]
learning_rate = tf.Variable(1.0)
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip=0.0,
noise_multiplier=0.0,
target_unclipped_quantile=0.5,
learning_rate=learning_rate,
clipped_count_stddev=0.0,
expected_num_records=2.0)
global_state = query.initial_global_state()
for t in range(50):
tf.assign(learning_rate, 1.0 / np.sqrt(t+1))
_, global_state = test_utils.run_query(query, records, global_state)
actual_clip = global_state.l2_norm_clip
if t > 40:
self.assertNear(actual_clip, 5.0, 0.25)
def test_adaptation_all_equal(self):
# 100 equal records. Test that with a decaying learning rate we converge to
# that record and bounce around it.
records = [tf.constant(5.0)] * 20
learning_rate = tf.Variable(1.0)
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip=0.0,
noise_multiplier=0.0,
target_unclipped_quantile=0.5,
learning_rate=learning_rate,
clipped_count_stddev=0.0,
expected_num_records=2.0)
global_state = query.initial_global_state()
for t in range(50):
tf.assign(learning_rate, 1.0 / np.sqrt(t+1))
_, global_state = test_utils.run_query(query, records, global_state)
actual_clip = global_state.l2_norm_clip
if t > 40:
self.assertNear(actual_clip, 5.0, 0.25)
def test_ledger(self):
record1 = tf.constant([8.5])
record2 = tf.constant([-7.25])
population_size = tf.Variable(0)
selection_probability = tf.Variable(1.0)
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip=10.0,
noise_multiplier=1.0,
target_unclipped_quantile=0.0,
learning_rate=1.0,
clipped_count_stddev=0.0,
expected_num_records=2.0)
query = privacy_ledger.QueryWithLedger(
query, population_size, selection_probability)
# First sample.
tf.assign(population_size, 10)
tf.assign(selection_probability, 0.1)
_, global_state = test_utils.run_query(query, [record1, record2])
expected_queries = [[10.0, 10.0], [0.5, 0.0]]
formatted = query.ledger.get_formatted_ledger_eager()
sample_1 = formatted[0]
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
# Second sample.
tf.assign(population_size, 20)
tf.assign(selection_probability, 0.2)
test_utils.run_query(query, [record1, record2], global_state)
formatted = query.ledger.get_formatted_ledger_eager()
sample_1, sample_2 = formatted
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
expected_queries_2 = [[9.0, 9.0], [0.5, 0.0]]
self.assertAllClose(sample_2.population_size, 20.0)
self.assertAllClose(sample_2.selection_probability, 0.2)
self.assertAllClose(sample_2.queries, expected_queries_2)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,49 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility methods for testing private queries.
Utility methods for testing private queries.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def run_query(query, records, global_state=None, weights=None):
"""Executes query on the given set of records as a single sample.
Args:
query: A PrivateQuery to run.
records: An iterable containing records to pass to the query.
global_state: The current global state. If None, an initial global state is
generated.
weights: An optional iterable containing the weights of the records.
Returns:
A tuple (result, new_global_state) where "result" is the result of the
query and "new_global_state" is the updated global state.
"""
if not global_state:
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
sample_state = query.initial_sample_state(next(iter(records)))
if weights is None:
for record in records:
sample_state = query.accumulate_record(params, sample_state, record)
else:
for weight, record in zip(weights, records):
sample_state = query.accumulate_record(
params, sample_state, record, weight)
return query.get_noised_result(sample_state, global_state)

View file

@ -0,0 +1,239 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Differentially private optimizers for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from distutils.version import LooseVersion
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
else:
nest = tf.nest
def make_optimizer_class(cls):
"""Constructs a DP optimizer class from an existing one."""
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
parent_code = tf.train.Optimizer.compute_gradients.__code__
child_code = cls.compute_gradients.__code__
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
else:
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
GATE_OP = None # pylint: disable=invalid-name
if child_code is not parent_code:
tf.logging.warning(
'WARNING: Calling make_optimizer_class() on class %s that overrides '
'method compute_gradients(). Check to ensure that '
'make_optimizer_class() does not interfere with overridden version.',
cls.__name__)
class DPOptimizerClass(cls):
"""Differentially private subclass of given class cls."""
def __init__(
self,
dp_sum_query,
num_microbatches=None,
unroll_microbatches=False,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Initialize the DPOptimizerClass.
Args:
dp_sum_query: DPQuery object, specifying differential privacy
mechanism to use.
num_microbatches: How many microbatches into which the minibatch is
split. If None, will default to the size of the minibatch, and
per-example gradients will be computed.
unroll_microbatches: If true, processes microbatches within a Python
loop instead of a tf.while_loop. Can be used if using a tf.while_loop
raises an exception.
"""
super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._dp_sum_query = dp_sum_query
self._num_microbatches = num_microbatches
self._global_state = self._dp_sum_query.initial_global_state()
# TODO(b/122613513): Set unroll_microbatches=True to avoid this bug.
# Beware: When num_microbatches is large (>100), enabling this parameter
# may cause an OOM error.
self._unroll_microbatches = unroll_microbatches
def compute_gradients(self,
loss,
var_list,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None,
gradient_tape=None):
if callable(loss):
# TF is running in Eager mode, check we received a vanilla tape.
if not gradient_tape:
raise ValueError('When in Eager mode, a tape needs to be passed.')
vector_loss = loss()
if self._num_microbatches is None:
self._num_microbatches = tf.shape(vector_loss)[0]
sample_state = self._dp_sum_query.initial_sample_state(var_list)
microbatches_losses = tf.reshape(vector_loss,
[self._num_microbatches, -1])
sample_params = (
self._dp_sum_query.derive_sample_params(self._global_state))
def process_microbatch(i, sample_state):
"""Process one microbatch (record) with privacy helper."""
microbatch_loss = tf.reduce_mean(tf.gather(microbatches_losses, [i]))
grads = gradient_tape.gradient(microbatch_loss, var_list)
sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, grads)
return sample_state
for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state)
grad_sums, self._global_state = (
self._dp_sum_query.get_noised_result(
sample_state, self._global_state))
def normalize(v):
return v / tf.cast(self._num_microbatches, tf.float32)
final_grads = nest.map_structure(normalize, grad_sums)
grads_and_vars = list(zip(final_grads, var_list))
return grads_and_vars
else:
# TF is running in graph mode, check we did not receive a gradient tape.
if gradient_tape:
raise ValueError('When in graph mode, a tape should not be passed.')
# Note: it would be closer to the correct i.i.d. sampling of records if
# we sampled each microbatch from the appropriate binomial distribution,
# although that still wouldn't be quite correct because it would be
# sampling from the dataset without replacement.
if self._num_microbatches is None:
self._num_microbatches = tf.shape(loss)[0]
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
sample_params = (
self._dp_sum_query.derive_sample_params(self._global_state))
def process_microbatch(i, sample_state):
"""Process one microbatch (record) with privacy helper."""
grads, _ = zip(*super(cls, self).compute_gradients(
tf.reduce_mean(tf.gather(microbatches_losses,
[i])), var_list, gate_gradients,
aggregation_method, colocate_gradients_with_ops, grad_loss))
grads_list = [
g if g is not None else tf.zeros_like(v)
for (g, v) in zip(list(grads), var_list)
]
sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, grads_list)
return sample_state
if var_list is None:
var_list = (
tf.trainable_variables() + tf.get_collection(
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
sample_state = self._dp_sum_query.initial_sample_state(var_list)
if self._unroll_microbatches:
for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state)
else:
# Use of while_loop here requires that sample_state be a nested
# structure of tensors. In general, we would prefer to allow it to be
# an arbitrary opaque type.
cond_fn = lambda i, _: tf.less(i, self._num_microbatches)
body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long
idx = tf.constant(0)
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
grad_sums, self._global_state = (
self._dp_sum_query.get_noised_result(
sample_state, self._global_state))
def normalize(v):
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
final_grads = nest.map_structure(normalize, grad_sums)
return list(zip(final_grads, var_list))
return DPOptimizerClass
def make_gaussian_optimizer_class(cls):
"""Constructs a DP optimizer with Gaussian averaging of updates."""
class DPGaussianOptimizerClass(make_optimizer_class(cls)):
"""DP subclass of given class cls using Gaussian averaging."""
def __init__(
self,
l2_norm_clip,
noise_multiplier,
num_microbatches=None,
ledger=None,
unroll_microbatches=False,
*args, # pylint: disable=keyword-arg-before-vararg
**kwargs):
dp_sum_query = gaussian_query.GaussianSumQuery(
l2_norm_clip, l2_norm_clip * noise_multiplier)
if ledger:
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query,
ledger=ledger)
super(DPGaussianOptimizerClass, self).__init__(
dp_sum_query,
num_microbatches,
unroll_microbatches,
*args,
**kwargs)
@property
def ledger(self):
return self._dp_sum_query.ledger
return DPGaussianOptimizerClass
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
AdagradOptimizer = tf.train.AdagradOptimizer
AdamOptimizer = tf.train.AdamOptimizer
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
else:
AdagradOptimizer = tf.optimizers.Adagrad
AdamOptimizer = tf.optimizers.Adam
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
DPAdagradOptimizer = make_optimizer_class(AdagradOptimizer)
DPAdamOptimizer = make_optimizer_class(AdamOptimizer)
DPGradientDescentOptimizer = make_optimizer_class(GradientDescentOptimizer)
DPAdagradGaussianOptimizer = make_gaussian_optimizer_class(AdagradOptimizer)
DPAdamGaussianOptimizer = make_gaussian_optimizer_class(AdamOptimizer)
DPGradientDescentGaussianOptimizer = make_gaussian_optimizer_class(
GradientDescentOptimizer)

View file

@ -0,0 +1,130 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for differentially private optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.optimizers import dp_optimizer
class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
tf.enable_eager_execution()
super(DPOptimizerEagerTest, self).setUp()
def _loss_fn(self, val0, val1):
return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1)
@parameterized.named_parameters(
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
[-2.5, -2.5]),
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2,
[-2.5, -2.5]),
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4,
[-2.5, -2.5]),
('DPAdagrad 1', dp_optimizer.DPAdagradOptimizer, 1, [-2.5, -2.5]),
('DPAdagrad 2', dp_optimizer.DPAdagradOptimizer, 2, [-2.5, -2.5]),
('DPAdagrad 4', dp_optimizer.DPAdagradOptimizer, 4, [-2.5, -2.5]),
('DPAdam 1', dp_optimizer.DPAdamOptimizer, 1, [-2.5, -2.5]),
('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-2.5, -2.5]),
('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5]))
def testBaseline(self, cls, num_microbatches, expected_answer):
with tf.GradientTape(persistent=True) as gradient_tape:
var0 = tf.Variable([1.0, 2.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(
dp_sum_query, 1e6, num_microbatches / 1e6)
opt = cls(
dp_sum_query,
num_microbatches=num_microbatches,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
# Expected gradient is sum of differences divided by number of
# microbatches.
grads_and_vars = opt.compute_gradients(
lambda: self._loss_fn(var0, data0), [var0],
gradient_tape=gradient_tape)
self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
('DPAdam', dp_optimizer.DPAdamOptimizer))
def testClippingNorm(self, cls):
with tf.GradientTape(persistent=True) as gradient_tape:
var0 = tf.Variable([0.0, 0.0])
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
# Expected gradient is sum of differences.
grads_and_vars = opt.compute_gradients(
lambda: self._loss_fn(var0, data0), [var0],
gradient_tape=gradient_tape)
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
('DPAdam', dp_optimizer.DPAdamOptimizer))
def testNoiseMultiplier(self, cls):
with tf.GradientTape(persistent=True) as gradient_tape:
var0 = tf.Variable([0.0])
data0 = tf.Variable([[0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(4.0, 8.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0], self.evaluate(var0))
grads = []
for _ in range(1000):
grads_and_vars = opt.compute_gradients(
lambda: self._loss_fn(var0, data0), [var0],
gradient_tape=gradient_tape)
grads.append(grads_and_vars[0][0])
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,241 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for differentially private optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import mock
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.optimizers import dp_optimizer
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
def _loss(self, val0, val1):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1)
# Parameters for testing: optimizer, num_microbatches, expected answer.
@parameterized.named_parameters(
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
[-2.5, -2.5]),
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2,
[-2.5, -2.5]),
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4,
[-2.5, -2.5]),
('DPAdagrad 1', dp_optimizer.DPAdagradOptimizer, 1, [-2.5, -2.5]),
('DPAdagrad 2', dp_optimizer.DPAdagradOptimizer, 2, [-2.5, -2.5]),
('DPAdagrad 4', dp_optimizer.DPAdagradOptimizer, 4, [-2.5, -2.5]),
('DPAdam 1', dp_optimizer.DPAdamOptimizer, 1, [-2.5, -2.5]),
('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-2.5, -2.5]),
('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5]))
def testBaseline(self, cls, num_microbatches, expected_answer):
with self.cached_session() as sess:
var0 = tf.Variable([1.0, 2.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(
dp_sum_query, 1e6, num_microbatches / 1e6)
opt = cls(
dp_sum_query,
num_microbatches=num_microbatches,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
# Expected gradient is sum of differences divided by number of
# microbatches.
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads_and_vars = sess.run(gradient_op)
self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
('DPAdam', dp_optimizer.DPAdamOptimizer))
def testClippingNorm(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([0.0, 0.0])
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
# Expected gradient is sum of differences.
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads_and_vars = sess.run(gradient_op)
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
('DPAdam', dp_optimizer.DPAdamOptimizer))
def testNoiseMultiplier(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([0.0])
data0 = tf.Variable([[0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(4.0, 8.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0], self.evaluate(var0))
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads = []
for _ in range(1000):
grads_and_vars = sess.run(gradient_op)
grads.append(grads_and_vars[0][0])
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
@mock.patch.object(tf, 'logging')
def testComputeGradientsOverrideWarning(self, mock_logging):
class SimpleOptimizer(tf.train.Optimizer):
def compute_gradients(self):
return 0
dp_optimizer.make_optimizer_class(SimpleOptimizer)
mock_logging.warning.assert_called_once_with(
'WARNING: Calling make_optimizer_class() on class %s that overrides '
'method compute_gradients(). Check to ensure that '
'make_optimizer_class() does not interfere with overridden version.',
'SimpleOptimizer')
def testEstimator(self):
"""Tests that DP optimizers work with tf.estimator."""
def linear_model_fn(features, labels, mode):
preds = tf.keras.layers.Dense(
1, activation='linear', name='dense').apply(features['x'])
vector_loss = tf.squared_difference(labels, preds)
scalar_loss = tf.reduce_mean(vector_loss)
dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
optimizer = dp_optimizer.DPGradientDescentOptimizer(
dp_sum_query,
num_microbatches=1,
learning_rate=1.0)
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
return tf.estimator.EstimatorSpec(
mode=mode, loss=scalar_loss, train_op=train_op)
linear_regressor = tf.estimator.Estimator(model_fn=linear_model_fn)
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
true_bias = 6.0
train_data = np.random.normal(scale=3.0, size=(200, 4)).astype(np.float32)
train_labels = np.matmul(train_data,
true_weights) + true_bias + np.random.normal(
scale=0.1, size=(200, 1)).astype(np.float32)
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': train_data},
y=train_labels,
batch_size=20,
num_epochs=10,
shuffle=True)
linear_regressor.train(input_fn=train_input_fn, steps=100)
self.assertAllClose(
linear_regressor.get_variable_value('dense/kernel'),
true_weights,
atol=1.0)
@parameterized.named_parameters(
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
('DPAdam', dp_optimizer.DPAdamOptimizer))
def testUnrollMicrobatches(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([1.0, 2.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
num_microbatches = 4
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(
dp_sum_query, 1e6, num_microbatches / 1e6)
opt = cls(
dp_sum_query,
num_microbatches=num_microbatches,
learning_rate=2.0,
unroll_microbatches=True)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
# Expected gradient is sum of differences divided by number of
# microbatches.
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads_and_vars = sess.run(gradient_op)
self.assertAllCloseAccordingToType([-2.5, -2.5], grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', dp_optimizer.DPGradientDescentGaussianOptimizer),
('DPAdagrad', dp_optimizer.DPAdagradGaussianOptimizer),
('DPAdam', dp_optimizer.DPAdamGaussianOptimizer))
def testDPGaussianOptimizerClass(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([0.0])
data0 = tf.Variable([[0.0]])
opt = cls(
l2_norm_clip=4.0,
noise_multiplier=2.0,
num_microbatches=1,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0], self.evaluate(var0))
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads = []
for _ in range(1000):
grads_and_vars = sess.run(gradient_op)
grads.append(grads_and_vars[0][0])
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,153 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vectorized differentially private optimizers for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from distutils.version import LooseVersion
import tensorflow as tf
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
nest = tf.contrib.framework.nest
AdagradOptimizer = tf.train.AdagradOptimizer
AdamOptimizer = tf.train.AdamOptimizer
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
parent_code = tf.train.Optimizer.compute_gradients.__code__
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
else:
nest = tf.nest
AdagradOptimizer = tf.optimizers.Adagrad
AdamOptimizer = tf.optimizers.Adam
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
GATE_OP = None # pylint: disable=invalid-name
def make_vectorized_optimizer_class(cls):
"""Constructs a vectorized DP optimizer class from an existing one."""
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
child_code = cls.compute_gradients.__code__
else:
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
if child_code is not parent_code:
tf.logging.warning(
'WARNING: Calling make_optimizer_class() on class %s that overrides '
'method compute_gradients(). Check to ensure that '
'make_optimizer_class() does not interfere with overridden version.',
cls.__name__)
class DPOptimizerClass(cls):
"""Differentially private subclass of given class cls."""
def __init__(
self,
l2_norm_clip,
noise_multiplier,
num_microbatches=None,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Initialize the DPOptimizerClass.
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients)
noise_multiplier: Ratio of the standard deviation to the clipping norm
num_microbatches: How many microbatches into which the minibatch is
split. If None, will default to the size of the minibatch, and
per-example gradients will be computed.
"""
super(DPOptimizerClass, self).__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier
self._num_microbatches = num_microbatches
def compute_gradients(self,
loss,
var_list,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None,
gradient_tape=None):
if callable(loss):
# TF is running in Eager mode
raise NotImplementedError('Vectorized optimizer unavailable for TF2.')
else:
# TF is running in graph mode, check we did not receive a gradient tape.
if gradient_tape:
raise ValueError('When in graph mode, a tape should not be passed.')
batch_size = tf.shape(loss)[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
# Note: it would be closer to the correct i.i.d. sampling of records if
# we sampled each microbatch from the appropriate binomial distribution,
# although that still wouldn't be quite correct because it would be
# sampling from the dataset without replacement.
microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1])
if var_list is None:
var_list = (
tf.trainable_variables() + tf.get_collection(
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
def process_microbatch(microbatch_loss):
"""Compute clipped grads for one microbatch."""
microbatch_loss = tf.reduce_mean(microbatch_loss)
grads, _ = zip(*super(DPOptimizerClass, self).compute_gradients(
microbatch_loss,
var_list,
gate_gradients,
aggregation_method,
colocate_gradients_with_ops,
grad_loss))
grads_list = [
g if g is not None else tf.zeros_like(v)
for (g, v) in zip(list(grads), var_list)
]
# Clip gradients to have L2 norm of l2_norm_clip.
# Here, we use TF primitives rather than the built-in
# tf.clip_by_global_norm() so that operations can be vectorized
# across microbatches.
grads_flat = nest.flatten(grads_list)
squared_l2_norms = [tf.reduce_sum(tf.square(g)) for g in grads_flat]
global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
div = tf.maximum(global_norm / self._l2_norm_clip, 1.)
clipped_flat = [g / div for g in grads_flat]
clipped_grads = nest.pack_sequence_as(grads_list, clipped_flat)
return clipped_grads
clipped_grads = tf.vectorized_map(process_microbatch, microbatch_losses)
def reduce_noise_normalize_batch(stacked_grads):
summed_grads = tf.reduce_sum(stacked_grads, axis=0)
noise_stddev = self._l2_norm_clip * self._noise_multiplier
noise = tf.random.normal(tf.shape(summed_grads),
stddev=noise_stddev)
noised_grads = summed_grads + noise
return noised_grads / tf.cast(self._num_microbatches, tf.float32)
final_grads = nest.map_structure(reduce_noise_normalize_batch,
clipped_grads)
return list(zip(final_grads, var_list))
return DPOptimizerClass
VectorizedDPAdagrad = make_vectorized_optimizer_class(AdagradOptimizer)
VectorizedDPAdam = make_vectorized_optimizer_class(AdamOptimizer)
VectorizedDPSGD = make_vectorized_optimizer_class(GradientDescentOptimizer)

View file

@ -0,0 +1,204 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for differentially private optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import mock
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagrad
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdam
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGD
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
def _loss(self, val0, val1):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1)
# Parameters for testing: optimizer, num_microbatches, expected answer.
@parameterized.named_parameters(
('DPGradientDescent 1', VectorizedDPSGD, 1, [-2.5, -2.5]),
('DPGradientDescent 2', VectorizedDPSGD, 2, [-2.5, -2.5]),
('DPGradientDescent 4', VectorizedDPSGD, 4, [-2.5, -2.5]),
('DPAdagrad 1', VectorizedDPAdagrad, 1, [-2.5, -2.5]),
('DPAdagrad 2', VectorizedDPAdagrad, 2, [-2.5, -2.5]),
('DPAdagrad 4', VectorizedDPAdagrad, 4, [-2.5, -2.5]),
('DPAdam 1', VectorizedDPAdam, 1, [-2.5, -2.5]),
('DPAdam 2', VectorizedDPAdam, 2, [-2.5, -2.5]),
('DPAdam 4', VectorizedDPAdam, 4, [-2.5, -2.5]))
def testBaseline(self, cls, num_microbatches, expected_answer):
with self.cached_session() as sess:
var0 = tf.Variable([1.0, 2.0])
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
opt = cls(
l2_norm_clip=1.0e9,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
# Expected gradient is sum of differences divided by number of
# microbatches.
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads_and_vars = sess.run(gradient_op)
self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', VectorizedDPSGD),
('DPAdagrad', VectorizedDPAdagrad),
('DPAdam', VectorizedDPAdam))
def testClippingNorm(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([0.0, 0.0])
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
opt = cls(l2_norm_clip=1.0,
noise_multiplier=0.,
num_microbatches=1,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
# Expected gradient is sum of differences.
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads_and_vars = sess.run(gradient_op)
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', VectorizedDPSGD),
('DPAdagrad', VectorizedDPAdagrad),
('DPAdam', VectorizedDPAdam))
def testNoiseMultiplier(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([0.0])
data0 = tf.Variable([[0.0]])
opt = cls(l2_norm_clip=4.0,
noise_multiplier=8.0,
num_microbatches=1,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0], self.evaluate(var0))
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads = []
for _ in range(5000):
grads_and_vars = sess.run(gradient_op)
grads.append(grads_and_vars[0][0])
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(np.std(grads), 4.0 * 8.0, 0.5)
@mock.patch.object(tf, 'logging')
def testComputeGradientsOverrideWarning(self, mock_logging):
class SimpleOptimizer(tf.train.Optimizer):
def compute_gradients(self):
return 0
dp_optimizer_vectorized.make_vectorized_optimizer_class(SimpleOptimizer)
mock_logging.warning.assert_called_once_with(
'WARNING: Calling make_optimizer_class() on class %s that overrides '
'method compute_gradients(). Check to ensure that '
'make_optimizer_class() does not interfere with overridden version.',
'SimpleOptimizer')
def testEstimator(self):
"""Tests that DP optimizers work with tf.estimator."""
def linear_model_fn(features, labels, mode):
preds = tf.keras.layers.Dense(
1, activation='linear', name='dense').apply(features['x'])
vector_loss = tf.squared_difference(labels, preds)
scalar_loss = tf.reduce_mean(vector_loss)
optimizer = VectorizedDPSGD(
l2_norm_clip=1.0,
noise_multiplier=0.,
num_microbatches=1,
learning_rate=1.0)
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
return tf.estimator.EstimatorSpec(
mode=mode, loss=scalar_loss, train_op=train_op)
linear_regressor = tf.estimator.Estimator(model_fn=linear_model_fn)
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
true_bias = 6.0
train_data = np.random.normal(scale=3.0, size=(200, 4)).astype(np.float32)
train_labels = np.matmul(train_data,
true_weights) + true_bias + np.random.normal(
scale=0.1, size=(200, 1)).astype(np.float32)
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'x': train_data},
y=train_labels,
batch_size=20,
num_epochs=10,
shuffle=True)
linear_regressor.train(input_fn=train_input_fn, steps=100)
self.assertAllClose(
linear_regressor.get_variable_value('dense/kernel'),
true_weights,
atol=1.0)
@parameterized.named_parameters(
('DPGradientDescent', VectorizedDPSGD),
('DPAdagrad', VectorizedDPAdagrad),
('DPAdam', VectorizedDPAdam))
def testDPGaussianOptimizerClass(self, cls):
with self.cached_session() as sess:
var0 = tf.Variable([0.0])
data0 = tf.Variable([[0.0]])
opt = cls(
l2_norm_clip=4.0,
noise_multiplier=2.0,
num_microbatches=1,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([0.0], self.evaluate(var0))
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads = []
for _ in range(1000):
grads_and_vars = sess.run(gradient_op)
grads.append(grads_and_vars[0][0])
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
if __name__ == '__main__':
tf.test.main()

View file

@ -16,9 +16,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf # pylint: disable=wrong-import-position import tensorflow as tf # pylint: disable=wrong-import-position
from privacy.bolt_on import losses # pylint: disable=wrong-import-position from tensorflow_privacy.privacy.bolt_on import losses # pylint: disable=wrong-import-position
from privacy.bolt_on import models # pylint: disable=wrong-import-position from tensorflow_privacy.privacy.bolt_on import models # pylint: disable=wrong-import-position
from privacy.bolt_on.optimizers import BoltOn # pylint: disable=wrong-import-position from tensorflow_privacy.privacy.bolt_on.optimizers import BoltOn # pylint: disable=wrong-import-position
# ------- # -------
# First, we will create a binary classification dataset with a single output # First, we will create a binary classification dataset with a single output
# dimension. The samples for each label are repeated data points at different # dimension. The samples for each label are repeated data points at different

View file

@ -44,10 +44,10 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from privacy.analysis import privacy_ledger from tensorflow_privacy.privacy.analysis import privacy_ledger
from privacy.analysis.rdp_accountant import compute_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from privacy.analysis.rdp_accountant import get_privacy_spent from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers import dp_optimizer from tensorflow_privacy.privacy.optimizers import dp_optimizer
flags.DEFINE_boolean( flags.DEFINE_boolean(
'dpsgd', True, 'If True, train with DP-SGD. If False, ' 'dpsgd', True, 'If True, train with DP-SGD. If False, '

View file

@ -26,10 +26,10 @@ from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from privacy.analysis import privacy_ledger from tensorflow_privacy.privacy.analysis import privacy_ledger
from privacy.analysis.rdp_accountant import compute_rdp_from_ledger from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp_from_ledger
from privacy.analysis.rdp_accountant import get_privacy_spent from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers import dp_optimizer from tensorflow_privacy.privacy.optimizers import dp_optimizer
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
GradientDescentOptimizer = tf.train.GradientDescentOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer

View file

@ -24,9 +24,9 @@ from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from privacy.analysis.rdp_accountant import compute_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from privacy.analysis.rdp_accountant import get_privacy_spent from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
GradientDescentOptimizer = tf.train.GradientDescentOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer

View file

@ -25,9 +25,9 @@ from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from privacy.analysis.rdp_accountant import compute_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from privacy.analysis.rdp_accountant import get_privacy_spent from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
GradientDescentOptimizer = tf.train.GradientDescentOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer

View file

@ -26,9 +26,9 @@ from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from privacy.analysis.rdp_accountant import compute_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from privacy.analysis.rdp_accountant import get_privacy_spent from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers import dp_optimizer_vectorized from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized
flags.DEFINE_boolean( flags.DEFINE_boolean(

View file

@ -35,9 +35,9 @@ from distutils.version import LooseVersion
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from privacy.analysis.rdp_accountant import compute_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from privacy.analysis.rdp_accountant import get_privacy_spent from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from privacy.optimizers import dp_optimizer from tensorflow_privacy.privacy.optimizers import dp_optimizer
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
GradientDescentOptimizer = tf.train.GradientDescentOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer