Catches when data is not sufficient for StratifiedKFold split.
PiperOrigin-RevId: 510197242
This commit is contained in:
parent
0c691d0b4d
commit
4dd8d0ffde
2 changed files with 58 additions and 31 deletions
|
@ -51,11 +51,13 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
|
|||
|
||||
|
||||
# TODO(b/220394926): Allow users to specify their own attack models.
|
||||
def _run_trained_attack(attack_input: AttackInputData,
|
||||
def _run_trained_attack(
|
||||
attack_input: AttackInputData,
|
||||
attack_type: AttackType,
|
||||
balance_attacker_training: bool = True,
|
||||
cross_validation_folds: int = 2,
|
||||
backend: Optional[str] = None):
|
||||
backend: Optional[str] = None,
|
||||
) -> Optional[SingleAttackResult]:
|
||||
"""Classification attack done by ML models."""
|
||||
prepared_attacker_data = models.create_attacker_data(
|
||||
attack_input, balance=balance_attacker_training)
|
||||
|
@ -76,8 +78,9 @@ def _run_trained_attack(attack_input: AttackInputData,
|
|||
# We use StratifiedKFold to create disjoint subsets of samples. Notice that
|
||||
# the index it returns is with respect to the samples shuffled with `indices`.
|
||||
kf = model_selection.StratifiedKFold(cross_validation_folds, shuffle=False)
|
||||
for train_indices_in_shuffled, test_indices_in_shuffled in kf.split(
|
||||
features[indices], labels[indices]):
|
||||
try:
|
||||
shuffled_indices = kf.split(features[indices], labels[indices])
|
||||
for train_indices_in_shuffled, test_indices_in_shuffled in shuffled_indices:
|
||||
# `train_indices_in_shuffled` is with respect to the data shuffled with
|
||||
# `indices`. We convert it to `train_indices` to work with the original
|
||||
# data (`features` and 'labels').
|
||||
|
@ -88,10 +91,10 @@ def _run_trained_attack(attack_input: AttackInputData,
|
|||
|
||||
# Setup sample weights if provided.
|
||||
if sample_weights is not None:
|
||||
# If sample weights are provided, only the weights at the training indices
|
||||
# are used for training. The weights at the test indices are not used
|
||||
# during prediction. Not that 'train' and 'test' refer to the data for the
|
||||
# attack models, not the data for the original models.
|
||||
# If sample weights are provided, only the weights at the training
|
||||
# indices are used for training. The weights at the test indices are not
|
||||
# used during prediction. Not that 'train' and 'test' refer to the data
|
||||
# for the attack models, not the data for the original models.
|
||||
sample_weights_train = np.squeeze(sample_weights[train_indices])
|
||||
else:
|
||||
sample_weights_train = None
|
||||
|
@ -100,9 +103,15 @@ def _run_trained_attack(attack_input: AttackInputData,
|
|||
attacker.train_model(
|
||||
features[train_indices],
|
||||
labels[train_indices],
|
||||
sample_weight=sample_weights_train)
|
||||
sample_weight=sample_weights_train,
|
||||
)
|
||||
predictions = attacker.predict(features[test_indices])
|
||||
scores[test_indices] = predictions
|
||||
except ValueError as ve:
|
||||
if 'cannot be greater than the number of members in each class.' in str(ve):
|
||||
logging.warning('kf.split in _run_trained_attack fails with: %s', str(ve))
|
||||
return None
|
||||
raise ValueError(str(ve)) from ve
|
||||
|
||||
# Predict the left out with the last attacker
|
||||
if left_out_indices.size:
|
||||
|
@ -322,7 +331,10 @@ def _compute_membership_probability(
|
|||
if attack_input.loss_train is not None and attack_input.loss_test is not None:
|
||||
train_values = attack_input.loss_train
|
||||
test_values = attack_input.loss_test
|
||||
elif attack_input.entropy_train is not None and attack_input.entropy_test is not None:
|
||||
elif (
|
||||
attack_input.entropy_train is not None
|
||||
and attack_input.entropy_test is not None
|
||||
):
|
||||
train_values = attack_input.entropy_train
|
||||
test_values = attack_input.entropy_test
|
||||
else:
|
||||
|
|
|
@ -398,6 +398,21 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase):
|
|||
np.testing.assert_almost_equal(
|
||||
result.roc_curve.get_ppv(), 0.57142, decimal=2)
|
||||
|
||||
def test_run_attacks_insufficient_examples(self):
|
||||
result = mia.run_attacks(
|
||||
get_test_input(1, 100),
|
||||
SlicingSpec(),
|
||||
(
|
||||
AttackType.THRESHOLD_ATTACK,
|
||||
AttackType.LOGISTIC_REGRESSION,
|
||||
),
|
||||
)
|
||||
self.assertLen(result.single_attack_results, 1)
|
||||
self.assertEqual(
|
||||
result.single_attack_results[0].attack_type.value,
|
||||
AttackType.THRESHOLD_ATTACK.value,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue