diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 76df7fc..d0dd7a8 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -299,8 +299,8 @@ class SingleRiskScoreResultTest(absltest.TestCase): train_risk_scores=np.array([0.91,1,0.92,0.82,0.75]), test_risk_scores=np.array([0.81,0.7,0.75,0.25,0.3])) - self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[1], [0.8,0.625]) - self.assertEqual(result.attack_with_varied_thresholds([0.8,0.7])[2], [0.8,1]) + self.assertEqual(result.attack_with_varied_thresholds(np.array([0.8,0.7]))[1].tolist(), [0.8,0.625]) + self.assertEqual(result.attack_with_varied_thresholds(np.array([0.8,0.7]))[2].tolist(), [0.8,1]) class AttackResultsCollectionTest(absltest.TestCase):