diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 721f9e1..25a9c98 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -29,11 +29,21 @@ class AttackInputDataTest(absltest.TestCase): labels_test=np.array([0, 1]) ) - np.testing.assert_almost_equal( + np.testing.assert_equal( attack_input.get_loss_train().tolist(), [0.5, 0.2]) - np.testing.assert_almost_equal( + np.testing.assert_equal( attack_input.get_loss_test().tolist(), [0.2, 0.5]) + def test_get_loss_explicitly_provided(self): + attack_input = AttackInputData( + loss_train=np.array([1.0, 3.0, 6.0]), + loss_test=np.array([1.0, 4.0, 6.0])) + + np.testing.assert_equal( + attack_input.get_loss_train().tolist(), [1.0, 3.0, 6.0]) + np.testing.assert_equal( + attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0]) + if __name__ == '__main__': absltest.main()