From d6d70f6211abb577a8144cac66c5a86a69dec61b Mon Sep 17 00:00:00 2001 From: Liwei Song Date: Thu, 10 Dec 2020 18:44:52 -0500 Subject: [PATCH] update data_structures_test --- .../membership_inference_attack/data_structures_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 76df7fc..d0dd7a8 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -299,8 +299,8 @@ class SingleRiskScoreResultTest(absltest.TestCase): train_risk_scores=np.array([0.91,1,0.92,0.82,0.75]), test_risk_scores=np.array([0.81,0.7,0.75,0.25,0.3])) - self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[1], [0.8,0.625]) - self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[2], [0.8,1]) + self.assertEqual(result.attack_with_varied_thresholds(np.array([0.8,0.7]))[1].tolist(), [0.8,0.625]) + self.assertEqual(result.attack_with_varied_thresholds(np.array([0.8,0.7]))[2].tolist(), [0.8,1]) class AttackResultsCollectionTest(absltest.TestCase):