Catches when data is not sufficient for StratifiedKFold split.

PiperOrigin-RevId: 510197242
This commit is contained in:
Shuang Song 2023-02-16 11:23:38 -08:00 committed by A. Unique TensorFlower
parent 0c691d0b4d
commit 4dd8d0ffde
2 changed files with 58 additions and 31 deletions

View file

@ -51,11 +51,13 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
# TODO(b/220394926): Allow users to specify their own attack models. # TODO(b/220394926): Allow users to specify their own attack models.
def _run_trained_attack(attack_input: AttackInputData, def _run_trained_attack(
attack_input: AttackInputData,
attack_type: AttackType, attack_type: AttackType,
balance_attacker_training: bool = True, balance_attacker_training: bool = True,
cross_validation_folds: int = 2, cross_validation_folds: int = 2,
backend: Optional[str] = None): backend: Optional[str] = None,
) -> Optional[SingleAttackResult]:
"""Classification attack done by ML models.""" """Classification attack done by ML models."""
prepared_attacker_data = models.create_attacker_data( prepared_attacker_data = models.create_attacker_data(
attack_input, balance=balance_attacker_training) attack_input, balance=balance_attacker_training)
@ -76,8 +78,9 @@ def _run_trained_attack(attack_input: AttackInputData,
# We use StratifiedKFold to create disjoint subsets of samples. Notice that # We use StratifiedKFold to create disjoint subsets of samples. Notice that
# the index it returns is with respect to the samples shuffled with `indices`. # the index it returns is with respect to the samples shuffled with `indices`.
kf = model_selection.StratifiedKFold(cross_validation_folds, shuffle=False) kf = model_selection.StratifiedKFold(cross_validation_folds, shuffle=False)
for train_indices_in_shuffled, test_indices_in_shuffled in kf.split( try:
features[indices], labels[indices]): shuffled_indices = kf.split(features[indices], labels[indices])
for train_indices_in_shuffled, test_indices_in_shuffled in shuffled_indices:
# `train_indices_in_shuffled` is with respect to the data shuffled with # `train_indices_in_shuffled` is with respect to the data shuffled with
# `indices`. We convert it to `train_indices` to work with the original # `indices`. We convert it to `train_indices` to work with the original
# data (`features` and 'labels'). # data (`features` and 'labels').
@ -88,10 +91,10 @@ def _run_trained_attack(attack_input: AttackInputData,
# Setup sample weights if provided. # Setup sample weights if provided.
if sample_weights is not None: if sample_weights is not None:
# If sample weights are provided, only the weights at the training indices # If sample weights are provided, only the weights at the training
# are used for training. The weights at the test indices are not used # indices are used for training. The weights at the test indices are not
# during prediction. Not that 'train' and 'test' refer to the data for the # used during prediction. Not that 'train' and 'test' refer to the data
# attack models, not the data for the original models. # for the attack models, not the data for the original models.
sample_weights_train = np.squeeze(sample_weights[train_indices]) sample_weights_train = np.squeeze(sample_weights[train_indices])
else: else:
sample_weights_train = None sample_weights_train = None
@ -100,9 +103,15 @@ def _run_trained_attack(attack_input: AttackInputData,
attacker.train_model( attacker.train_model(
features[train_indices], features[train_indices],
labels[train_indices], labels[train_indices],
sample_weight=sample_weights_train) sample_weight=sample_weights_train,
)
predictions = attacker.predict(features[test_indices]) predictions = attacker.predict(features[test_indices])
scores[test_indices] = predictions scores[test_indices] = predictions
except ValueError as ve:
if 'cannot be greater than the number of members in each class.' in str(ve):
logging.warning('kf.split in _run_trained_attack fails with: %s', str(ve))
return None
raise ValueError(str(ve)) from ve
# Predict the left out with the last attacker # Predict the left out with the last attacker
if left_out_indices.size: if left_out_indices.size:
@ -322,7 +331,10 @@ def _compute_membership_probability(
if attack_input.loss_train is not None and attack_input.loss_test is not None: if attack_input.loss_train is not None and attack_input.loss_test is not None:
train_values = attack_input.loss_train train_values = attack_input.loss_train
test_values = attack_input.loss_test test_values = attack_input.loss_test
elif attack_input.entropy_train is not None and attack_input.entropy_test is not None: elif (
attack_input.entropy_train is not None
and attack_input.entropy_test is not None
):
train_values = attack_input.entropy_train train_values = attack_input.entropy_train
test_values = attack_input.entropy_test test_values = attack_input.entropy_test
else: else:

View file

@ -398,6 +398,21 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase):
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
result.roc_curve.get_ppv(), 0.57142, decimal=2) result.roc_curve.get_ppv(), 0.57142, decimal=2)
def test_run_attacks_insufficient_examples(self):
result = mia.run_attacks(
get_test_input(1, 100),
SlicingSpec(),
(
AttackType.THRESHOLD_ATTACK,
AttackType.LOGISTIC_REGRESSION,
),
)
self.assertLen(result.single_attack_results, 1)
self.assertEqual(
result.single_attack_results[0].attack_type.value,
AttackType.THRESHOLD_ATTACK.value,
)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()