93e9585f18
PiperOrigin-RevId: 229241117
126 lines
4.5 KiB
Python
126 lines
4.5 KiB
Python
# Copyright 2017 The 'Scalable Private Learning with PATE' Authors All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Tests for pate.smooth_sensitivity."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import unittest
|
|
import numpy as np
|
|
|
|
import smooth_sensitivity as pate_ss
|
|
|
|
|
|
class PateSmoothSensitivityTest(unittest.TestCase):
|
|
|
|
def test_check_conditions(self):
|
|
self.assertEqual(pate_ss.check_conditions(20, 10, 25.), (True, False))
|
|
self.assertEqual(pate_ss.check_conditions(30, 10, 25.), (True, True))
|
|
|
|
def _assert_all_close(self, x, y):
|
|
"""Asserts that two numpy arrays are close."""
|
|
self.assertEqual(len(x), len(y))
|
|
self.assertTrue(np.allclose(x, y, rtol=1e-8, atol=0))
|
|
|
|
def test_compute_local_sensitivity_bounds_gnmax(self):
|
|
counts1 = np.array([10, 0, 0])
|
|
sigma1 = .5
|
|
order1 = 1.5
|
|
|
|
answer1 = np.array(
|
|
[3.13503646e-17, 1.60178280e-08, 5.90681786e-03] + [5.99981308e+00] * 7)
|
|
|
|
# Test for "going right" in the smooth sensitivity computation.
|
|
out1 = pate_ss.compute_local_sensitivity_bounds_gnmax(
|
|
counts1, 10, sigma1, order1)
|
|
|
|
self._assert_all_close(out1, answer1)
|
|
|
|
counts2 = np.array([1000, 500, 300, 200, 0])
|
|
sigma2 = 250.
|
|
order2 = 10.
|
|
|
|
# Test for "going left" in the smooth sensitivity computation.
|
|
out2 = pate_ss.compute_local_sensitivity_bounds_gnmax(
|
|
counts2, 2000, sigma2, order2)
|
|
|
|
answer2 = np.array([0.] * 298 + [2.77693450548e-7, 2.10853979548e-6] +
|
|
[2.73113623988e-6] * 1700)
|
|
self._assert_all_close(out2, answer2)
|
|
|
|
def test_compute_local_sensitivity_bounds_threshold(self):
|
|
counts1_3 = np.array([20, 10, 0])
|
|
num_teachers = sum(counts1_3)
|
|
t1 = 16 # high threshold
|
|
sigma = 2
|
|
order = 10
|
|
|
|
out1 = pate_ss.compute_local_sensitivity_bounds_threshold(
|
|
counts1_3, num_teachers, t1, sigma, order)
|
|
answer1 = np.array([0] * 3 + [
|
|
1.48454129e-04, 1.47826870e-02, 3.94153241e-02, 6.45775697e-02,
|
|
9.01543247e-02, 1.16054002e-01, 1.42180452e-01, 1.42180452e-01,
|
|
1.48454129e-04, 1.47826870e-02, 3.94153241e-02, 6.45775697e-02,
|
|
9.01543266e-02, 1.16054000e-01, 1.42180452e-01, 1.68302106e-01,
|
|
1.93127860e-01
|
|
] + [0] * 10)
|
|
self._assert_all_close(out1, answer1)
|
|
|
|
t2 = 2 # low threshold
|
|
|
|
out2 = pate_ss.compute_local_sensitivity_bounds_threshold(
|
|
counts1_3, num_teachers, t2, sigma, order)
|
|
answer2 = np.array([
|
|
1.60212079e-01, 2.07021132e-01, 2.07021132e-01, 1.93127860e-01,
|
|
1.68302106e-01, 1.42180452e-01, 1.16054002e-01, 9.01543247e-02,
|
|
6.45775697e-02, 3.94153241e-02, 1.47826870e-02, 1.48454129e-04
|
|
] + [0] * 18)
|
|
self._assert_all_close(out2, answer2)
|
|
|
|
t3 = 50 # very high threshold (larger than the number of teachers).
|
|
|
|
out3 = pate_ss.compute_local_sensitivity_bounds_threshold(
|
|
counts1_3, num_teachers, t3, sigma, order)
|
|
|
|
answer3 = np.array([
|
|
1.35750725752e-19, 1.88990500499e-17, 2.05403154065e-15,
|
|
1.74298153642e-13, 1.15489723995e-11, 5.97584949325e-10,
|
|
2.41486826748e-08, 7.62150641922e-07, 1.87846248741e-05,
|
|
0.000360973025976, 0.000360973025976, 2.76377015215e-50,
|
|
1.00904975276e-53, 2.87254164748e-57, 6.37583360761e-61,
|
|
1.10331620211e-64, 1.48844393335e-68, 1.56535552444e-72,
|
|
1.28328011060e-76, 8.20047697109e-81
|
|
] + [0] * 10)
|
|
|
|
self._assert_all_close(out3, answer3)
|
|
|
|
# Fractional values.
|
|
counts4 = np.array([19.5, -5.1, 0])
|
|
t4 = 10.1
|
|
out4 = pate_ss.compute_local_sensitivity_bounds_threshold(
|
|
counts4, num_teachers, t4, sigma, order)
|
|
|
|
answer4 = np.array([
|
|
0.0620410301, 0.0875807131, 0.113451958, 0.139561671, 0.1657074530,
|
|
0.1908244840, 0.2070270720, 0.207027072, 0.169718100, 0.0575152142,
|
|
0.00678695871
|
|
] + [0] * 6 + [0.000536304908, 0.0172181073, 0.041909870] + [0] * 10)
|
|
self._assert_all_close(out4, answer4)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|