From 8ec709e3d7c0d32f50868ff4d78f13fc3de1aaf7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Jul 2020 06:37:23 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 321768596 --- .../data_structures_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 25a9c98..23f9ba0 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -44,6 +44,22 @@ class AttackInputDataTest(absltest.TestCase): np.testing.assert_equal( attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0]) + def test_validator(self): + self.assertRaises(ValueError, + AttackInputData(logits_train=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(labels_train=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(loss_train=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(logits_test=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(labels_test=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(loss_test=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData().validate) + if __name__ == '__main__': absltest.main()