forked from 626_privacy/tensorflow_privacy
Make DPQuery classes (almost) completely functional: the only state from the initializer that is used gets pushed into the initial_global_state.
PiperOrigin-RevId: 248424593
This commit is contained in:
parent
17fefb3895
commit
3908429796
6 changed files with 100 additions and 53 deletions
|
@ -257,6 +257,6 @@ class QueryWithLedger(dp_query.DPQuery):
|
||||||
with tf.control_dependencies([self._ledger.finalize_sample()]):
|
with tf.control_dependencies([self._ledger.finalize_sample()]):
|
||||||
return self._query.get_noised_result(sample_state, global_state)
|
return self._query.get_noised_result(sample_state, global_state)
|
||||||
|
|
||||||
def set_denominator(self, num_microbatches, microbatch_size=1):
|
def set_denominator(self, global_state, num_microbatches, microbatch_size=1):
|
||||||
self._query.set_denominator(num_microbatches)
|
|
||||||
self._ledger.set_sample_size(num_microbatches * microbatch_size)
|
self._ledger.set_sample_size(num_microbatches * microbatch_size)
|
||||||
|
return self._query.set_denominator(global_state, num_microbatches)
|
||||||
|
|
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -37,6 +39,10 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
Accumulates clipped vectors, then adds Gaussian noise to the sum.
|
Accumulates clipped vectors, then adds Gaussian noise to the sum.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
_GlobalState = collections.namedtuple(
|
||||||
|
'_GlobalState', ['l2_norm_clip', 'stddev'])
|
||||||
|
|
||||||
def __init__(self, l2_norm_clip, stddev, ledger=None):
|
def __init__(self, l2_norm_clip, stddev, ledger=None):
|
||||||
"""Initializes the GaussianSumQuery.
|
"""Initializes the GaussianSumQuery.
|
||||||
|
|
||||||
|
@ -46,17 +52,26 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
stddev: The stddev of the noise added to the sum.
|
stddev: The stddev of the noise added to the sum.
|
||||||
ledger: The privacy ledger to which queries should be recorded.
|
ledger: The privacy ledger to which queries should be recorded.
|
||||||
"""
|
"""
|
||||||
self._l2_norm_clip = tf.cast(l2_norm_clip, tf.float32)
|
self._l2_norm_clip = l2_norm_clip
|
||||||
self._stddev = tf.cast(stddev, tf.float32)
|
self._stddev = stddev
|
||||||
self._ledger = ledger
|
self._ledger = ledger
|
||||||
|
|
||||||
|
def make_global_state(self, l2_norm_clip, stddev):
|
||||||
|
"""Creates a global state from the given parameters."""
|
||||||
|
return self._GlobalState(tf.cast(l2_norm_clip, tf.float32),
|
||||||
|
tf.cast(stddev, tf.float32))
|
||||||
|
|
||||||
|
def initial_global_state(self):
|
||||||
|
return self.make_global_state(self._l2_norm_clip, self._stddev)
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
return self._l2_norm_clip
|
return global_state.l2_norm_clip
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, template):
|
def initial_sample_state(self, global_state, template):
|
||||||
if self._ledger:
|
if self._ledger:
|
||||||
dependencies = [
|
dependencies = [
|
||||||
self._ledger.record_sum_query(self._l2_norm_clip, self._stddev)
|
self._ledger.record_sum_query(
|
||||||
|
global_state.l2_norm_clip, global_state.stddev)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
dependencies = []
|
dependencies = []
|
||||||
|
@ -89,9 +104,9 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||||
def add_noise(v):
|
def add_noise(v):
|
||||||
return v + tf.random_normal(tf.shape(v), stddev=self._stddev)
|
return v + tf.random_normal(tf.shape(v), stddev=global_state.stddev)
|
||||||
else:
|
else:
|
||||||
random_normal = tf.random_normal_initializer(stddev=self._stddev)
|
random_normal = tf.random_normal_initializer(stddev=global_state.stddev)
|
||||||
def add_noise(v):
|
def add_noise(v):
|
||||||
return v + random_normal(tf.shape(v))
|
return v + random_normal(tf.shape(v))
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -33,6 +35,10 @@ else:
|
||||||
class NormalizedQuery(dp_query.DPQuery):
|
class NormalizedQuery(dp_query.DPQuery):
|
||||||
"""DPQuery for queries with a DPQuery numerator and fixed denominator."""
|
"""DPQuery for queries with a DPQuery numerator and fixed denominator."""
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
_GlobalState = collections.namedtuple(
|
||||||
|
'_GlobalState', ['numerator_state', 'denominator'])
|
||||||
|
|
||||||
def __init__(self, numerator_query, denominator):
|
def __init__(self, numerator_query, denominator):
|
||||||
"""Initializer for NormalizedQuery.
|
"""Initializer for NormalizedQuery.
|
||||||
|
|
||||||
|
@ -43,22 +49,26 @@ class NormalizedQuery(dp_query.DPQuery):
|
||||||
called.
|
called.
|
||||||
"""
|
"""
|
||||||
self._numerator = numerator_query
|
self._numerator = numerator_query
|
||||||
self._denominator = (
|
self._denominator = denominator
|
||||||
tf.cast(denominator, tf.float32) if denominator is not None else None)
|
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
# NormalizedQuery has no global state beyond the numerator state.
|
if self._denominator is not None:
|
||||||
return self._numerator.initial_global_state()
|
denominator = tf.cast(self._denominator, tf.float32)
|
||||||
|
else:
|
||||||
|
denominator = None
|
||||||
|
return self._GlobalState(
|
||||||
|
self._numerator.initial_global_state(), denominator)
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._numerator.derive_sample_params(global_state)
|
return self._numerator.derive_sample_params(global_state.numerator_state)
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, template):
|
def initial_sample_state(self, global_state, template):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
# NormalizedQuery has no sample state beyond the numerator state.
|
# NormalizedQuery has no sample state beyond the numerator state.
|
||||||
return self._numerator.initial_sample_state(global_state, template)
|
return self._numerator.initial_sample_state(
|
||||||
|
global_state.numerator_state, template)
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
return self._numerator.preprocess_record(params, record)
|
return self._numerator.preprocess_record(params, record)
|
||||||
|
@ -72,16 +82,17 @@ class NormalizedQuery(dp_query.DPQuery):
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
|
||||||
sample_state, global_state)
|
sample_state, global_state.numerator_state)
|
||||||
def normalize(v):
|
def normalize(v):
|
||||||
return tf.truediv(v, self._denominator)
|
return tf.truediv(v, global_state.denominator)
|
||||||
|
|
||||||
return nest.map_structure(normalize, noised_sum), new_sum_global_state
|
return (nest.map_structure(normalize, noised_sum),
|
||||||
|
self._GlobalState(new_sum_global_state, global_state.denominator))
|
||||||
|
|
||||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._numerator.merge_sample_states(sample_state_1, sample_state_2)
|
return self._numerator.merge_sample_states(sample_state_1, sample_state_2)
|
||||||
|
|
||||||
def set_denominator(self, denominator):
|
def set_denominator(self, global_state, denominator):
|
||||||
"""Sets the denominator for the NormalizedQuery."""
|
"""Returns an updated global_state with the given denominator."""
|
||||||
self._denominator = tf.cast(denominator, tf.float32)
|
return global_state._replace(denominator=tf.cast(denominator, tf.float32))
|
||||||
|
|
|
@ -45,7 +45,13 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
_GlobalState = collections.namedtuple(
|
_GlobalState = collections.namedtuple(
|
||||||
'_GlobalState', ['l2_norm_clip', 'sum_state', 'clipped_fraction_state'])
|
'_GlobalState', [
|
||||||
|
'l2_norm_clip',
|
||||||
|
'noise_multiplier',
|
||||||
|
'target_unclipped_quantile',
|
||||||
|
'learning_rate',
|
||||||
|
'sum_state',
|
||||||
|
'clipped_fraction_state'])
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
_SampleState = collections.namedtuple(
|
_SampleState = collections.namedtuple(
|
||||||
|
@ -75,8 +81,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
found for which approximately 20% of updates are clipped each round.
|
found for which approximately 20% of updates are clipped each round.
|
||||||
learning_rate: The learning rate for the clipping norm adaptation. A
|
learning_rate: The learning rate for the clipping norm adaptation. A
|
||||||
rate of r means that the clipping norm will change by a maximum of r at
|
rate of r means that the clipping norm will change by a maximum of r at
|
||||||
each step. This maximum is attained when |clip - target| is 1.0. Can be
|
each step. This maximum is attained when |clip - target| is 1.0.
|
||||||
a tf.Variable for example to implement a learning rate schedule.
|
|
||||||
clipped_count_stddev: The stddev of the noise added to the clipped_count.
|
clipped_count_stddev: The stddev of the noise added to the clipped_count.
|
||||||
Since the sensitivity of the clipped count is 0.5, as a rule of thumb it
|
Since the sensitivity of the clipped count is 0.5, as a rule of thumb it
|
||||||
should be about 0.5 for reasonable privacy.
|
should be about 0.5 for reasonable privacy.
|
||||||
|
@ -84,19 +89,14 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
estimate the clipped count quantile.
|
estimate the clipped count quantile.
|
||||||
ledger: The privacy ledger to which queries should be recorded.
|
ledger: The privacy ledger to which queries should be recorded.
|
||||||
"""
|
"""
|
||||||
self._initial_l2_norm_clip = tf.cast(initial_l2_norm_clip, tf.float32)
|
self._initial_l2_norm_clip = initial_l2_norm_clip
|
||||||
self._noise_multiplier = tf.cast(noise_multiplier, tf.float32)
|
self._noise_multiplier = noise_multiplier
|
||||||
self._target_unclipped_quantile = tf.cast(
|
self._target_unclipped_quantile = target_unclipped_quantile
|
||||||
target_unclipped_quantile, tf.float32)
|
self._learning_rate = learning_rate
|
||||||
self._learning_rate = tf.cast(learning_rate, tf.float32)
|
|
||||||
|
|
||||||
self._l2_norm_clip = tf.Variable(self._initial_l2_norm_clip)
|
# Initialize sum query's global state with None, to be set later.
|
||||||
self._sum_stddev = tf.Variable(
|
|
||||||
self._initial_l2_norm_clip * self._noise_multiplier)
|
|
||||||
self._sum_query = gaussian_query.GaussianSumQuery(
|
self._sum_query = gaussian_query.GaussianSumQuery(
|
||||||
self._l2_norm_clip,
|
None, None, ledger)
|
||||||
self._sum_stddev,
|
|
||||||
ledger)
|
|
||||||
|
|
||||||
# self._clipped_fraction_query is a DPQuery used to estimate the fraction of
|
# self._clipped_fraction_query is a DPQuery used to estimate the fraction of
|
||||||
# records that are clipped. It accumulates an indicator 0/1 of whether each
|
# records that are clipped. It accumulates an indicator 0/1 of whether each
|
||||||
|
@ -115,29 +115,40 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
|
initial_l2_norm_clip = tf.cast(self._initial_l2_norm_clip, tf.float32)
|
||||||
|
noise_multiplier = tf.cast(self._noise_multiplier, tf.float32)
|
||||||
|
target_unclipped_quantile = tf.cast(self._target_unclipped_quantile,
|
||||||
|
tf.float32)
|
||||||
|
learning_rate = tf.cast(self._learning_rate, tf.float32)
|
||||||
|
sum_stddev = initial_l2_norm_clip * noise_multiplier
|
||||||
|
|
||||||
|
sum_query_global_state = self._sum_query.make_global_state(
|
||||||
|
l2_norm_clip=initial_l2_norm_clip,
|
||||||
|
stddev=sum_stddev)
|
||||||
|
|
||||||
return self._GlobalState(
|
return self._GlobalState(
|
||||||
self._initial_l2_norm_clip,
|
initial_l2_norm_clip,
|
||||||
self._sum_query.initial_global_state(),
|
noise_multiplier,
|
||||||
|
target_unclipped_quantile,
|
||||||
|
learning_rate,
|
||||||
|
sum_query_global_state,
|
||||||
self._clipped_fraction_query.initial_global_state())
|
self._clipped_fraction_query.initial_global_state())
|
||||||
|
|
||||||
def derive_sample_params(self, global_state):
|
def derive_sample_params(self, global_state):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
gs = global_state
|
|
||||||
|
|
||||||
# Assign values to variables that inner sum query uses.
|
# Assign values to variables that inner sum query uses.
|
||||||
tf.assign(self._l2_norm_clip, gs.l2_norm_clip)
|
sum_params = self._sum_query.derive_sample_params(global_state.sum_state)
|
||||||
tf.assign(self._sum_stddev, gs.l2_norm_clip * self._noise_multiplier)
|
|
||||||
sum_params = self._sum_query.derive_sample_params(gs.sum_state)
|
|
||||||
clipped_fraction_params = self._clipped_fraction_query.derive_sample_params(
|
clipped_fraction_params = self._clipped_fraction_query.derive_sample_params(
|
||||||
gs.clipped_fraction_state)
|
global_state.clipped_fraction_state)
|
||||||
return self._SampleParams(sum_params, clipped_fraction_params)
|
return self._SampleParams(sum_params, clipped_fraction_params)
|
||||||
|
|
||||||
def initial_sample_state(self, global_state, template):
|
def initial_sample_state(self, global_state, template):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
clipped_fraction_state = self._clipped_fraction_query.initial_sample_state(
|
|
||||||
global_state.clipped_fraction_state, tf.constant(0.0))
|
|
||||||
sum_state = self._sum_query.initial_sample_state(
|
sum_state = self._sum_query.initial_sample_state(
|
||||||
global_state.sum_state, template)
|
global_state.sum_state, template)
|
||||||
|
clipped_fraction_state = self._clipped_fraction_query.initial_sample_state(
|
||||||
|
global_state.clipped_fraction_state, tf.constant(0.0))
|
||||||
return self._SampleState(sum_state, clipped_fraction_state)
|
return self._SampleState(sum_state, clipped_fraction_state)
|
||||||
|
|
||||||
def preprocess_record(self, params, record):
|
def preprocess_record(self, params, record):
|
||||||
|
@ -187,6 +198,7 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
|
|
||||||
noised_vectors, sum_state = self._sum_query.get_noised_result(
|
noised_vectors, sum_state = self._sum_query.get_noised_result(
|
||||||
sample_state.sum_state, gs.sum_state)
|
sample_state.sum_state, gs.sum_state)
|
||||||
|
del sum_state # Unused. To be set explicitly later.
|
||||||
|
|
||||||
clipped_fraction_result, new_clipped_fraction_state = (
|
clipped_fraction_result, new_clipped_fraction_state = (
|
||||||
self._clipped_fraction_query.get_noised_result(
|
self._clipped_fraction_query.get_noised_result(
|
||||||
|
@ -202,15 +214,20 @@ class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||||
|
|
||||||
# Loss function is convex, with derivative in [-1, 1], and minimized when
|
# Loss function is convex, with derivative in [-1, 1], and minimized when
|
||||||
# the true quantile matches the target.
|
# the true quantile matches the target.
|
||||||
loss_grad = unclipped_quantile - self._target_unclipped_quantile
|
loss_grad = unclipped_quantile - global_state.target_unclipped_quantile
|
||||||
|
|
||||||
new_l2_norm_clip = gs.l2_norm_clip - self._learning_rate * loss_grad
|
new_l2_norm_clip = gs.l2_norm_clip - global_state.learning_rate * loss_grad
|
||||||
new_l2_norm_clip = tf.maximum(0.0, new_l2_norm_clip)
|
new_l2_norm_clip = tf.maximum(0.0, new_l2_norm_clip)
|
||||||
|
|
||||||
new_global_state = self._GlobalState(
|
new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier
|
||||||
new_l2_norm_clip,
|
new_sum_query_global_state = self._sum_query.make_global_state(
|
||||||
sum_state,
|
l2_norm_clip=new_l2_norm_clip,
|
||||||
new_clipped_fraction_state)
|
stddev=new_sum_stddev)
|
||||||
|
|
||||||
|
new_global_state = global_state._replace(
|
||||||
|
l2_norm_clip=new_l2_norm_clip,
|
||||||
|
sum_state=new_sum_query_global_state,
|
||||||
|
clipped_fraction_state=new_clipped_fraction_state)
|
||||||
|
|
||||||
return noised_vectors, new_global_state
|
return noised_vectors, new_global_state
|
||||||
|
|
||||||
|
|
|
@ -270,7 +270,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase):
|
||||||
tf.assign(selection_probability, 0.1)
|
tf.assign(selection_probability, 0.1)
|
||||||
_, global_state = test_utils.run_query(query, [record1, record2])
|
_, global_state = test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
expected_queries = [[0.5, 0.0], [10.0, 10.0]]
|
expected_queries = [[10.0, 10.0], [0.5, 0.0]]
|
||||||
formatted = ledger.get_formatted_ledger_eager()
|
formatted = ledger.get_formatted_ledger_eager()
|
||||||
sample_1 = formatted[0]
|
sample_1 = formatted[0]
|
||||||
self.assertAllClose(sample_1.population_size, 10.0)
|
self.assertAllClose(sample_1.population_size, 10.0)
|
||||||
|
@ -288,7 +288,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase):
|
||||||
self.assertAllClose(sample_1.selection_probability, 0.1)
|
self.assertAllClose(sample_1.selection_probability, 0.1)
|
||||||
self.assertAllClose(sample_1.queries, expected_queries)
|
self.assertAllClose(sample_1.queries, expected_queries)
|
||||||
|
|
||||||
expected_queries_2 = [[0.5, 0.0], [9.0, 9.0]]
|
expected_queries_2 = [[9.0, 9.0], [0.5, 0.0]]
|
||||||
self.assertAllClose(sample_2.population_size, 20.0)
|
self.assertAllClose(sample_2.population_size, 20.0)
|
||||||
self.assertAllClose(sample_2.selection_probability, 0.2)
|
self.assertAllClose(sample_2.selection_probability, 0.2)
|
||||||
self.assertAllClose(sample_2.queries, expected_queries_2)
|
self.assertAllClose(sample_2.queries, expected_queries_2)
|
||||||
|
|
|
@ -88,7 +88,9 @@ def make_optimizer_class(cls):
|
||||||
vector_loss = loss()
|
vector_loss = loss()
|
||||||
if self._num_microbatches is None:
|
if self._num_microbatches is None:
|
||||||
self._num_microbatches = tf.shape(vector_loss)[0]
|
self._num_microbatches = tf.shape(vector_loss)[0]
|
||||||
self._dp_average_query.set_denominator(self._num_microbatches)
|
self._global_state = self._dp_average_query.set_denominator(
|
||||||
|
self._global_state,
|
||||||
|
self._num_microbatches)
|
||||||
sample_state = self._dp_average_query.initial_sample_state(
|
sample_state = self._dp_average_query.initial_sample_state(
|
||||||
self._global_state, var_list)
|
self._global_state, var_list)
|
||||||
microbatches_losses = tf.reshape(vector_loss,
|
microbatches_losses = tf.reshape(vector_loss,
|
||||||
|
@ -126,7 +128,9 @@ def make_optimizer_class(cls):
|
||||||
# sampling from the dataset without replacement.
|
# sampling from the dataset without replacement.
|
||||||
if self._num_microbatches is None:
|
if self._num_microbatches is None:
|
||||||
self._num_microbatches = tf.shape(loss)[0]
|
self._num_microbatches = tf.shape(loss)[0]
|
||||||
self._dp_average_query.set_denominator(self._num_microbatches)
|
self._global_state = self._dp_average_query.set_denominator(
|
||||||
|
self._global_state,
|
||||||
|
self._num_microbatches)
|
||||||
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
||||||
sample_params = (
|
sample_params = (
|
||||||
self._dp_average_query.derive_sample_params(self._global_state))
|
self._dp_average_query.derive_sample_params(self._global_state))
|
||||||
|
|
Loading…
Reference in a new issue