From 3549d23da3c72bdf355c5479c65ff24eb7430eca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Jul 2020 06:07:28 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 321765113 --- .../data_structures_test.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 721f9e1..25a9c98 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -29,11 +29,21 @@ class AttackInputDataTest(absltest.TestCase): labels_test=np.array([0, 1]) ) - np.testing.assert_almost_equal( + np.testing.assert_equal( attack_input.get_loss_train().tolist(), [0.5, 0.2]) - np.testing.assert_almost_equal( + np.testing.assert_equal( attack_input.get_loss_test().tolist(), [0.2, 0.5]) + def test_get_loss_explicitly_provided(self): + attack_input = AttackInputData( + loss_train=np.array([1.0, 3.0, 6.0]), + loss_test=np.array([1.0, 4.0, 6.0])) + + np.testing.assert_equal( + attack_input.get_loss_train().tolist(), [1.0, 3.0, 6.0]) + np.testing.assert_equal( + attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0]) + if __name__ == '__main__': absltest.main()