tensorflow_privacy/research/pate_2018/smooth_sensitivity_test.py

127 lines
4.5 KiB
Python
Raw Permalink Normal View History

# Copyright 2017 The 'Scalable Private Learning with PATE' Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for pate.smooth_sensitivity."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
import smooth_sensitivity as pate_ss
class PateSmoothSensitivityTest(unittest.TestCase):
def test_check_conditions(self):
self.assertEqual(pate_ss.check_conditions(20, 10, 25.), (True, False))
self.assertEqual(pate_ss.check_conditions(30, 10, 25.), (True, True))
def _assert_all_close(self, x, y):
"""Asserts that two numpy arrays are close."""
self.assertEqual(len(x), len(y))
self.assertTrue(np.allclose(x, y, rtol=1e-8, atol=0))
def test_compute_local_sensitivity_bounds_gnmax(self):
counts1 = np.array([10, 0, 0])
sigma1 = .5
order1 = 1.5
answer1 = np.array(
[3.13503646e-17, 1.60178280e-08, 5.90681786e-03] + [5.99981308e+00] * 7)
# Test for "going right" in the smooth sensitivity computation.
out1 = pate_ss.compute_local_sensitivity_bounds_gnmax(
counts1, 10, sigma1, order1)
self._assert_all_close(out1, answer1)
counts2 = np.array([1000, 500, 300, 200, 0])
sigma2 = 250.
order2 = 10.
# Test for "going left" in the smooth sensitivity computation.
out2 = pate_ss.compute_local_sensitivity_bounds_gnmax(
counts2, 2000, sigma2, order2)
answer2 = np.array([0.] * 298 + [2.77693450548e-7, 2.10853979548e-6] +
[2.73113623988e-6] * 1700)
self._assert_all_close(out2, answer2)
def test_compute_local_sensitivity_bounds_threshold(self):
counts1_3 = np.array([20, 10, 0])
num_teachers = sum(counts1_3)
t1 = 16 # high threshold
sigma = 2
order = 10
out1 = pate_ss.compute_local_sensitivity_bounds_threshold(
counts1_3, num_teachers, t1, sigma, order)
answer1 = np.array([0] * 3 + [
1.48454129e-04, 1.47826870e-02, 3.94153241e-02, 6.45775697e-02,
9.01543247e-02, 1.16054002e-01, 1.42180452e-01, 1.42180452e-01,
1.48454129e-04, 1.47826870e-02, 3.94153241e-02, 6.45775697e-02,
9.01543266e-02, 1.16054000e-01, 1.42180452e-01, 1.68302106e-01,
1.93127860e-01
] + [0] * 10)
self._assert_all_close(out1, answer1)
t2 = 2 # low threshold
out2 = pate_ss.compute_local_sensitivity_bounds_threshold(
counts1_3, num_teachers, t2, sigma, order)
answer2 = np.array([
1.60212079e-01, 2.07021132e-01, 2.07021132e-01, 1.93127860e-01,
1.68302106e-01, 1.42180452e-01, 1.16054002e-01, 9.01543247e-02,
6.45775697e-02, 3.94153241e-02, 1.47826870e-02, 1.48454129e-04
] + [0] * 18)
self._assert_all_close(out2, answer2)
t3 = 50 # very high threshold (larger than the number of teachers).
out3 = pate_ss.compute_local_sensitivity_bounds_threshold(
counts1_3, num_teachers, t3, sigma, order)
answer3 = np.array([
1.35750725752e-19, 1.88990500499e-17, 2.05403154065e-15,
1.74298153642e-13, 1.15489723995e-11, 5.97584949325e-10,
2.41486826748e-08, 7.62150641922e-07, 1.87846248741e-05,
0.000360973025976, 0.000360973025976, 2.76377015215e-50,
1.00904975276e-53, 2.87254164748e-57, 6.37583360761e-61,
1.10331620211e-64, 1.48844393335e-68, 1.56535552444e-72,
1.28328011060e-76, 8.20047697109e-81
] + [0] * 10)
self._assert_all_close(out3, answer3)
# Fractional values.
counts4 = np.array([19.5, -5.1, 0])
t4 = 10.1
out4 = pate_ss.compute_local_sensitivity_bounds_threshold(
counts4, num_teachers, t4, sigma, order)
answer4 = np.array([
0.0620410301, 0.0875807131, 0.113451958, 0.139561671, 0.1657074530,
0.1908244840, 0.2070270720, 0.207027072, 0.169718100, 0.0575152142,
0.00678695871
] + [0] * 6 + [0.000536304908, 0.0172181073, 0.041909870] + [0] * 10)
self._assert_all_close(out4, answer4)
if __name__ == "__main__":
unittest.main()