# Copyright 2017 The 'Scalable Private Learning with PATE' Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for pate.smooth_sensitivity."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import unittest
import numpy as np

import smooth_sensitivity as pate_ss


class PateSmoothSensitivityTest(unittest.TestCase):

  def test_check_conditions(self):
    self.assertEqual(pate_ss.check_conditions(20, 10, 25.), (True, False))
    self.assertEqual(pate_ss.check_conditions(30, 10, 25.), (True, True))

  def _assert_all_close(self, x, y):
    """Asserts that two numpy arrays are close."""
    self.assertEqual(len(x), len(y))
    self.assertTrue(np.allclose(x, y, rtol=1e-8, atol=0))

  def test_compute_local_sensitivity_bounds_gnmax(self):
    counts1 = np.array([10, 0, 0])
    sigma1 = .5
    order1 = 1.5

    answer1 = np.array(
        [3.13503646e-17, 1.60178280e-08, 5.90681786e-03] + [5.99981308e+00] * 7)

    # Test for "going right" in the smooth sensitivity computation.
    out1 = pate_ss.compute_local_sensitivity_bounds_gnmax(
        counts1, 10, sigma1, order1)

    self._assert_all_close(out1, answer1)

    counts2 = np.array([1000, 500, 300, 200, 0])
    sigma2 = 250.
    order2 = 10.

    # Test for "going left" in the smooth sensitivity computation.
    out2 = pate_ss.compute_local_sensitivity_bounds_gnmax(
        counts2, 2000, sigma2, order2)

    answer2 = np.array([0.] * 298 + [2.77693450548e-7, 2.10853979548e-6] +
                       [2.73113623988e-6] * 1700)
    self._assert_all_close(out2, answer2)

  def test_compute_local_sensitivity_bounds_threshold(self):
    counts1_3 = np.array([20, 10, 0])
    num_teachers = sum(counts1_3)
    t1 = 16  # high threshold
    sigma = 2
    order = 10

    out1 = pate_ss.compute_local_sensitivity_bounds_threshold(
        counts1_3, num_teachers, t1, sigma, order)
    answer1 = np.array([0] * 3 + [
        1.48454129e-04, 1.47826870e-02, 3.94153241e-02, 6.45775697e-02,
        9.01543247e-02, 1.16054002e-01, 1.42180452e-01, 1.42180452e-01,
        1.48454129e-04, 1.47826870e-02, 3.94153241e-02, 6.45775697e-02,
        9.01543266e-02, 1.16054000e-01, 1.42180452e-01, 1.68302106e-01,
        1.93127860e-01
    ] + [0] * 10)
    self._assert_all_close(out1, answer1)

    t2 = 2  # low threshold

    out2 = pate_ss.compute_local_sensitivity_bounds_threshold(
        counts1_3, num_teachers, t2, sigma, order)
    answer2 = np.array([
        1.60212079e-01, 2.07021132e-01, 2.07021132e-01, 1.93127860e-01,
        1.68302106e-01, 1.42180452e-01, 1.16054002e-01, 9.01543247e-02,
        6.45775697e-02, 3.94153241e-02, 1.47826870e-02, 1.48454129e-04
    ] + [0] * 18)
    self._assert_all_close(out2, answer2)

    t3 = 50  # very high threshold (larger than the number of teachers).

    out3 = pate_ss.compute_local_sensitivity_bounds_threshold(
        counts1_3, num_teachers, t3, sigma, order)

    answer3 = np.array([
        1.35750725752e-19, 1.88990500499e-17, 2.05403154065e-15,
        1.74298153642e-13, 1.15489723995e-11, 5.97584949325e-10,
        2.41486826748e-08, 7.62150641922e-07, 1.87846248741e-05,
        0.000360973025976, 0.000360973025976, 2.76377015215e-50,
        1.00904975276e-53, 2.87254164748e-57, 6.37583360761e-61,
        1.10331620211e-64, 1.48844393335e-68, 1.56535552444e-72,
        1.28328011060e-76, 8.20047697109e-81
    ] + [0] * 10)

    self._assert_all_close(out3, answer3)

    # Fractional values.
    counts4 = np.array([19.5, -5.1, 0])
    t4 = 10.1
    out4 = pate_ss.compute_local_sensitivity_bounds_threshold(
        counts4, num_teachers, t4, sigma, order)

    answer4 = np.array([
        0.0620410301, 0.0875807131, 0.113451958, 0.139561671, 0.1657074530,
        0.1908244840, 0.2070270720, 0.207027072, 0.169718100, 0.0575152142,
        0.00678695871
    ] + [0] * 6 + [0.000536304908, 0.0172181073, 0.041909870] + [0] * 10)
    self._assert_all_close(out4, answer4)


if __name__ == "__main__":
  unittest.main()