# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for audit.py."""

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import audit

def dummy_train_and_score_function(dataset):
  del dataset
  return 0

def get_auditor():
  poisoning = {}
  datasets = (np.zeros((5, 2)), np.zeros(5)), (np.zeros((5, 2)), np.zeros(5))
  poisoning["data"] = datasets
  poisoning["pois"] = (datasets[0][0][0], datasets[0][1][0])
  auditor = audit.AuditAttack(datasets[0][0], datasets[0][1],
                              dummy_train_and_score_function)
  auditor.poisoning = poisoning

  return auditor


class AuditParameterizedTest(parameterized.TestCase):
  """Class to test parameterized audit.py functions."""
  @parameterized.named_parameters(
      ('Test0', np.ones(500), np.zeros(500), 0.5, 0.01, 1,
       (4.541915810224092, 0.9894593118113243)),
      ('Test1', np.ones(500), np.zeros(500), 0.5, 0.01, 2,
       (2.27095790511, 0.9894593118113243)),
      ('Test2', np.ones(500), np.ones(500), 0.5, 0.01, 1,
       (0, 0))
  )

  def test_compute_epsilon_and_acc(self, poison_scores, unpois_scores,
                                   threshold, pois_ct, alpha, expected_res):
    expected_eps, expected_acc = expected_res
    computed_res = audit.compute_epsilon_and_acc(poison_scores, unpois_scores,
                                                 threshold, pois_ct, alpha)
    computed_eps, computed_acc = computed_res
    self.assertAlmostEqual(computed_eps, expected_eps)
    self.assertAlmostEqual(computed_acc, expected_acc)

  @parameterized.named_parameters(
      ('Test0', [1]*500, [0]*250 + [.5]*250, 1, 0.01, .5,
       (.5, 4.541915810224092, 0.9894593118113243)),
      ('Test1', [1]*500, [0]*250 + [.5]*250, 1, 0.01, None,
       (.5, 4.541915810224092, 0.9894593118113243)),
      ('Test2', [1]*500, [0]*500, 2, 0.01, .5,
       (.5, 2.27095790511, 0.9894593118113243)),
  )

  def test_compute_results(self, poison_scores, unpois_scores, pois_ct,
                           alpha, threshold, expected_res):
    expected_thresh, expected_eps, expected_acc = expected_res
    computed_res = audit.compute_results(poison_scores, unpois_scores,
                                         pois_ct, alpha, threshold)
    computed_thresh, computed_eps, computed_acc = computed_res
    self.assertAlmostEqual(computed_thresh, expected_thresh)
    self.assertAlmostEqual(computed_eps, expected_eps)
    self.assertAlmostEqual(computed_acc, expected_acc)


class AuditAttackTest(absltest.TestCase):
  """Nonparameterized audit.py test class."""
  def test_run_experiments(self):
    auditor = get_auditor()
    pois, unpois = auditor.run_experiments(100)
    expected = [0]*100
    self.assertListEqual(pois, expected)
    self.assertListEqual(unpois, expected)



if __name__ == '__main__':
  absltest.main()