Catches when data is not sufficient for StratifiedKFold split.

PiperOrigin-RevId: 510197242
This commit is contained in:
Shuang Song 2023-02-16 11:23:38 -08:00 committed by A. Unique TensorFlower
parent 0c691d0b4d
commit 4dd8d0ffde
2 changed files with 58 additions and 31 deletions

View file

@ -51,11 +51,13 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
# TODO(b/220394926): Allow users to specify their own attack models.
def _run_trained_attack(attack_input: AttackInputData,
attack_type: AttackType,
balance_attacker_training: bool = True,
cross_validation_folds: int = 2,
backend: Optional[str] = None):
def _run_trained_attack(
attack_input: AttackInputData,
attack_type: AttackType,
balance_attacker_training: bool = True,
cross_validation_folds: int = 2,
backend: Optional[str] = None,
) -> Optional[SingleAttackResult]:
"""Classification attack done by ML models."""
prepared_attacker_data = models.create_attacker_data(
attack_input, balance=balance_attacker_training)
@ -76,33 +78,40 @@ def _run_trained_attack(attack_input: AttackInputData,
# We use StratifiedKFold to create disjoint subsets of samples. Notice that
# the index it returns is with respect to the samples shuffled with `indices`.
kf = model_selection.StratifiedKFold(cross_validation_folds, shuffle=False)
for train_indices_in_shuffled, test_indices_in_shuffled in kf.split(
features[indices], labels[indices]):
# `train_indices_in_shuffled` is with respect to the data shuffled with
# `indices`. We convert it to `train_indices` to work with the original
# data (`features` and 'labels').
train_indices = indices[train_indices_in_shuffled]
test_indices = indices[test_indices_in_shuffled]
# Make sure one sample only got score predicted once
assert np.all(np.isnan(scores[test_indices]))
try:
shuffled_indices = kf.split(features[indices], labels[indices])
for train_indices_in_shuffled, test_indices_in_shuffled in shuffled_indices:
# `train_indices_in_shuffled` is with respect to the data shuffled with
# `indices`. We convert it to `train_indices` to work with the original
# data (`features` and 'labels').
train_indices = indices[train_indices_in_shuffled]
test_indices = indices[test_indices_in_shuffled]
# Make sure one sample only got score predicted once
assert np.all(np.isnan(scores[test_indices]))
# Setup sample weights if provided.
if sample_weights is not None:
# If sample weights are provided, only the weights at the training indices
# are used for training. The weights at the test indices are not used
# during prediction. Not that 'train' and 'test' refer to the data for the
# attack models, not the data for the original models.
sample_weights_train = np.squeeze(sample_weights[train_indices])
else:
sample_weights_train = None
# Setup sample weights if provided.
if sample_weights is not None:
# If sample weights are provided, only the weights at the training
# indices are used for training. The weights at the test indices are not
# used during prediction. Not that 'train' and 'test' refer to the data
# for the attack models, not the data for the original models.
sample_weights_train = np.squeeze(sample_weights[train_indices])
else:
sample_weights_train = None
attacker = models.create_attacker(attack_type, backend=backend)
attacker.train_model(
features[train_indices],
labels[train_indices],
sample_weight=sample_weights_train)
predictions = attacker.predict(features[test_indices])
scores[test_indices] = predictions
attacker = models.create_attacker(attack_type, backend=backend)
attacker.train_model(
features[train_indices],
labels[train_indices],
sample_weight=sample_weights_train,
)
predictions = attacker.predict(features[test_indices])
scores[test_indices] = predictions
except ValueError as ve:
if 'cannot be greater than the number of members in each class.' in str(ve):
logging.warning('kf.split in _run_trained_attack fails with: %s', str(ve))
return None
raise ValueError(str(ve)) from ve
# Predict the left out with the last attacker
if left_out_indices.size:
@ -322,7 +331,10 @@ def _compute_membership_probability(
if attack_input.loss_train is not None and attack_input.loss_test is not None:
train_values = attack_input.loss_train
test_values = attack_input.loss_test
elif attack_input.entropy_train is not None and attack_input.entropy_test is not None:
elif (
attack_input.entropy_train is not None
and attack_input.entropy_test is not None
):
train_values = attack_input.entropy_train
test_values = attack_input.entropy_test
else:

View file

@ -398,6 +398,21 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase):
np.testing.assert_almost_equal(
result.roc_curve.get_ppv(), 0.57142, decimal=2)
def test_run_attacks_insufficient_examples(self):
result = mia.run_attacks(
get_test_input(1, 100),
SlicingSpec(),
(
AttackType.THRESHOLD_ATTACK,
AttackType.LOGISTIC_REGRESSION,
),
)
self.assertLen(result.single_attack_results, 1)
self.assertEqual(
result.single_attack_results[0].attack_type.value,
AttackType.THRESHOLD_ATTACK.value,
)
if __name__ == '__main__':
absltest.main()