From 4dd8d0ffde4ddb1575d5c2fc02e0693e08f4f4a1 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Thu, 16 Feb 2023 11:23:38 -0800 Subject: [PATCH] Catches when data is not sufficient for StratifiedKFold split. PiperOrigin-RevId: 510197242 --- .../membership_inference_attack.py | 74 +++++++++++-------- .../membership_inference_attack_test.py | 15 ++++ 2 files changed, 58 insertions(+), 31 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py index 37a8fb2..ef3bdc7 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py @@ -51,11 +51,13 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec: # TODO(b/220394926): Allow users to specify their own attack models. -def _run_trained_attack(attack_input: AttackInputData, - attack_type: AttackType, - balance_attacker_training: bool = True, - cross_validation_folds: int = 2, - backend: Optional[str] = None): +def _run_trained_attack( + attack_input: AttackInputData, + attack_type: AttackType, + balance_attacker_training: bool = True, + cross_validation_folds: int = 2, + backend: Optional[str] = None, +) -> Optional[SingleAttackResult]: """Classification attack done by ML models.""" prepared_attacker_data = models.create_attacker_data( attack_input, balance=balance_attacker_training) @@ -76,33 +78,40 @@ def _run_trained_attack(attack_input: AttackInputData, # We use StratifiedKFold to create disjoint subsets of samples. Notice that # the index it returns is with respect to the samples shuffled with `indices`. kf = model_selection.StratifiedKFold(cross_validation_folds, shuffle=False) - for train_indices_in_shuffled, test_indices_in_shuffled in kf.split( - features[indices], labels[indices]): - # `train_indices_in_shuffled` is with respect to the data shuffled with - # `indices`. We convert it to `train_indices` to work with the original - # data (`features` and 'labels'). - train_indices = indices[train_indices_in_shuffled] - test_indices = indices[test_indices_in_shuffled] - # Make sure one sample only got score predicted once - assert np.all(np.isnan(scores[test_indices])) + try: + shuffled_indices = kf.split(features[indices], labels[indices]) + for train_indices_in_shuffled, test_indices_in_shuffled in shuffled_indices: + # `train_indices_in_shuffled` is with respect to the data shuffled with + # `indices`. We convert it to `train_indices` to work with the original + # data (`features` and 'labels'). + train_indices = indices[train_indices_in_shuffled] + test_indices = indices[test_indices_in_shuffled] + # Make sure one sample only got score predicted once + assert np.all(np.isnan(scores[test_indices])) - # Setup sample weights if provided. - if sample_weights is not None: - # If sample weights are provided, only the weights at the training indices - # are used for training. The weights at the test indices are not used - # during prediction. Not that 'train' and 'test' refer to the data for the - # attack models, not the data for the original models. - sample_weights_train = np.squeeze(sample_weights[train_indices]) - else: - sample_weights_train = None + # Setup sample weights if provided. + if sample_weights is not None: + # If sample weights are provided, only the weights at the training + # indices are used for training. The weights at the test indices are not + # used during prediction. Not that 'train' and 'test' refer to the data + # for the attack models, not the data for the original models. + sample_weights_train = np.squeeze(sample_weights[train_indices]) + else: + sample_weights_train = None - attacker = models.create_attacker(attack_type, backend=backend) - attacker.train_model( - features[train_indices], - labels[train_indices], - sample_weight=sample_weights_train) - predictions = attacker.predict(features[test_indices]) - scores[test_indices] = predictions + attacker = models.create_attacker(attack_type, backend=backend) + attacker.train_model( + features[train_indices], + labels[train_indices], + sample_weight=sample_weights_train, + ) + predictions = attacker.predict(features[test_indices]) + scores[test_indices] = predictions + except ValueError as ve: + if 'cannot be greater than the number of members in each class.' in str(ve): + logging.warning('kf.split in _run_trained_attack fails with: %s', str(ve)) + return None + raise ValueError(str(ve)) from ve # Predict the left out with the last attacker if left_out_indices.size: @@ -322,7 +331,10 @@ def _compute_membership_probability( if attack_input.loss_train is not None and attack_input.loss_test is not None: train_values = attack_input.loss_train test_values = attack_input.loss_test - elif attack_input.entropy_train is not None and attack_input.entropy_test is not None: + elif ( + attack_input.entropy_train is not None + and attack_input.entropy_test is not None + ): train_values = attack_input.entropy_train test_values = attack_input.entropy_test else: diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py index fd0f294..c3ba19b 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py @@ -398,6 +398,21 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase): np.testing.assert_almost_equal( result.roc_curve.get_ppv(), 0.57142, decimal=2) + def test_run_attacks_insufficient_examples(self): + result = mia.run_attacks( + get_test_input(1, 100), + SlicingSpec(), + ( + AttackType.THRESHOLD_ATTACK, + AttackType.LOGISTIC_REGRESSION, + ), + ) + self.assertLen(result.single_attack_results, 1) + self.assertEqual( + result.single_attack_results[0].attack_type.value, + AttackType.THRESHOLD_ATTACK.value, + ) + if __name__ == '__main__': absltest.main()