Internal change.

PiperOrigin-RevId: 321765113
This commit is contained in:
A. Unique TensorFlower 2020-07-17 06:07:28 -07:00
parent 510dd207d5
commit 3549d23da3

View file

@ -29,11 +29,21 @@ class AttackInputDataTest(absltest.TestCase):
labels_test=np.array([0, 1]) labels_test=np.array([0, 1])
) )
np.testing.assert_almost_equal( np.testing.assert_equal(
attack_input.get_loss_train().tolist(), [0.5, 0.2]) attack_input.get_loss_train().tolist(), [0.5, 0.2])
np.testing.assert_almost_equal( np.testing.assert_equal(
attack_input.get_loss_test().tolist(), [0.2, 0.5]) attack_input.get_loss_test().tolist(), [0.2, 0.5])
def test_get_loss_explicitly_provided(self):
attack_input = AttackInputData(
loss_train=np.array([1.0, 3.0, 6.0]),
loss_test=np.array([1.0, 4.0, 6.0]))
np.testing.assert_equal(
attack_input.get_loss_train().tolist(), [1.0, 3.0, 6.0])
np.testing.assert_equal(
attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0])
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()