forked from 626_privacy/tensorflow_privacy
PiperOrigin-RevId: 224061027
This commit is contained in:
parent
c0a43b2178
commit
afb8189dba
8 changed files with 590 additions and 37 deletions
|
@ -4,10 +4,5 @@ This repository will contain implementations of TensorFlow optimizers that
|
||||||
support training machine learning models with (differential) privacy, as well
|
support training machine learning models with (differential) privacy, as well
|
||||||
as tutorials and analysis tools for computing the privacy guarantees provided.
|
as tutorials and analysis tools for computing the privacy guarantees provided.
|
||||||
|
|
||||||
The content of this repository will superseed the following existing repository:
|
The content of this repository will supersede the following existing repository:
|
||||||
https://github.com/tensorflow/models/tree/master/research/differential_privacy
|
https://github.com/tensorflow/models/tree/master/research/differential_privacy
|
||||||
|
|
||||||
# Contact
|
|
||||||
|
|
||||||
* Steve Chien (schien@google.com)
|
|
||||||
* Nicolas Papernot (@npapernot)
|
|
||||||
|
|
55
privacy/optimizers/BUILD
Normal file
55
privacy/optimizers/BUILD
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "gaussian_average_query",
|
||||||
|
srcs = ["gaussian_average_query.py"],
|
||||||
|
deps = [
|
||||||
|
":private_queries",
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dp_optimizers",
|
||||||
|
deps = [
|
||||||
|
":dp_adam",
|
||||||
|
":dp_gradient_descent",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dp_adam",
|
||||||
|
srcs = [
|
||||||
|
"dp_adam.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":gaussian_average_query",
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dp_gradient_descent",
|
||||||
|
srcs = [
|
||||||
|
"dp_gradient_descent.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":gaussian_average_query",
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dp_optimizer_test",
|
||||||
|
srcs = ["dp_optimizer_test.py"],
|
||||||
|
deps = [
|
||||||
|
":dp_optimizers",
|
||||||
|
"//third_party/py/absl/testing:parameterized",
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "private_queries",
|
||||||
|
srcs = ["private_queries.py"],
|
||||||
|
)
|
122
privacy/optimizers/dp_adam.py
Normal file
122
privacy/optimizers/dp_adam.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
# Copyright 2018, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""DPAdamOptimizer for TensorFlow."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
import tensorflow_privacy.privacy.optimizers.gaussian_average_query as ph
|
||||||
|
|
||||||
|
|
||||||
|
class DPAdamOptimizer(tf.train.AdamOptimizer):
|
||||||
|
"""Optimizer that implements the DP Adam algorithm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
learning_rate,
|
||||||
|
beta1=0.9,
|
||||||
|
beta2=0.999,
|
||||||
|
epsilon=1e-8,
|
||||||
|
use_locking=False,
|
||||||
|
l2_norm_clip=1e9,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
nb_microbatches=1,
|
||||||
|
name='DPAdam'):
|
||||||
|
"""Construct a new DP Adam optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate: A Tensor or a floating point value. The learning rate to
|
||||||
|
use.
|
||||||
|
beta1: A float value or a constant float tensor.
|
||||||
|
The exponential decay rate for the 1st moment estimates.
|
||||||
|
beta2: A float value or a constant float tensor.
|
||||||
|
The exponential decay rate for the 2nd moment estimates.
|
||||||
|
epsilon: A small constant for numerical stability. This epsilon is
|
||||||
|
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
||||||
|
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
||||||
|
use_locking: If True use locks for update operations.
|
||||||
|
l2_norm_clip: Clipping parameter for DP-SGD.
|
||||||
|
noise_multiplier: Noise multiplier for DP-SGD.
|
||||||
|
nb_microbatches: Number of microbatches in which to split the input.
|
||||||
|
name: Optional name prefix for the operations created when applying
|
||||||
|
gradients. Defaults to "DPAdam". @compatibility(eager) When eager
|
||||||
|
execution is enabled, `learning_rate` can be a callable that takes no
|
||||||
|
arguments and returns the actual value to use. This can be useful for
|
||||||
|
changing these values across different invocations of optimizer
|
||||||
|
functions. @end_compatibility
|
||||||
|
"""
|
||||||
|
super(DPAdamOptimizer, self).__init__(
|
||||||
|
learning_rate,
|
||||||
|
beta1,
|
||||||
|
beta2,
|
||||||
|
epsilon,
|
||||||
|
use_locking,
|
||||||
|
name)
|
||||||
|
stddev = l2_norm_clip * noise_multiplier
|
||||||
|
self._nb_microbatches = nb_microbatches
|
||||||
|
self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev,
|
||||||
|
nb_microbatches)
|
||||||
|
self._ph_global_state = self._privacy_helper.initial_global_state()
|
||||||
|
|
||||||
|
def compute_gradients(self,
|
||||||
|
loss,
|
||||||
|
var_list,
|
||||||
|
gate_gradients=tf.train.Optimizer.GATE_OP,
|
||||||
|
aggregation_method=None,
|
||||||
|
colocate_gradients_with_ops=False,
|
||||||
|
grad_loss=None):
|
||||||
|
|
||||||
|
# Note: it would be closer to the correct i.i.d. sampling of records if
|
||||||
|
# we sampled each microbatch from the appropriate binomial distribution,
|
||||||
|
# although that still wouldn't be quite correct because it would be sampling
|
||||||
|
# from the dataset without replacement.
|
||||||
|
microbatches_losses = tf.reshape(loss, [self._nb_microbatches, -1])
|
||||||
|
sample_params = (
|
||||||
|
self._privacy_helper.derive_sample_params(self._ph_global_state))
|
||||||
|
|
||||||
|
def process_microbatch(i, sample_state):
|
||||||
|
"""Process one microbatch (record) with privacy helper."""
|
||||||
|
grads, _ = zip(*super(DPAdamOptimizer, self).compute_gradients(
|
||||||
|
tf.gather(microbatches_losses, [i]), var_list, gate_gradients,
|
||||||
|
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||||
|
sample_state = self._privacy_helper.accumulate_record(
|
||||||
|
sample_params, sample_state, grads)
|
||||||
|
return [tf.add(i, 1), sample_state]
|
||||||
|
|
||||||
|
i = tf.constant(0)
|
||||||
|
|
||||||
|
if var_list is None:
|
||||||
|
var_list = (
|
||||||
|
tf.trainable_variables() +
|
||||||
|
tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||||
|
sample_state = self._privacy_helper.initial_sample_state(
|
||||||
|
self._ph_global_state, var_list)
|
||||||
|
|
||||||
|
# Use of while_loop here requires that sample_state be a nested structure of
|
||||||
|
# tensors. In general, we would prefer to allow it to be an arbitrary
|
||||||
|
# opaque type.
|
||||||
|
_, final_state = tf.while_loop(
|
||||||
|
lambda i, _: tf.less(i, self._nb_microbatches), process_microbatch,
|
||||||
|
[i, sample_state])
|
||||||
|
final_grads, self._ph_global_state = (
|
||||||
|
self._privacy_helper.get_noised_average(final_state,
|
||||||
|
self._ph_global_state))
|
||||||
|
|
||||||
|
return zip(final_grads, var_list)
|
||||||
|
|
106
privacy/optimizers/dp_gradient_descent.py
Normal file
106
privacy/optimizers/dp_gradient_descent.py
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
# Copyright 2018, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""DPGradientDescentOptimizer for TensorFlow."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
import tensorflow_privacy.privacy.optimizers.gaussian_average_query as ph
|
||||||
|
|
||||||
|
|
||||||
|
class DPGradientDescentOptimizer(tf.train.GradientDescentOptimizer):
|
||||||
|
"""Optimizer that implements the DP gradient descent algorithm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
learning_rate,
|
||||||
|
use_locking=False,
|
||||||
|
l2_norm_clip=1e9,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
nb_microbatches=1,
|
||||||
|
name='DPGradientDescent'):
|
||||||
|
"""Construct a new DP gradient descent optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate: A Tensor or a floating point value. The learning rate to
|
||||||
|
use.
|
||||||
|
use_locking: If True use locks for update operations.
|
||||||
|
l2_norm_clip: Clipping parameter for DP-SGD.
|
||||||
|
noise_multiplier: Noise multiplier for DP-SGD.
|
||||||
|
nb_microbatches: Number of microbatches in which to split the input.
|
||||||
|
name: Optional name prefix for the operations created when applying
|
||||||
|
gradients. Defaults to "DPGradientDescent". @compatibility(eager) When
|
||||||
|
eager execution is enabled, `learning_rate` can be a callable that takes
|
||||||
|
no arguments and returns the actual value to use. This can be useful for
|
||||||
|
changing these values across different invocations of optimizer
|
||||||
|
functions. @end_compatibility
|
||||||
|
"""
|
||||||
|
super(DPGradientDescentOptimizer, self).__init__(learning_rate, use_locking,
|
||||||
|
name)
|
||||||
|
stddev = l2_norm_clip * noise_multiplier
|
||||||
|
self._nb_microbatches = nb_microbatches
|
||||||
|
self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev,
|
||||||
|
nb_microbatches)
|
||||||
|
self._ph_global_state = self._privacy_helper.initial_global_state()
|
||||||
|
|
||||||
|
def compute_gradients(self,
|
||||||
|
loss,
|
||||||
|
var_list,
|
||||||
|
gate_gradients=tf.train.Optimizer.GATE_OP,
|
||||||
|
aggregation_method=None,
|
||||||
|
colocate_gradients_with_ops=False,
|
||||||
|
grad_loss=None):
|
||||||
|
|
||||||
|
# Note: it would be closer to the correct i.i.d. sampling of records if
|
||||||
|
# we sampled each microbatch from the appropriate binomial distribution,
|
||||||
|
# although that still wouldn't be quite correct because it would be sampling
|
||||||
|
# from the dataset without replacement.
|
||||||
|
microbatches_losses = tf.reshape(loss, [self._nb_microbatches, -1])
|
||||||
|
sample_params = (
|
||||||
|
self._privacy_helper.derive_sample_params(self._ph_global_state))
|
||||||
|
|
||||||
|
def process_microbatch(i, sample_state):
|
||||||
|
"""Process one microbatch (record) with privacy helper."""
|
||||||
|
grads, _ = zip(*super(DPGradientDescentOptimizer, self).compute_gradients(
|
||||||
|
tf.gather(microbatches_losses, [i]), var_list, gate_gradients,
|
||||||
|
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||||
|
sample_state = self._privacy_helper.accumulate_record(
|
||||||
|
sample_params, sample_state, grads)
|
||||||
|
return [tf.add(i, 1), sample_state]
|
||||||
|
|
||||||
|
i = tf.constant(0)
|
||||||
|
|
||||||
|
if var_list is None:
|
||||||
|
var_list = (
|
||||||
|
tf.trainable_variables() +
|
||||||
|
tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||||
|
sample_state = self._privacy_helper.initial_sample_state(
|
||||||
|
self._ph_global_state, var_list)
|
||||||
|
|
||||||
|
# Use of while_loop here requires that sample_state be a nested structure of
|
||||||
|
# tensors. In general, we would prefer to allow it to be an arbitrary
|
||||||
|
# opaque type.
|
||||||
|
_, final_state = tf.while_loop(
|
||||||
|
lambda i, _: tf.less(i, self._nb_microbatches), process_microbatch,
|
||||||
|
[i, sample_state])
|
||||||
|
final_grads, self._ph_global_state = (
|
||||||
|
self._privacy_helper.get_noised_average(final_state,
|
||||||
|
self._ph_global_state))
|
||||||
|
|
||||||
|
return zip(final_grads, var_list)
|
108
privacy/optimizers/dp_optimizer_test.py
Normal file
108
privacy/optimizers/dp_optimizer_test.py
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
# Copyright 2018, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for differentially private optimizers."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.optimizers import dp_adam
|
||||||
|
from tensorflow_privacy.privacy.optimizers import dp_gradient_descent
|
||||||
|
|
||||||
|
|
||||||
|
def loss(val0, val1):
|
||||||
|
"""Loss function that is minimized at the mean of the input points."""
|
||||||
|
return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
# Parameters for testing: optimizer, nb_microbatches, expected answer.
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('DPGradientDescent 1', dp_gradient_descent.DPGradientDescentOptimizer, 1,
|
||||||
|
[-10.0, -10.0]),
|
||||||
|
('DPGradientDescent 2', dp_gradient_descent.DPGradientDescentOptimizer, 2,
|
||||||
|
[-5.0, -5.0]),
|
||||||
|
('DPGradientDescent 4', dp_gradient_descent.DPGradientDescentOptimizer, 4,
|
||||||
|
[-2.5, -2.5]), ('DPAdam 1', dp_adam.DPAdamOptimizer, 1, [-10.0, -10.0]),
|
||||||
|
('DPAdam 2', dp_adam.DPAdamOptimizer, 2, [-5.0, -5.0]),
|
||||||
|
('DPAdam 4', dp_adam.DPAdamOptimizer, 4, [-2.5, -2.5]))
|
||||||
|
def testBaseline(self, cls, nb_microbatches, expected_answer):
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
var0 = tf.Variable([1.0, 2.0])
|
||||||
|
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||||
|
|
||||||
|
opt = cls(learning_rate=2.0, nb_microbatches=nb_microbatches)
|
||||||
|
self.evaluate(tf.global_variables_initializer())
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
|
|
||||||
|
# Expected gradient is sum of differences divided by number of
|
||||||
|
# microbatches.
|
||||||
|
gradient_op = opt.compute_gradients(loss(data0, var0), [var0])
|
||||||
|
grads_and_vars = sess.run(gradient_op)
|
||||||
|
self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0])
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('DPGradientDescent', dp_gradient_descent.DPGradientDescentOptimizer),
|
||||||
|
('DPAdam', dp_adam.DPAdamOptimizer))
|
||||||
|
def testClippingNorm(self, cls):
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
var0 = tf.Variable([0.0, 0.0])
|
||||||
|
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
|
||||||
|
|
||||||
|
opt = cls(learning_rate=2.0, l2_norm_clip=1.0, nb_microbatches=1)
|
||||||
|
self.evaluate(tf.global_variables_initializer())
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||||
|
|
||||||
|
# Expected gradient is sum of differences.
|
||||||
|
gradient_op = opt.compute_gradients(loss(data0, var0), [var0])
|
||||||
|
grads_and_vars = sess.run(gradient_op)
|
||||||
|
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('DPGradientDescent', dp_gradient_descent.DPGradientDescentOptimizer),
|
||||||
|
('DPAdam', dp_adam.DPAdamOptimizer))
|
||||||
|
def testNoiseMultiplier(self, cls):
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
var0 = tf.Variable([0.0])
|
||||||
|
data0 = tf.Variable([[0.0]])
|
||||||
|
|
||||||
|
opt = cls(
|
||||||
|
learning_rate=2.0,
|
||||||
|
l2_norm_clip=4.0,
|
||||||
|
noise_multiplier=2.0,
|
||||||
|
nb_microbatches=1)
|
||||||
|
self.evaluate(tf.global_variables_initializer())
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
self.assertAllClose([0.0], self.evaluate(var0))
|
||||||
|
|
||||||
|
gradient_op = opt.compute_gradients(loss(data0, var0), [var0])
|
||||||
|
grads = []
|
||||||
|
for _ in xrange(1000):
|
||||||
|
grads_and_vars = sess.run(gradient_op)
|
||||||
|
grads.append(grads_and_vars[0][0])
|
||||||
|
|
||||||
|
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||||
|
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
108
privacy/optimizers/gaussian_average_query.py
Normal file
108
privacy/optimizers/gaussian_average_query.py
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
# Copyright 2018, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Implements PrivateQuery interface for Gaussian average queries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.optimizers import private_queries
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianAverageQuery(private_queries.PrivateAverageQuery):
|
||||||
|
"""Implements PrivateQuery interface for Gaussian average queries.
|
||||||
|
|
||||||
|
Accumulates clipped vectors, then adds Gaussian noise to the average.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
_GlobalState = collections.namedtuple(
|
||||||
|
'_GlobalState', ['l2_norm_clip', 'stddev', 'denominator'])
|
||||||
|
|
||||||
|
def __init__(self, l2_norm_clip, stddev, denominator):
|
||||||
|
"""Initializes the GaussianAverageQuery."""
|
||||||
|
self._l2_norm_clip = l2_norm_clip
|
||||||
|
self._stddev = stddev
|
||||||
|
self._denominator = denominator
|
||||||
|
|
||||||
|
def initial_global_state(self):
|
||||||
|
"""Returns the initial global state for the PrivacyHelper."""
|
||||||
|
return self._GlobalState(
|
||||||
|
float(self._l2_norm_clip), float(self._stddev),
|
||||||
|
float(self._denominator))
|
||||||
|
|
||||||
|
def derive_sample_params(self, global_state):
|
||||||
|
"""Given the global state, derives parameters to use for the next sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_state: The current global state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parameters to use to process records in the next sample.
|
||||||
|
"""
|
||||||
|
return global_state.l2_norm_clip
|
||||||
|
|
||||||
|
def initial_sample_state(self, global_state, tensors):
|
||||||
|
"""Returns an initial state to use for the next sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_state: The current global state.
|
||||||
|
tensors: A structure of tensors used as a template to create the initial
|
||||||
|
sample state.
|
||||||
|
|
||||||
|
Returns: An initial sample state.
|
||||||
|
"""
|
||||||
|
del global_state # unused.
|
||||||
|
return tf.contrib.framework.nest.map_structure(tf.zeros_like, tensors)
|
||||||
|
|
||||||
|
def accumulate_record(self, params, sample_state, record):
|
||||||
|
"""Accumulates a single record into the sample state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: The parameters for the sample.
|
||||||
|
sample_state: The current sample state.
|
||||||
|
record: The record to accumulate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated sample state.
|
||||||
|
"""
|
||||||
|
l2_norm_clip = params
|
||||||
|
clipped, _ = tf.clip_by_global_norm(record, l2_norm_clip)
|
||||||
|
return tf.contrib.framework.nest.map_structure(tf.add, sample_state,
|
||||||
|
clipped)
|
||||||
|
|
||||||
|
def get_noised_average(self, sample_state, global_state):
|
||||||
|
"""Gets noised average after all records of sample have been accumulated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_state: The sample state after all records have been accumulated.
|
||||||
|
global_state: The global state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (estimate, new_global_state) where "estimate" is the estimated
|
||||||
|
average of the records and "new_global_state" is the updated global state.
|
||||||
|
"""
|
||||||
|
def noised_average(v):
|
||||||
|
return tf.truediv(
|
||||||
|
v + tf.random_normal(tf.shape(v), stddev=self._stddev),
|
||||||
|
global_state.denominator)
|
||||||
|
|
||||||
|
return (tf.contrib.framework.nest.map_structure(noised_average,
|
||||||
|
sample_state), global_state)
|
90
privacy/optimizers/private_queries.py
Normal file
90
privacy/optimizers/private_queries.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright 2018, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""An interface for differentially private query mechanisms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateQuery(object):
|
||||||
|
"""Interface for differentially private query mechanisms."""
|
||||||
|
|
||||||
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def initial_global_state(self):
|
||||||
|
"""Returns the initial global state for the PrivateQuery."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def derive_sample_params(self, global_state):
|
||||||
|
"""Given the global state, derives parameters to use for the next sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_state: The current global state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parameters to use to process records in the next sample.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def initial_sample_state(self, global_state, tensors):
|
||||||
|
"""Returns an initial state to use for the next sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_state: The current global state.
|
||||||
|
tensors: A structure of tensors used as a template to create the initial
|
||||||
|
sample state.
|
||||||
|
|
||||||
|
Returns: An initial sample state.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def accumulate_record(self, params, sample_state, record):
|
||||||
|
"""Accumulates a single record into the sample state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: The parameters for the sample.
|
||||||
|
sample_state: The current sample state.
|
||||||
|
record: The record to accumulate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated sample state.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateAverageQuery(PrivateQuery):
|
||||||
|
"""Interface for differentially private mechanisms to compute an average."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_noised_average(self, sample_state, global_state):
|
||||||
|
"""Gets average estimate after all records of sample have been accumulated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_state: The sample state after all records have been accumulated.
|
||||||
|
global_state: The global state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (estimate, new_global_state) where "estimate" is the estimated
|
||||||
|
average of the records and "new_global_state" is the updated global state.
|
||||||
|
"""
|
||||||
|
pass
|
|
@ -1,31 +0,0 @@
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
from absl import flags
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
if len(argv) > 1:
|
|
||||||
raise app.UsageError('Too many command-line arguments.')
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
app.run(main)
|
|
Loading…
Reference in a new issue