Catches when data is not sufficient for StratifiedKFold split.
PiperOrigin-RevId: 510197242
This commit is contained in:
parent
0c691d0b4d
commit
4dd8d0ffde
2 changed files with 58 additions and 31 deletions
|
@ -51,11 +51,13 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/220394926): Allow users to specify their own attack models.
|
# TODO(b/220394926): Allow users to specify their own attack models.
|
||||||
def _run_trained_attack(attack_input: AttackInputData,
|
def _run_trained_attack(
|
||||||
|
attack_input: AttackInputData,
|
||||||
attack_type: AttackType,
|
attack_type: AttackType,
|
||||||
balance_attacker_training: bool = True,
|
balance_attacker_training: bool = True,
|
||||||
cross_validation_folds: int = 2,
|
cross_validation_folds: int = 2,
|
||||||
backend: Optional[str] = None):
|
backend: Optional[str] = None,
|
||||||
|
) -> Optional[SingleAttackResult]:
|
||||||
"""Classification attack done by ML models."""
|
"""Classification attack done by ML models."""
|
||||||
prepared_attacker_data = models.create_attacker_data(
|
prepared_attacker_data = models.create_attacker_data(
|
||||||
attack_input, balance=balance_attacker_training)
|
attack_input, balance=balance_attacker_training)
|
||||||
|
@ -76,8 +78,9 @@ def _run_trained_attack(attack_input: AttackInputData,
|
||||||
# We use StratifiedKFold to create disjoint subsets of samples. Notice that
|
# We use StratifiedKFold to create disjoint subsets of samples. Notice that
|
||||||
# the index it returns is with respect to the samples shuffled with `indices`.
|
# the index it returns is with respect to the samples shuffled with `indices`.
|
||||||
kf = model_selection.StratifiedKFold(cross_validation_folds, shuffle=False)
|
kf = model_selection.StratifiedKFold(cross_validation_folds, shuffle=False)
|
||||||
for train_indices_in_shuffled, test_indices_in_shuffled in kf.split(
|
try:
|
||||||
features[indices], labels[indices]):
|
shuffled_indices = kf.split(features[indices], labels[indices])
|
||||||
|
for train_indices_in_shuffled, test_indices_in_shuffled in shuffled_indices:
|
||||||
# `train_indices_in_shuffled` is with respect to the data shuffled with
|
# `train_indices_in_shuffled` is with respect to the data shuffled with
|
||||||
# `indices`. We convert it to `train_indices` to work with the original
|
# `indices`. We convert it to `train_indices` to work with the original
|
||||||
# data (`features` and 'labels').
|
# data (`features` and 'labels').
|
||||||
|
@ -88,10 +91,10 @@ def _run_trained_attack(attack_input: AttackInputData,
|
||||||
|
|
||||||
# Setup sample weights if provided.
|
# Setup sample weights if provided.
|
||||||
if sample_weights is not None:
|
if sample_weights is not None:
|
||||||
# If sample weights are provided, only the weights at the training indices
|
# If sample weights are provided, only the weights at the training
|
||||||
# are used for training. The weights at the test indices are not used
|
# indices are used for training. The weights at the test indices are not
|
||||||
# during prediction. Not that 'train' and 'test' refer to the data for the
|
# used during prediction. Not that 'train' and 'test' refer to the data
|
||||||
# attack models, not the data for the original models.
|
# for the attack models, not the data for the original models.
|
||||||
sample_weights_train = np.squeeze(sample_weights[train_indices])
|
sample_weights_train = np.squeeze(sample_weights[train_indices])
|
||||||
else:
|
else:
|
||||||
sample_weights_train = None
|
sample_weights_train = None
|
||||||
|
@ -100,9 +103,15 @@ def _run_trained_attack(attack_input: AttackInputData,
|
||||||
attacker.train_model(
|
attacker.train_model(
|
||||||
features[train_indices],
|
features[train_indices],
|
||||||
labels[train_indices],
|
labels[train_indices],
|
||||||
sample_weight=sample_weights_train)
|
sample_weight=sample_weights_train,
|
||||||
|
)
|
||||||
predictions = attacker.predict(features[test_indices])
|
predictions = attacker.predict(features[test_indices])
|
||||||
scores[test_indices] = predictions
|
scores[test_indices] = predictions
|
||||||
|
except ValueError as ve:
|
||||||
|
if 'cannot be greater than the number of members in each class.' in str(ve):
|
||||||
|
logging.warning('kf.split in _run_trained_attack fails with: %s', str(ve))
|
||||||
|
return None
|
||||||
|
raise ValueError(str(ve)) from ve
|
||||||
|
|
||||||
# Predict the left out with the last attacker
|
# Predict the left out with the last attacker
|
||||||
if left_out_indices.size:
|
if left_out_indices.size:
|
||||||
|
@ -322,7 +331,10 @@ def _compute_membership_probability(
|
||||||
if attack_input.loss_train is not None and attack_input.loss_test is not None:
|
if attack_input.loss_train is not None and attack_input.loss_test is not None:
|
||||||
train_values = attack_input.loss_train
|
train_values = attack_input.loss_train
|
||||||
test_values = attack_input.loss_test
|
test_values = attack_input.loss_test
|
||||||
elif attack_input.entropy_train is not None and attack_input.entropy_test is not None:
|
elif (
|
||||||
|
attack_input.entropy_train is not None
|
||||||
|
and attack_input.entropy_test is not None
|
||||||
|
):
|
||||||
train_values = attack_input.entropy_train
|
train_values = attack_input.entropy_train
|
||||||
test_values = attack_input.entropy_test
|
test_values = attack_input.entropy_test
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -398,6 +398,21 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase):
|
||||||
np.testing.assert_almost_equal(
|
np.testing.assert_almost_equal(
|
||||||
result.roc_curve.get_ppv(), 0.57142, decimal=2)
|
result.roc_curve.get_ppv(), 0.57142, decimal=2)
|
||||||
|
|
||||||
|
def test_run_attacks_insufficient_examples(self):
|
||||||
|
result = mia.run_attacks(
|
||||||
|
get_test_input(1, 100),
|
||||||
|
SlicingSpec(),
|
||||||
|
(
|
||||||
|
AttackType.THRESHOLD_ATTACK,
|
||||||
|
AttackType.LOGISTIC_REGRESSION,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertLen(result.single_attack_results, 1)
|
||||||
|
self.assertEqual(
|
||||||
|
result.single_attack_results[0].attack_type.value,
|
||||||
|
AttackType.THRESHOLD_ATTACK.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue