forked from 626_privacy/tensorflow_privacy
Internal change.
PiperOrigin-RevId: 321765113
This commit is contained in:
parent
510dd207d5
commit
3549d23da3
1 changed files with 12 additions and 2 deletions
|
@ -29,11 +29,21 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
labels_test=np.array([0, 1])
|
||||
)
|
||||
|
||||
np.testing.assert_almost_equal(
|
||||
np.testing.assert_equal(
|
||||
attack_input.get_loss_train().tolist(), [0.5, 0.2])
|
||||
np.testing.assert_almost_equal(
|
||||
np.testing.assert_equal(
|
||||
attack_input.get_loss_test().tolist(), [0.2, 0.5])
|
||||
|
||||
def test_get_loss_explicitly_provided(self):
|
||||
attack_input = AttackInputData(
|
||||
loss_train=np.array([1.0, 3.0, 6.0]),
|
||||
loss_test=np.array([1.0, 4.0, 6.0]))
|
||||
|
||||
np.testing.assert_equal(
|
||||
attack_input.get_loss_train().tolist(), [1.0, 3.0, 6.0])
|
||||
np.testing.assert_equal(
|
||||
attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue