Implement possibility to return slice indices.

PiperOrigin-RevId: 540885025
This commit is contained in:
Vadym Doroshenko 2023-06-16 08:22:07 -07:00 committed by A. Unique TensorFlower
parent a4bdb05b62
commit 45da453410
6 changed files with 129 additions and 42 deletions

View file

@ -148,11 +148,16 @@ def main(unused_argv):
labels_train=training_labels, labels_train=training_labels,
labels_test=test_labels, labels_test=test_labels,
probs_train=training_pred, probs_train=training_pred,
probs_test=test_pred), probs_test=test_pred,
),
data_structures.SlicingSpec(entire_dataset=True, by_class=True), data_structures.SlicingSpec(entire_dataset=True, by_class=True),
attack_types=(data_structures.AttackType.THRESHOLD_ATTACK, attack_types=(
data_structures.AttackType.LOGISTIC_REGRESSION), data_structures.AttackType.THRESHOLD_ATTACK,
privacy_report_metadata=privacy_report_metadata) data_structures.AttackType.LOGISTIC_REGRESSION,
),
privacy_report_metadata=privacy_report_metadata,
return_slice_indices=True,
)
epoch_results.append(attack_results) epoch_results.append(attack_results)
# Generate privacy reports # Generate privacy reports

View file

@ -766,6 +766,11 @@ class SingleAttackResult:
# test set samples will have lower scores than the training set samples. # test set samples will have lower scores than the training set samples.
membership_scores_test: Optional[np.ndarray] = None membership_scores_test: Optional[np.ndarray] = None
# Indices of train and test examples from the input data that were used in
# this attack.
train_indices: Optional[np.ndarray] = None
test_indices: Optional[np.ndarray] = None
def get_attacker_advantage(self): def get_attacker_advantage(self):
return self.roc_curve.get_attacker_advantage() return self.roc_curve.get_attacker_advantage()

View file

@ -33,8 +33,12 @@ def _slice_if_not_none(a, idx):
return None if a is None else a[idx] return None if a is None else a[idx]
def _slice_data_by_indices(data: AttackInputData, idx_train, def _slice_data_by_indices(
idx_test) -> AttackInputData: data: AttackInputData,
idx_train: np.ndarray,
idx_test: np.ndarray,
return_slice_indices: bool,
) -> AttackInputData:
"""Slices train fields with idx_train and test fields with idx_test.""" """Slices train fields with idx_train and test fields with idx_test."""
result = AttackInputData() result = AttackInputData()
@ -68,10 +72,16 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
# of the original dataset. # of the original dataset.
result.multilabel_data = data.is_multilabel_data() result.multilabel_data = data.is_multilabel_data()
if return_slice_indices:
result.train_indices = np.where(idx_train)[0]
result.test_indices = np.where(idx_test)[0]
return result return result
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData: def _slice_by_class(
data: AttackInputData, class_value: int, return_slice_indices: bool = False
) -> AttackInputData:
"""Gets the indices (boolean) for examples belonging to the given class.""" """Gets the indices (boolean) for examples belonging to the given class."""
if not data.is_multilabel_data(): if not data.is_multilabel_data():
idx_train = data.labels_train == class_value idx_train = data.labels_train == class_value
@ -84,11 +94,15 @@ def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
) )
idx_train = data.labels_train[:, class_value].astype(bool) idx_train = data.labels_train[:, class_value].astype(bool)
idx_test = data.labels_test[:, class_value].astype(bool) idx_test = data.labels_test[:, class_value].astype(bool)
return _slice_data_by_indices(data, idx_train, idx_test) return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)
def _slice_by_percentiles(data: AttackInputData, from_percentile: float, def _slice_by_percentiles(
to_percentile: float): data: AttackInputData,
from_percentile: float,
to_percentile: float,
return_slice_indices: bool = False,
) -> AttackInputData:
"""Slices samples by loss percentiles.""" """Slices samples by loss percentiles."""
# Find from_percentile and to_percentile percentiles in losses. # Find from_percentile and to_percentile percentiles in losses.
@ -107,22 +121,31 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
idx_train = (from_loss <= loss_train) & (loss_train <= to_loss) idx_train = (from_loss <= loss_train) & (loss_train <= to_loss)
idx_test = (from_loss <= loss_test) & (loss_test <= to_loss) idx_test = (from_loss <= loss_test) & (loss_test <= to_loss)
return _slice_data_by_indices(data, idx_train, idx_test) return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)
def _indices_by_classification(logits_or_probs, labels, correctly_classified): def _indices_by_classification(
logits_or_probs,
labels,
correctly_classified,
):
idx_correct = labels == np.argmax(logits_or_probs, axis=1) idx_correct = labels == np.argmax(logits_or_probs, axis=1)
return idx_correct if correctly_classified else np.invert(idx_correct) return idx_correct if correctly_classified else np.invert(idx_correct)
def _slice_by_classification_correctness(data: AttackInputData, def _slice_by_classification_correctness(
correctly_classified: bool): data: AttackInputData,
correctly_classified: bool,
return_slice_indices: bool = False,
) -> AttackInputData:
"""Slices attack inputs by whether they were classified correctly. """Slices attack inputs by whether they were classified correctly.
Args: Args:
data: Data to be used as input to the attack models. data: Data to be used as input to the attack models.
correctly_classified: Whether to use the indices corresponding to the correctly_classified: Whether to use the indices corresponding to the
correctly classified samples. correctly classified samples.
return_slice_indices: if true, the returned AttackInputData will include
indices of the train and test data samples that were used for this slice.
Returns: Returns:
AttackInputData object containing the sliced data. AttackInputData object containing the sliced data.
@ -136,13 +159,16 @@ def _slice_by_classification_correctness(data: AttackInputData,
correctly_classified) correctly_classified)
idx_test = _indices_by_classification(data.logits_or_probs_test, idx_test = _indices_by_classification(data.logits_or_probs_test,
data.labels_test, correctly_classified) data.labels_test, correctly_classified)
return _slice_data_by_indices(data, idx_train, idx_test) return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)
def _slice_by_custom_indices(data: AttackInputData, def _slice_by_custom_indices(
custom_train_indices: np.ndarray, data: AttackInputData,
custom_test_indices: np.ndarray, custom_train_indices: np.ndarray,
group_value: int) -> AttackInputData: custom_test_indices: np.ndarray,
group_value: int,
return_slice_indices: bool = False,
) -> AttackInputData:
"""Slices attack inputs by custom indices. """Slices attack inputs by custom indices.
Args: Args:
@ -150,6 +176,8 @@ def _slice_by_custom_indices(data: AttackInputData,
custom_train_indices: The group indices of each training example. custom_train_indices: The group indices of each training example.
custom_test_indices: The group indices of each test example. custom_test_indices: The group indices of each test example.
group_value: The group value to pick. group_value: The group value to pick.
return_slice_indices: if true, the returned AttackInputData will include
indices of the train and test data samples that were used for this slice.
Returns: Returns:
AttackInputData object containing the sliced data. AttackInputData object containing the sliced data.
@ -167,12 +195,12 @@ def _slice_by_custom_indices(data: AttackInputData,
f"{test_size}") f"{test_size}")
idx_train = custom_train_indices == group_value idx_train = custom_train_indices == group_value
idx_test = custom_test_indices == group_value idx_test = custom_test_indices == group_value
return _slice_data_by_indices(data, idx_train, idx_test) return _slice_data_by_indices(data, idx_train, idx_test, return_slice_indices)
def get_single_slice_specs( def get_single_slice_specs(
slicing_spec: SlicingSpec, slicing_spec: SlicingSpec, num_classes: Optional[int] = None
num_classes: Optional[int] = None) -> List[SingleSliceSpec]: ) -> List[SingleSliceSpec]:
"""Returns slices of data according to slicing_spec. """Returns slices of data according to slicing_spec.
Args: Args:
@ -242,22 +270,34 @@ def get_single_slice_specs(
return result return result
def get_slice(data: AttackInputData, def get_slice(
slice_spec: SingleSliceSpec) -> AttackInputData: data: AttackInputData,
slice_spec: SingleSliceSpec,
return_slice_indices: bool = False,
) -> AttackInputData:
"""Returns a single slice of data according to slice_spec.""" """Returns a single slice of data according to slice_spec."""
if slice_spec.entire_dataset: if slice_spec.entire_dataset:
data_slice = copy.copy(data) data_slice = copy.copy(data)
elif slice_spec.feature == SlicingFeature.CLASS: elif slice_spec.feature == SlicingFeature.CLASS:
data_slice = _slice_by_class(data, slice_spec.value) data_slice = _slice_by_class(data, slice_spec.value, return_slice_indices)
elif slice_spec.feature == SlicingFeature.PERCENTILE: elif slice_spec.feature == SlicingFeature.PERCENTILE:
from_percentile, to_percentile = slice_spec.value from_percentile, to_percentile = slice_spec.value
data_slice = _slice_by_percentiles(data, from_percentile, to_percentile) data_slice = _slice_by_percentiles(
data, from_percentile, to_percentile, return_slice_indices
)
elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED: elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED:
data_slice = _slice_by_classification_correctness(data, slice_spec.value) data_slice = _slice_by_classification_correctness(
data, slice_spec.value, return_slice_indices
)
elif slice_spec.feature == SlicingFeature.CUSTOM: elif slice_spec.feature == SlicingFeature.CUSTOM:
custom_train_indices, custom_test_indices, group_value = slice_spec.value custom_train_indices, custom_test_indices, group_value = slice_spec.value
data_slice = _slice_by_custom_indices(data, custom_train_indices, data_slice = _slice_by_custom_indices(
custom_test_indices, group_value) data,
custom_train_indices,
custom_test_indices,
group_value,
return_slice_indices,
)
else: else:
raise ValueError('Unknown slice spec feature "%s"' % slice_spec.feature) raise ValueError('Unknown slice spec feature "%s"' % slice_spec.feature)

View file

@ -267,10 +267,12 @@ class GetSliceTest(parameterized.TestCase):
self.assertTrue((output.labels_train == [1, 0, 2]).all()) self.assertTrue((output.labels_train == [1, 0, 2]).all())
self.assertTrue((output.labels_test == [1]).all()) self.assertTrue((output.labels_test == [1]).all())
def test_slice_by_correctness(self): @parameterized.parameters(False, True)
percentile_slice = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, def test_slice_by_correctness(self, return_slice_indices):
False) slice_spec = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, False)
output = get_slice(self.input_data, percentile_slice) output = get_slice(
self.input_data, slice_spec, return_slice_indices=return_slice_indices
)
# Check logits. # Check logits.
self.assertLen(output.logits_train, 2) self.assertLen(output.logits_train, 2)
@ -284,6 +286,14 @@ class GetSliceTest(parameterized.TestCase):
self.assertTrue((output.labels_train == [0, 2]).all()) self.assertTrue((output.labels_train == [0, 2]).all())
self.assertTrue((output.labels_test == [1, 2, 0]).all()) self.assertTrue((output.labels_test == [1, 2, 0]).all())
# Check return indices
if return_slice_indices:
self.assertTrue((output.train_indices == [1, 3]).all())
self.assertTrue((output.test_indices == [0, 1, 2]).all())
else:
self.assertFalse(hasattr(output, 'train_indices'))
self.assertFalse(hasattr(output, 'test_indices'))
def test_slice_by_custom_indices(self): def test_slice_by_custom_indices(self):
custom_train_indices = np.array([2, 2, 100, 4]) custom_train_indices = np.array([2, 2, 100, 4])
custom_test_indices = np.array([100, 2, 2, 2]) custom_test_indices = np.array([100, 2, 2, 2])

View file

@ -254,14 +254,16 @@ def _run_attack(attack_input: AttackInputData,
return _run_threshold_attack(attack_input) return _run_threshold_attack(attack_input)
def run_attacks(attack_input: AttackInputData, def run_attacks(
slicing_spec: SlicingSpec = None, attack_input: AttackInputData,
attack_types: Iterable[AttackType] = ( slicing_spec: SlicingSpec = None,
AttackType.THRESHOLD_ATTACK,), attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
privacy_report_metadata: PrivacyReportMetadata = None, privacy_report_metadata: PrivacyReportMetadata = None,
balance_attacker_training: bool = True, balance_attacker_training: bool = True,
min_num_samples: int = 1, min_num_samples: int = 1,
backend: Optional[str] = None) -> AttackResults: backend: Optional[str] = None,
return_slice_indices=False,
) -> AttackResults:
"""Runs membership inference attacks on a classification model. """Runs membership inference attacks on a classification model.
It runs attacks specified by attack_types on each attack_input slice which is It runs attacks specified by attack_types on each attack_input slice which is
@ -282,6 +284,9 @@ def run_attacks(attack_input: AttackInputData,
may not support multiprocessing and in those cases the `threading` backend may not support multiprocessing and in those cases the `threading` backend
should be used. See https://joblib.readthedocs.io/en/latest/parallel.html should be used. See https://joblib.readthedocs.io/en/latest/parallel.html
for more details. for more details.
return_slice_indices: if true, the result for each slice will include the
indices of train and test data examples that were used for this slice and
attacks. This does not return indices for the "whole dataset" slice.
Returns: Returns:
the attack result. the attack result.
@ -301,7 +306,9 @@ def run_attacks(attack_input: AttackInputData,
logging.info('Will run %s attacks on each of %s slice specifications.', logging.info('Will run %s attacks on each of %s slice specifications.',
num_attacks, num_slice_specs) num_attacks, num_slice_specs)
for single_slice_spec in input_slice_specs: for single_slice_spec in input_slice_specs:
attack_input_slice = get_slice(attack_input, single_slice_spec) attack_input_slice = get_slice(
attack_input, single_slice_spec, return_slice_indices
)
for attack_type in attack_types: for attack_type in attack_types:
logging.info('Running attack: %s', attack_type.name) logging.info('Running attack: %s', attack_type.name)
attack_result = _run_attack(attack_input_slice, attack_type, attack_result = _run_attack(attack_input_slice, attack_type,
@ -313,6 +320,9 @@ def run_attacks(attack_input: AttackInputData,
'positive predictive value=%s', attack_type.name, 'positive predictive value=%s', attack_type.name,
attack_result.get_auc(), attack_result.get_attacker_advantage(), attack_result.get_auc(), attack_result.get_attacker_advantage(),
attack_result.get_ppv()) attack_result.get_ppv())
if return_slice_indices and not single_slice_spec.entire_dataset:
attack_result.train_indices = attack_input_slice.train_indices
attack_result.test_indices = attack_input_slice.test_indices
attack_results.append(attack_result) attack_results.append(attack_result)
privacy_report_metadata = _compute_missing_privacy_report_metadata( privacy_report_metadata = _compute_missing_privacy_report_metadata(

View file

@ -445,6 +445,23 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase):
AttackType.THRESHOLD_ATTACK.value, AttackType.THRESHOLD_ATTACK.value,
) )
def test_run_attacks_size_return_indices(self):
result = mia.run_attacks(
get_test_input(100, 100),
SlicingSpec(
entire_dataset=False,
by_class=True,
by_percentiles=True,
by_classification_correctness=True,
),
(AttackType.LOGISTIC_REGRESSION,),
return_slice_indices=True,
)
for attack_result in result.single_attack_results:
self.assertIsNotNone(attack_result.train_indices)
self.assertIsNotNone(attack_result.test_indices)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()