Internal change.
PiperOrigin-RevId: 321768596
This commit is contained in:
parent
3549d23da3
commit
8ec709e3d7
1 changed files with 16 additions and 0 deletions
|
@ -44,6 +44,22 @@ class AttackInputDataTest(absltest.TestCase):
|
||||||
np.testing.assert_equal(
|
np.testing.assert_equal(
|
||||||
attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0])
|
attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0])
|
||||||
|
|
||||||
|
def test_validator(self):
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
AttackInputData(logits_train=np.array([])).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
AttackInputData(labels_train=np.array([])).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
AttackInputData(loss_train=np.array([])).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
AttackInputData(logits_test=np.array([])).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
AttackInputData(labels_test=np.array([])).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
AttackInputData(loss_test=np.array([])).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
AttackInputData().validate)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue