Implement possibility to return slice indices.
PiperOrigin-RevId: 540885025
This commit is contained in:
parent
a4bdb05b62
commit
45da453410
6 changed files with 129 additions and 42 deletions
|
@ -148,11 +148,16 @@ def main(unused_argv):
|
|||
labels_train=training_labels,
|
||||
labels_test=test_labels,
|
||||
probs_train=training_pred,
|
||||
probs_test=test_pred),
|
||||
probs_test=test_pred,
|
||||
),
|
||||
data_structures.SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=(data_structures.AttackType.THRESHOLD_ATTACK,
|
||||
data_structures.AttackType.LOGISTIC_REGRESSION),
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
attack_types=(
|
||||
data_structures.AttackType.THRESHOLD_ATTACK,
|
||||
data_structures.AttackType.LOGISTIC_REGRESSION,
|
||||
),
|
||||
privacy_report_metadata=privacy_report_metadata,
|
||||
return_slice_indices=True,
|
||||
)
|
||||
epoch_results.append(attack_results)
|
||||
|
||||
# Generate privacy reports
|
||||
|
|
|
@ -766,6 +766,11 @@ class SingleAttackResult:
|
|||
# test set samples will have lower scores than the training set samples.
|
||||
membership_scores_test: Optional[np.ndarray] = None
|
||||
|
||||
# Indices of train and test examples from the input data that were used in
|
||||
# this attack.
|
||||
train_indices: Optional[np.ndarray] = None
|
||||
test_indices: Optional[np.ndarray] = None
|
||||
|
||||
def get_attacker_advantage(self):
|
||||
return self.roc_curve.get_attacker_advantage()
|
||||
|
||||
|
|
|
@ -33,8 +33,12 @@ def _slice_if_not_none(a, idx):
|
|||
return None if a is None else a[idx]
|
||||
|
||||
|
||||
def _slice_data_by_indices(data: AttackInputData, idx_train,
|
||||
idx_test) -> AttackInputData:
|
||||
def _slice_data_by_indices(
|
||||
data: AttackInputData,
|
||||
idx_train: np.ndarray,
|
||||
idx_test: np.ndarray,
|
||||
return_slice_indices: bool,
|
||||
) -> AttackInputData:
|
||||
"""Slices train fields with idx_train and test fields with idx_test."""
|
||||
|
||||
result = AttackInputData()
|
||||
|
@ -68,10 +72,16 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
|
|||
# of the original dataset.
|
||||
result.multilabel_data = data.is_multilabel_data()
|
||||
|
||||
if return_slice_indices:
|
||||
result.train_indices = np.where(idx_train)[0]
|
||||
result.test_indices = np.where(idx_test)[0]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
|
||||
def _slice_by_class(
|
||||
data: AttackInputData, class_value: int, return_slice_indices: bool = False
|
||||
) -> AttackInputData:
|
||||
"""Gets the indices (boolean) for examples belonging to the given class."""
|
||||
if not data.is_multilabel_data():
|
||||
idx_train = data.labels_train == class_value
|
||||
|
@ -84,11 +94,15 @@ def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
|
|||
)
|
||||
idx_train = data.labels_train[:, class_value].astype(bool)
|
||||
idx_test = data.labels_test[:, class_value].astype(bool)
|
||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||
return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)
|
||||
|
||||
|
||||
def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
|
||||
to_percentile: float):
|
||||
def _slice_by_percentiles(
|
||||
data: AttackInputData,
|
||||
from_percentile: float,
|
||||
to_percentile: float,
|
||||
return_slice_indices: bool = False,
|
||||
) -> AttackInputData:
|
||||
"""Slices samples by loss percentiles."""
|
||||
|
||||
# Find from_percentile and to_percentile percentiles in losses.
|
||||
|
@ -107,22 +121,31 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
|
|||
idx_train = (from_loss <= loss_train) & (loss_train <= to_loss)
|
||||
idx_test = (from_loss <= loss_test) & (loss_test <= to_loss)
|
||||
|
||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||
return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)
|
||||
|
||||
|
||||
def _indices_by_classification(logits_or_probs, labels, correctly_classified):
|
||||
def _indices_by_classification(
|
||||
logits_or_probs,
|
||||
labels,
|
||||
correctly_classified,
|
||||
):
|
||||
idx_correct = labels == np.argmax(logits_or_probs, axis=1)
|
||||
return idx_correct if correctly_classified else np.invert(idx_correct)
|
||||
|
||||
|
||||
def _slice_by_classification_correctness(data: AttackInputData,
|
||||
correctly_classified: bool):
|
||||
def _slice_by_classification_correctness(
|
||||
data: AttackInputData,
|
||||
correctly_classified: bool,
|
||||
return_slice_indices: bool = False,
|
||||
) -> AttackInputData:
|
||||
"""Slices attack inputs by whether they were classified correctly.
|
||||
|
||||
Args:
|
||||
data: Data to be used as input to the attack models.
|
||||
correctly_classified: Whether to use the indices corresponding to the
|
||||
correctly classified samples.
|
||||
return_slice_indices: if true, the returned AttackInputData will include
|
||||
indices of the train and test data samples that were used for this slice.
|
||||
|
||||
Returns:
|
||||
AttackInputData object containing the sliced data.
|
||||
|
@ -136,13 +159,16 @@ def _slice_by_classification_correctness(data: AttackInputData,
|
|||
correctly_classified)
|
||||
idx_test = _indices_by_classification(data.logits_or_probs_test,
|
||||
data.labels_test, correctly_classified)
|
||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||
return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)
|
||||
|
||||
|
||||
def _slice_by_custom_indices(data: AttackInputData,
|
||||
def _slice_by_custom_indices(
|
||||
data: AttackInputData,
|
||||
custom_train_indices: np.ndarray,
|
||||
custom_test_indices: np.ndarray,
|
||||
group_value: int) -> AttackInputData:
|
||||
group_value: int,
|
||||
return_slice_indices: bool = False,
|
||||
) -> AttackInputData:
|
||||
"""Slices attack inputs by custom indices.
|
||||
|
||||
Args:
|
||||
|
@ -150,6 +176,8 @@ def _slice_by_custom_indices(data: AttackInputData,
|
|||
custom_train_indices: The group indices of each training example.
|
||||
custom_test_indices: The group indices of each test example.
|
||||
group_value: The group value to pick.
|
||||
return_slice_indices: if true, the returned AttackInputData will include
|
||||
indices of the train and test data samples that were used for this slice.
|
||||
|
||||
Returns:
|
||||
AttackInputData object containing the sliced data.
|
||||
|
@ -167,12 +195,12 @@ def _slice_by_custom_indices(data: AttackInputData,
|
|||
f"{test_size}")
|
||||
idx_train = custom_train_indices == group_value
|
||||
idx_test = custom_test_indices == group_value
|
||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||
return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)
|
||||
|
||||
|
||||
def get_single_slice_specs(
|
||||
slicing_spec: SlicingSpec,
|
||||
num_classes: Optional[int] = None) -> List[SingleSliceSpec]:
|
||||
slicing_spec: SlicingSpec, num_classes: Optional[int] = None
|
||||
) -> List[SingleSliceSpec]:
|
||||
"""Returns slices of data according to slicing_spec.
|
||||
|
||||
Args:
|
||||
|
@ -242,22 +270,34 @@ def get_single_slice_specs(
|
|||
return result
|
||||
|
||||
|
||||
def get_slice(data: AttackInputData,
|
||||
slice_spec: SingleSliceSpec) -> AttackInputData:
|
||||
def get_slice(
|
||||
data: AttackInputData,
|
||||
slice_spec: SingleSliceSpec,
|
||||
return_slice_indices: bool = False,
|
||||
) -> AttackInputData:
|
||||
"""Returns a single slice of data according to slice_spec."""
|
||||
if slice_spec.entire_dataset:
|
||||
data_slice = copy.copy(data)
|
||||
elif slice_spec.feature == SlicingFeature.CLASS:
|
||||
data_slice = _slice_by_class(data, slice_spec.value)
|
||||
data_slice = _slice_by_class(data, slice_spec.value, return_slice_indices)
|
||||
elif slice_spec.feature == SlicingFeature.PERCENTILE:
|
||||
from_percentile, to_percentile = slice_spec.value
|
||||
data_slice = _slice_by_percentiles(data, from_percentile, to_percentile)
|
||||
data_slice = _slice_by_percentiles(
|
||||
data, from_percentile, to_percentile, return_slice_indices
|
||||
)
|
||||
elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED:
|
||||
data_slice = _slice_by_classification_correctness(data, slice_spec.value)
|
||||
data_slice = _slice_by_classification_correctness(
|
||||
data, slice_spec.value, return_slice_indices
|
||||
)
|
||||
elif slice_spec.feature == SlicingFeature.CUSTOM:
|
||||
custom_train_indices, custom_test_indices, group_value = slice_spec.value
|
||||
data_slice = _slice_by_custom_indices(data, custom_train_indices,
|
||||
custom_test_indices, group_value)
|
||||
data_slice = _slice_by_custom_indices(
|
||||
data,
|
||||
custom_train_indices,
|
||||
custom_test_indices,
|
||||
group_value,
|
||||
return_slice_indices,
|
||||
)
|
||||
else:
|
||||
raise ValueError('Unknown slice spec feature "%s"' % slice_spec.feature)
|
||||
|
||||
|
|
|
@ -267,10 +267,12 @@ class GetSliceTest(parameterized.TestCase):
|
|||
self.assertTrue((output.labels_train == [1, 0, 2]).all())
|
||||
self.assertTrue((output.labels_test == [1]).all())
|
||||
|
||||
def test_slice_by_correctness(self):
|
||||
percentile_slice = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED,
|
||||
False)
|
||||
output = get_slice(self.input_data, percentile_slice)
|
||||
@parameterized.parameters(False, True)
|
||||
def test_slice_by_correctness(self, return_slice_indices):
|
||||
slice_spec = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, False)
|
||||
output = get_slice(
|
||||
self.input_data, slice_spec, return_slice_indices=return_slice_indices
|
||||
)
|
||||
|
||||
# Check logits.
|
||||
self.assertLen(output.logits_train, 2)
|
||||
|
@ -284,6 +286,14 @@ class GetSliceTest(parameterized.TestCase):
|
|||
self.assertTrue((output.labels_train == [0, 2]).all())
|
||||
self.assertTrue((output.labels_test == [1, 2, 0]).all())
|
||||
|
||||
# Check return indices
|
||||
if return_slice_indices:
|
||||
self.assertTrue((output.train_indices == [1, 3]).all())
|
||||
self.assertTrue((output.test_indices == [0, 1, 2]).all())
|
||||
else:
|
||||
self.assertFalse(hasattr(output, 'train_indices'))
|
||||
self.assertFalse(hasattr(output, 'test_indices'))
|
||||
|
||||
def test_slice_by_custom_indices(self):
|
||||
custom_train_indices = np.array([2, 2, 100, 4])
|
||||
custom_test_indices = np.array([100, 2, 2, 2])
|
||||
|
|
|
@ -254,14 +254,16 @@ def _run_attack(attack_input: AttackInputData,
|
|||
return _run_threshold_attack(attack_input)
|
||||
|
||||
|
||||
def run_attacks(attack_input: AttackInputData,
|
||||
def run_attacks(
|
||||
attack_input: AttackInputData,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (
|
||||
AttackType.THRESHOLD_ATTACK,),
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||
privacy_report_metadata: PrivacyReportMetadata = None,
|
||||
balance_attacker_training: bool = True,
|
||||
min_num_samples: int = 1,
|
||||
backend: Optional[str] = None) -> AttackResults:
|
||||
backend: Optional[str] = None,
|
||||
return_slice_indices=False,
|
||||
) -> AttackResults:
|
||||
"""Runs membership inference attacks on a classification model.
|
||||
|
||||
It runs attacks specified by attack_types on each attack_input slice which is
|
||||
|
@ -282,6 +284,9 @@ def run_attacks(attack_input: AttackInputData,
|
|||
may not support multiprocessing and in those cases the `threading` backend
|
||||
should be used. See https://joblib.readthedocs.io/en/latest/parallel.html
|
||||
for more details.
|
||||
return_slice_indices: if true, the result for each slice will include the
|
||||
indices of train and test data examples that were used for this slice and
|
||||
attacks. This does not return indices for the "whole dataset" slice.
|
||||
|
||||
Returns:
|
||||
the attack result.
|
||||
|
@ -301,7 +306,9 @@ def run_attacks(attack_input: AttackInputData,
|
|||
logging.info('Will run %s attacks on each of %s slice specifications.',
|
||||
num_attacks, num_slice_specs)
|
||||
for single_slice_spec in input_slice_specs:
|
||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||
attack_input_slice = get_slice(
|
||||
attack_input, single_slice_spec, return_slice_indices
|
||||
)
|
||||
for attack_type in attack_types:
|
||||
logging.info('Running attack: %s', attack_type.name)
|
||||
attack_result = _run_attack(attack_input_slice, attack_type,
|
||||
|
@ -313,6 +320,9 @@ def run_attacks(attack_input: AttackInputData,
|
|||
'positive predictive value=%s', attack_type.name,
|
||||
attack_result.get_auc(), attack_result.get_attacker_advantage(),
|
||||
attack_result.get_ppv())
|
||||
if return_slice_indices and not single_slice_spec.entire_dataset:
|
||||
attack_result.train_indices = attack_input_slice.train_indices
|
||||
attack_result.test_indices = attack_input_slice.test_indices
|
||||
attack_results.append(attack_result)
|
||||
|
||||
privacy_report_metadata = _compute_missing_privacy_report_metadata(
|
||||
|
|
|
@ -445,6 +445,23 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase):
|
|||
AttackType.THRESHOLD_ATTACK.value,
|
||||
)
|
||||
|
||||
def test_run_attacks_size_return_indices(self):
|
||||
result = mia.run_attacks(
|
||||
get_test_input(100, 100),
|
||||
SlicingSpec(
|
||||
entire_dataset=False,
|
||||
by_class=True,
|
||||
by_percentiles=True,
|
||||
by_classification_correctness=True,
|
||||
),
|
||||
(AttackType.LOGISTIC_REGRESSION,),
|
||||
return_slice_indices=True,
|
||||
)
|
||||
|
||||
for attack_result in result.single_attack_results:
|
||||
self.assertIsNotNone(attack_result.train_indices)
|
||||
self.assertIsNotNone(attack_result.test_indices)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue