Remove GaussianAverageQuery. Users can simply wrap GaussianSumQuery with a NormalizedQuery.
PiperOrigin-RevId: 374784618
This commit is contained in:
parent
1de7e4dde4
commit
e5848656ed
7 changed files with 96 additions and 189 deletions
|
@ -39,7 +39,6 @@ else:
|
||||||
# DPQuery classes
|
# DPQuery classes
|
||||||
from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery
|
from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.dp_query import SumAggregationDPQuery
|
from tensorflow_privacy.privacy.dp_query.dp_query import SumAggregationDPQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianAverageQuery
|
|
||||||
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery
|
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery
|
from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery
|
||||||
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery
|
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery
|
||||||
|
|
|
@ -11,7 +11,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for PrivacyLedger."""
|
"""Tests for PrivacyLedger."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
@ -57,10 +56,9 @@ class PrivacyLedgerTest(tf.test.TestCase):
|
||||||
population_size = tf.Variable(0)
|
population_size = tf.Variable(0)
|
||||||
selection_probability = tf.Variable(1.0)
|
selection_probability = tf.Variable(1.0)
|
||||||
|
|
||||||
query = gaussian_query.GaussianSumQuery(
|
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||||
l2_norm_clip=10.0, stddev=0.0)
|
query = privacy_ledger.QueryWithLedger(query, population_size,
|
||||||
query = privacy_ledger.QueryWithLedger(
|
selection_probability)
|
||||||
query, population_size, selection_probability)
|
|
||||||
|
|
||||||
# First sample.
|
# First sample.
|
||||||
tf.assign(population_size, 10)
|
tf.assign(population_size, 10)
|
||||||
|
@ -93,14 +91,12 @@ class PrivacyLedgerTest(tf.test.TestCase):
|
||||||
population_size = tf.Variable(0)
|
population_size = tf.Variable(0)
|
||||||
selection_probability = tf.Variable(1.0)
|
selection_probability = tf.Variable(1.0)
|
||||||
|
|
||||||
query1 = gaussian_query.GaussianAverageQuery(
|
query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=4.0, stddev=2.0)
|
||||||
l2_norm_clip=4.0, sum_stddev=2.0, denominator=5.0)
|
query2 = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=1.0)
|
||||||
query2 = gaussian_query.GaussianAverageQuery(
|
|
||||||
l2_norm_clip=5.0, sum_stddev=1.0, denominator=5.0)
|
|
||||||
|
|
||||||
query = nested_query.NestedQuery([query1, query2])
|
query = nested_query.NestedQuery([query1, query2])
|
||||||
query = privacy_ledger.QueryWithLedger(
|
query = privacy_ledger.QueryWithLedger(query, population_size,
|
||||||
query, population_size, selection_probability)
|
selection_probability)
|
||||||
|
|
||||||
record1 = [1.0, [12.0, 9.0]]
|
record1 = [1.0, [12.0, 9.0]]
|
||||||
record2 = [5.0, [1.0, 2.0]]
|
record2 = [5.0, [1.0, 2.0]]
|
||||||
|
|
|
@ -11,9 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""Implements DPQuery interface for Gaussian sum queries."""
|
||||||
"""Implements DPQuery interface for Gaussian average queries.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
@ -25,7 +23,6 @@ import distutils
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import normalized_query
|
|
||||||
|
|
||||||
|
|
||||||
class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
@ -35,8 +32,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
_GlobalState = collections.namedtuple(
|
_GlobalState = collections.namedtuple('_GlobalState',
|
||||||
'_GlobalState', ['l2_norm_clip', 'stddev'])
|
['l2_norm_clip', 'stddev'])
|
||||||
|
|
||||||
def __init__(self, l2_norm_clip, stddev):
|
def __init__(self, l2_norm_clip, stddev):
|
||||||
"""Initializes the GaussianSumQuery.
|
"""Initializes the GaussianSumQuery.
|
||||||
|
@ -55,8 +52,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
def make_global_state(self, l2_norm_clip, stddev):
|
def make_global_state(self, l2_norm_clip, stddev):
|
||||||
"""Creates a global state from the given parameters."""
|
"""Creates a global state from the given parameters."""
|
||||||
return self._GlobalState(tf.cast(l2_norm_clip, tf.float32),
|
return self._GlobalState(
|
||||||
tf.cast(stddev, tf.float32))
|
tf.cast(l2_norm_clip, tf.float32), tf.cast(stddev, tf.float32))
|
||||||
|
|
||||||
def initial_global_state(self):
|
def initial_global_state(self):
|
||||||
return self.make_global_state(self._l2_norm_clip, self._stddev)
|
return self.make_global_state(self._l2_norm_clip, self._stddev)
|
||||||
|
@ -94,48 +91,17 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||||
return v + tf.random.normal(
|
return v + tf.random.normal(
|
||||||
tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype)
|
tf.shape(input=v), stddev=global_state.stddev, dtype=v.dtype)
|
||||||
else:
|
else:
|
||||||
random_normal = tf.random_normal_initializer(
|
random_normal = tf.random_normal_initializer(stddev=global_state.stddev)
|
||||||
stddev=global_state.stddev)
|
|
||||||
|
|
||||||
def add_noise(v):
|
def add_noise(v):
|
||||||
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
|
||||||
|
|
||||||
if self._ledger:
|
if self._ledger:
|
||||||
dependencies = [
|
dependencies = [
|
||||||
self._ledger.record_sum_query(
|
self._ledger.record_sum_query(global_state.l2_norm_clip,
|
||||||
global_state.l2_norm_clip, global_state.stddev)
|
global_state.stddev)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
dependencies = []
|
dependencies = []
|
||||||
with tf.control_dependencies(dependencies):
|
with tf.control_dependencies(dependencies):
|
||||||
return tf.nest.map_structure(add_noise, sample_state), global_state
|
return tf.nest.map_structure(add_noise, sample_state), global_state
|
||||||
|
|
||||||
|
|
||||||
class GaussianAverageQuery(normalized_query.NormalizedQuery):
|
|
||||||
"""Implements DPQuery interface for Gaussian average queries.
|
|
||||||
|
|
||||||
Accumulates clipped vectors, adds Gaussian noise, and normalizes.
|
|
||||||
|
|
||||||
Note that we use "fixed-denominator" estimation: the denominator should be
|
|
||||||
specified as the expected number of records per sample. Accumulating the
|
|
||||||
denominator separately would also be possible but would be produce a higher
|
|
||||||
variance estimator.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
l2_norm_clip,
|
|
||||||
sum_stddev,
|
|
||||||
denominator):
|
|
||||||
"""Initializes the GaussianAverageQuery.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
l2_norm_clip: The clipping norm to apply to the global norm of each
|
|
||||||
record.
|
|
||||||
sum_stddev: The stddev of the noise added to the sum (before
|
|
||||||
normalization).
|
|
||||||
denominator: The normalization constant (applied after noise is added to
|
|
||||||
the sum).
|
|
||||||
"""
|
|
||||||
super(GaussianAverageQuery, self).__init__(
|
|
||||||
numerator_query=GaussianSumQuery(l2_norm_clip, sum_stddev),
|
|
||||||
denominator=denominator)
|
|
||||||
|
|
|
@ -11,8 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""Tests for GaussianSumQuery."""
|
||||||
"""Tests for GaussianAverageQuery."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
@ -34,8 +33,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record1 = tf.constant([2.0, 0.0])
|
record1 = tf.constant([2.0, 0.0])
|
||||||
record2 = tf.constant([-1.0, 1.0])
|
record2 = tf.constant([-1.0, 1.0])
|
||||||
|
|
||||||
query = gaussian_query.GaussianSumQuery(
|
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||||
l2_norm_clip=10.0, stddev=0.0)
|
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [1.0, 1.0]
|
expected = [1.0, 1.0]
|
||||||
|
@ -46,8 +44,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
||||||
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||||
|
|
||||||
query = gaussian_query.GaussianSumQuery(
|
query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0)
|
||||||
l2_norm_clip=5.0, stddev=0.0)
|
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
result = sess.run(query_result)
|
||||||
expected = [1.0, 1.0]
|
expected = [1.0, 1.0]
|
||||||
|
@ -80,8 +77,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record1, record2 = 2.71828, 3.14159
|
record1, record2 = 2.71828, 3.14159
|
||||||
stddev = 1.0
|
stddev = 1.0
|
||||||
|
|
||||||
query = gaussian_query.GaussianSumQuery(
|
query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=stddev)
|
||||||
l2_norm_clip=5.0, stddev=stddev)
|
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
noised_sums = []
|
noised_sums = []
|
||||||
|
@ -108,8 +104,7 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
sample_state_2 = get_sample_state(records2)
|
sample_state_2 = get_sample_state(records2)
|
||||||
|
|
||||||
merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states(
|
merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states(
|
||||||
sample_state_1,
|
sample_state_1, sample_state_2)
|
||||||
sample_state_2)
|
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
result = sess.run(merged)
|
result = sess.run(merged)
|
||||||
|
@ -117,36 +112,6 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
expected = [3.0, 10.0]
|
expected = [3.0, 10.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
|
||||||
def test_gaussian_average_no_noise(self):
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].
|
|
||||||
record2 = tf.constant([-1.0, 2.0]) # Not clipped.
|
|
||||||
|
|
||||||
query = gaussian_query.GaussianAverageQuery(
|
|
||||||
l2_norm_clip=3.0, sum_stddev=0.0, denominator=2.0)
|
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
|
||||||
result = sess.run(query_result)
|
|
||||||
expected_average = [1.0, 1.0]
|
|
||||||
self.assertAllClose(result, expected_average)
|
|
||||||
|
|
||||||
def test_gaussian_average_with_noise(self):
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
record1, record2 = 2.71828, 3.14159
|
|
||||||
sum_stddev = 1.0
|
|
||||||
denominator = 2.0
|
|
||||||
|
|
||||||
query = gaussian_query.GaussianAverageQuery(
|
|
||||||
l2_norm_clip=5.0, sum_stddev=sum_stddev, denominator=denominator)
|
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
|
||||||
|
|
||||||
noised_averages = []
|
|
||||||
for _ in range(1000):
|
|
||||||
noised_averages.append(sess.run(query_result))
|
|
||||||
|
|
||||||
result_stddev = np.std(noised_averages)
|
|
||||||
avg_stddev = sum_stddev / denominator
|
|
||||||
self.assertNear(result_stddev, avg_stddev, 0.1)
|
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('type_mismatch', [1.0], (1.0,), TypeError),
|
('type_mismatch', [1.0], (1.0,), TypeError),
|
||||||
('too_few_on_left', [1.0], [1.0, 1.0], ValueError),
|
('too_few_on_left', [1.0], [1.0, 1.0], ValueError),
|
||||||
|
|
|
@ -11,7 +11,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for NestedQuery."""
|
"""Tests for NestedQuery."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
@ -27,9 +26,9 @@ import tensorflow.compat.v1 as tf
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
from tensorflow_privacy.privacy.dp_query import nested_query
|
from tensorflow_privacy.privacy.dp_query import nested_query
|
||||||
|
from tensorflow_privacy.privacy.dp_query import normalized_query
|
||||||
from tensorflow_privacy.privacy.dp_query import test_utils
|
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||||
|
|
||||||
|
|
||||||
_basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
_basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,10 +36,8 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_nested_gaussian_sum_no_clip_no_noise(self):
|
def test_nested_gaussian_sum_no_clip_no_noise(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
query1 = gaussian_query.GaussianSumQuery(
|
query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||||
l2_norm_clip=10.0, stddev=0.0)
|
query2 = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||||
query2 = gaussian_query.GaussianSumQuery(
|
|
||||||
l2_norm_clip=10.0, stddev=0.0)
|
|
||||||
|
|
||||||
query = nested_query.NestedSumQuery([query1, query2])
|
query = nested_query.NestedSumQuery([query1, query2])
|
||||||
|
|
||||||
|
@ -52,29 +49,14 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
expected = [5.0, [5.0, 5.0]]
|
expected = [5.0, [5.0, 5.0]]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(result, expected)
|
||||||
|
|
||||||
def test_nested_gaussian_average_no_clip_no_noise(self):
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
query1 = gaussian_query.GaussianAverageQuery(
|
|
||||||
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
|
|
||||||
query2 = gaussian_query.GaussianAverageQuery(
|
|
||||||
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
|
|
||||||
|
|
||||||
query = nested_query.NestedSumQuery([query1, query2])
|
|
||||||
|
|
||||||
record1 = [1.0, [2.0, 3.0]]
|
|
||||||
record2 = [4.0, [3.0, 2.0]]
|
|
||||||
|
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [1.0, [1.0, 1.0]]
|
|
||||||
self.assertAllClose(result, expected)
|
|
||||||
|
|
||||||
def test_nested_gaussian_average_with_clip_no_noise(self):
|
def test_nested_gaussian_average_with_clip_no_noise(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
query1 = gaussian_query.GaussianAverageQuery(
|
query1 = normalized_query.NormalizedQuery(
|
||||||
l2_norm_clip=4.0, sum_stddev=0.0, denominator=5.0)
|
gaussian_query.GaussianSumQuery(l2_norm_clip=4.0, stddev=0.0),
|
||||||
query2 = gaussian_query.GaussianAverageQuery(
|
denominator=5.0)
|
||||||
l2_norm_clip=5.0, sum_stddev=0.0, denominator=5.0)
|
query2 = normalized_query.NormalizedQuery(
|
||||||
|
gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0),
|
||||||
|
denominator=5.0)
|
||||||
|
|
||||||
query = nested_query.NestedSumQuery([query1, query2])
|
query = nested_query.NestedSumQuery([query1, query2])
|
||||||
|
|
||||||
|
@ -88,15 +70,17 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_complex_nested_query(self):
|
def test_complex_nested_query(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
query_ab = gaussian_query.GaussianSumQuery(
|
query_ab = gaussian_query.GaussianSumQuery(l2_norm_clip=1.0, stddev=0.0)
|
||||||
l2_norm_clip=1.0, stddev=0.0)
|
query_c = normalized_query.NormalizedQuery(
|
||||||
query_c = gaussian_query.GaussianAverageQuery(
|
gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0),
|
||||||
l2_norm_clip=10.0, sum_stddev=0.0, denominator=2.0)
|
denominator=2.0)
|
||||||
query_d = gaussian_query.GaussianSumQuery(
|
query_d = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||||
l2_norm_clip=10.0, stddev=0.0)
|
|
||||||
|
|
||||||
query = nested_query.NestedSumQuery(
|
query = nested_query.NestedSumQuery(
|
||||||
[query_ab, {'c': query_c, 'd': [query_d]}])
|
[query_ab, {
|
||||||
|
'c': query_c,
|
||||||
|
'd': [query_d]
|
||||||
|
}])
|
||||||
|
|
||||||
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
|
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
|
||||||
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
|
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
|
||||||
|
@ -108,13 +92,13 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_nested_query_with_noise(self):
|
def test_nested_query_with_noise(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sum_stddev = 2.71828
|
stddev = 2.71828
|
||||||
denominator = 3.14159
|
denominator = 3.14159
|
||||||
|
|
||||||
query1 = gaussian_query.GaussianSumQuery(
|
query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=1.5, stddev=stddev)
|
||||||
l2_norm_clip=1.5, stddev=sum_stddev)
|
query2 = normalized_query.NormalizedQuery(
|
||||||
query2 = gaussian_query.GaussianAverageQuery(
|
gaussian_query.GaussianSumQuery(l2_norm_clip=0.5, stddev=stddev),
|
||||||
l2_norm_clip=0.5, sum_stddev=sum_stddev, denominator=denominator)
|
denominator=denominator)
|
||||||
query = nested_query.NestedSumQuery((query1, query2))
|
query = nested_query.NestedSumQuery((query1, query2))
|
||||||
|
|
||||||
record1 = (3.0, [2.0, 1.5])
|
record1 = (3.0, [2.0, 1.5])
|
||||||
|
@ -127,20 +111,20 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
noised_averages.append(tf.nest.flatten(sess.run(query_result)))
|
noised_averages.append(tf.nest.flatten(sess.run(query_result)))
|
||||||
|
|
||||||
result_stddev = np.std(noised_averages, 0)
|
result_stddev = np.std(noised_averages, 0)
|
||||||
avg_stddev = sum_stddev / denominator
|
avg_stddev = stddev / denominator
|
||||||
expected_stddev = [sum_stddev, avg_stddev, avg_stddev]
|
expected_stddev = [stddev, avg_stddev, avg_stddev]
|
||||||
self.assertArrayNear(result_stddev, expected_stddev, 0.1)
|
self.assertArrayNear(result_stddev, expected_stddev, 0.1)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('type_mismatch', [_basic_query], (1.0,), TypeError),
|
('type_mismatch', [_basic_query], (1.0,), TypeError),
|
||||||
('too_many_queries', [_basic_query, _basic_query], [1.0], ValueError),
|
('too_many_queries', [_basic_query, _basic_query], [1.0], ValueError),
|
||||||
('query_too_deep', [_basic_query, [_basic_query]], [1.0, 1.0], TypeError))
|
('query_too_deep', [_basic_query, [_basic_query]], [1.0, 1.0], TypeError))
|
||||||
def test_record_incompatible_with_query(
|
def test_record_incompatible_with_query(self, queries, record, error_type):
|
||||||
self, queries, record, error_type):
|
|
||||||
with self.assertRaises(error_type):
|
with self.assertRaises(error_type):
|
||||||
test_utils.run_query(nested_query.NestedSumQuery(queries), [record])
|
test_utils.run_query(nested_query.NestedSumQuery(queries), [record])
|
||||||
|
|
||||||
def test_raises_with_non_sum(self):
|
def test_raises_with_non_sum(self):
|
||||||
|
|
||||||
class NonSumDPQuery(dp_query.DPQuery):
|
class NonSumDPQuery(dp_query.DPQuery):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -154,6 +138,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
nested_query.NestedSumQuery(non_sum_query)
|
nested_query.NestedSumQuery(non_sum_query)
|
||||||
|
|
||||||
def test_metrics(self):
|
def test_metrics(self):
|
||||||
|
|
||||||
class QueryWithMetric(dp_query.SumAggregationDPQuery):
|
class QueryWithMetric(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
def __init__(self, metric_val):
|
def __init__(self, metric_val):
|
||||||
|
@ -171,8 +156,10 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
metric_val = nested_a.derive_metrics(global_state)
|
metric_val = nested_a.derive_metrics(global_state)
|
||||||
self.assertEqual(metric_val['metric'], 1)
|
self.assertEqual(metric_val['metric'], 1)
|
||||||
|
|
||||||
nested_b = nested_query.NestedSumQuery(
|
nested_b = nested_query.NestedSumQuery({
|
||||||
{'key1': query1, 'key2': [query2, query3]})
|
'key1': query1,
|
||||||
|
'key2': [query2, query3]
|
||||||
|
})
|
||||||
global_state = nested_b.initial_global_state()
|
global_state = nested_b.initial_global_state()
|
||||||
metric_val = nested_b.derive_metrics(global_state)
|
metric_val = nested_b.derive_metrics(global_state)
|
||||||
self.assertEqual(metric_val['key1/metric'], 1)
|
self.assertEqual(metric_val['key1/metric'], 1)
|
||||||
|
|
|
@ -11,8 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""Tests for NormalizedQuery."""
|
||||||
"""Tests for GaussianAverageQuery."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
@ -32,8 +31,7 @@ class NormalizedQueryTest(tf.test.TestCase):
|
||||||
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
||||||
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||||
|
|
||||||
sum_query = gaussian_query.GaussianSumQuery(
|
sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0)
|
||||||
l2_norm_clip=5.0, stddev=0.0)
|
|
||||||
query = normalized_query.NormalizedQuery(
|
query = normalized_query.NormalizedQuery(
|
||||||
numerator_query=sum_query, denominator=2.0)
|
numerator_query=sum_query, denominator=2.0)
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Implements DPQuery interface for quantile estimator.
|
"""Implements DPQuery interface for quantile estimator.
|
||||||
|
|
||||||
From a starting estimate of the target quantile, the estimate is updated
|
From a starting estimate of the target quantile, the estimate is updated
|
||||||
|
@ -27,23 +26,20 @@ from __future__ import print_function
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||||
from tensorflow_privacy.privacy.dp_query import no_privacy_query
|
from tensorflow_privacy.privacy.dp_query import no_privacy_query
|
||||||
|
from tensorflow_privacy.privacy.dp_query import normalized_query
|
||||||
|
|
||||||
|
|
||||||
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
"""Iterative process to estimate target quantile of a univariate distribution.
|
"""Iterative process to estimate target quantile of a univariate distribution."""
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
_GlobalState = collections.namedtuple(
|
_GlobalState = collections.namedtuple('_GlobalState', [
|
||||||
'_GlobalState', [
|
'current_estimate', 'target_quantile', 'learning_rate',
|
||||||
'current_estimate',
|
'below_estimate_state'
|
||||||
'target_quantile',
|
])
|
||||||
'learning_rate',
|
|
||||||
'below_estimate_state'])
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
_SampleParams = collections.namedtuple(
|
_SampleParams = collections.namedtuple(
|
||||||
|
@ -52,8 +48,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
# No separate SampleState-- sample state is just below_estimate_query's
|
# No separate SampleState-- sample state is just below_estimate_query's
|
||||||
# SampleState.
|
# SampleState.
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
initial_estimate,
|
initial_estimate,
|
||||||
target_quantile,
|
target_quantile,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
|
@ -65,11 +60,11 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
Args:
|
Args:
|
||||||
initial_estimate: The initial estimate of the quantile.
|
initial_estimate: The initial estimate of the quantile.
|
||||||
target_quantile: The target quantile. I.e., a value of 0.8 means a value
|
target_quantile: The target quantile. I.e., a value of 0.8 means a value
|
||||||
should be found for which approximately 80% of updates are
|
should be found for which approximately 80% of updates are less than the
|
||||||
less than the estimate each round.
|
estimate each round.
|
||||||
learning_rate: The learning rate. A rate of r means that the estimate
|
learning_rate: The learning rate. A rate of r means that the estimate will
|
||||||
will change by a maximum of r at each step (for arithmetic updating) or
|
change by a maximum of r at each step (for arithmetic updating) or by a
|
||||||
by a maximum factor of exp(r) (for geometric updating).
|
maximum factor of exp(r) (for geometric updating).
|
||||||
below_estimate_stddev: The stddev of the noise added to the count of
|
below_estimate_stddev: The stddev of the noise added to the count of
|
||||||
records currently below the estimate. Since the sensitivity of the count
|
records currently below the estimate. Since the sensitivity of the count
|
||||||
query is 0.5, as a rule of thumb it should be about 0.5 for reasonable
|
query is 0.5, as a rule of thumb it should be about 0.5 for reasonable
|
||||||
|
@ -90,8 +85,8 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
|
|
||||||
self._geometric_update = geometric_update
|
self._geometric_update = geometric_update
|
||||||
|
|
||||||
def _construct_below_estimate_query(
|
def _construct_below_estimate_query(self, below_estimate_stddev,
|
||||||
self, below_estimate_stddev, expected_num_records):
|
expected_num_records):
|
||||||
# A DPQuery used to estimate the fraction of records that are less than the
|
# A DPQuery used to estimate the fraction of records that are less than the
|
||||||
# current quantile estimate. It accumulates an indicator 0/1 of whether each
|
# current quantile estimate. It accumulates an indicator 0/1 of whether each
|
||||||
# record is below the estimate, and normalizes by the expected number of
|
# record is below the estimate, and normalizes by the expected number of
|
||||||
|
@ -101,9 +96,9 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
# affect the count is 0.5. Note that although the l2_norm_clip of the
|
# affect the count is 0.5. Note that although the l2_norm_clip of the
|
||||||
# below_estimate query is 0.5, no clipping will ever actually occur
|
# below_estimate query is 0.5, no clipping will ever actually occur
|
||||||
# because the value of each record is always +/-0.5.
|
# because the value of each record is always +/-0.5.
|
||||||
return gaussian_query.GaussianAverageQuery(
|
return normalized_query.NormalizedQuery(
|
||||||
l2_norm_clip=0.5,
|
gaussian_query.GaussianSumQuery(
|
||||||
sum_stddev=below_estimate_stddev,
|
l2_norm_clip=0.5, stddev=below_estimate_stddev),
|
||||||
denominator=expected_num_records)
|
denominator=expected_num_records)
|
||||||
|
|
||||||
def set_ledger(self, ledger):
|
def set_ledger(self, ledger):
|
||||||
|
@ -140,8 +135,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
|
||||||
def get_noised_result(self, sample_state, global_state):
|
def get_noised_result(self, sample_state, global_state):
|
||||||
below_estimate_result, new_below_estimate_state = (
|
below_estimate_result, new_below_estimate_state = (
|
||||||
self._below_estimate_query.get_noised_result(
|
self._below_estimate_query.get_noised_result(
|
||||||
sample_state,
|
sample_state, global_state.below_estimate_state))
|
||||||
global_state.below_estimate_state))
|
|
||||||
|
|
||||||
# Unshift below_estimate percentile by 0.5. (See comment in initializer.)
|
# Unshift below_estimate percentile by 0.5. (See comment in initializer.)
|
||||||
below_estimate = below_estimate_result + 0.5
|
below_estimate = below_estimate_result + 0.5
|
||||||
|
@ -177,8 +171,7 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||||
below estimate with an exact denominator.
|
below estimate with an exact denominator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
initial_estimate,
|
initial_estimate,
|
||||||
target_quantile,
|
target_quantile,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
|
@ -188,22 +181,25 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
|
||||||
Args:
|
Args:
|
||||||
initial_estimate: The initial estimate of the quantile.
|
initial_estimate: The initial estimate of the quantile.
|
||||||
target_quantile: The target quantile. I.e., a value of 0.8 means a value
|
target_quantile: The target quantile. I.e., a value of 0.8 means a value
|
||||||
should be found for which approximately 80% of updates are
|
should be found for which approximately 80% of updates are less than the
|
||||||
less than the estimate each round.
|
estimate each round.
|
||||||
learning_rate: The learning rate. A rate of r means that the estimate
|
learning_rate: The learning rate. A rate of r means that the estimate will
|
||||||
will change by a maximum of r at each step (for arithmetic updating) or
|
change by a maximum of r at each step (for arithmetic updating) or by a
|
||||||
by a maximum factor of exp(r) (for geometric updating).
|
maximum factor of exp(r) (for geometric updating).
|
||||||
geometric_update: If True, use geometric updating of estimate. Geometric
|
geometric_update: If True, use geometric updating of estimate. Geometric
|
||||||
updating is preferred for non-negative records like vector norms that
|
updating is preferred for non-negative records like vector norms that
|
||||||
could potentially be very large or very close to zero.
|
could potentially be very large or very close to zero.
|
||||||
"""
|
"""
|
||||||
super(NoPrivacyQuantileEstimatorQuery, self).__init__(
|
super(NoPrivacyQuantileEstimatorQuery, self).__init__(
|
||||||
initial_estimate, target_quantile, learning_rate,
|
initial_estimate,
|
||||||
below_estimate_stddev=None, expected_num_records=None,
|
target_quantile,
|
||||||
|
learning_rate,
|
||||||
|
below_estimate_stddev=None,
|
||||||
|
expected_num_records=None,
|
||||||
geometric_update=geometric_update)
|
geometric_update=geometric_update)
|
||||||
|
|
||||||
def _construct_below_estimate_query(
|
def _construct_below_estimate_query(self, below_estimate_stddev,
|
||||||
self, below_estimate_stddev, expected_num_records):
|
expected_num_records):
|
||||||
del below_estimate_stddev
|
del below_estimate_stddev
|
||||||
del expected_num_records
|
del expected_num_records
|
||||||
return no_privacy_query.NoPrivacyAverageQuery()
|
return no_privacy_query.NoPrivacyAverageQuery()
|
||||||
|
|
Loading…
Reference in a new issue