diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index 39a4773..e75cccd 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -20,7 +20,7 @@ import glob import logging import os import pickle -from typing import Any, Iterable, MutableSequence, Optional, Union +from typing import Any, Iterable, MutableSequence, Optional, Union, Sequence import numpy as np import pandas as pd @@ -39,6 +39,7 @@ class SlicingFeature(enum.Enum): CLASS = 'class' PERCENTILE = 'percentile' CORRECTLY_CLASSIFIED = 'correctly_classified' + CUSTOM = 'custom' @dataclasses.dataclass @@ -65,6 +66,11 @@ class SingleSliceSpec: if self.feature == SlicingFeature.PERCENTILE: return 'Loss percentiles: %d-%d' % self.value + if self.feature == SlicingFeature.CUSTOM: + custom_train_indices, custom_test_indices, group_value = self.value + return (f'Custom indices: train = {custom_train_indices}, ' + f'test = {custom_test_indices}, group_value = {group_value}') + return '%s=%s' % (self.feature.name, self.value) @@ -91,6 +97,37 @@ class SlicingSpec: # examples will be generated. by_classification_correctness: bool = False + # When both `all_custom_train_indices` and `all_custom_test_indices` are set, + # will slice by custom indices. + # `custom_train_indices` and `custom_test_indices` are sequences containing + # the same number of arrays. Each array indicates the grouping of training and + # test examples, and should have a length equal to the number of training and + # test examples. + # For example, suppose we have 3 training examples (a1, a2, a3), and + # 2 test examples (b1, b2). Then, + # all_custom_train_indices = [np.array([2, 1, 2]), np.array([0, 0, 1])] + # all_custom_test_indices = [np.array([1, 2]), np.array([1, 0])] + # means we are going to consider two ways of slicing them: + # 1. two groups: (a2, b1) corresponding to value 1, (a1, a3, b2) corresponding + # to value 2. + # 2. two groups: (a1, a2, b2) corresponding to value 0, (a3, b1) corresponding + # to value 1. + all_custom_train_indices: Optional[Sequence[np.ndarray]] = None + all_custom_test_indices: Optional[Sequence[np.ndarray]] = None + + def __post_init__(self): + if not self.all_custom_train_indices and not self.all_custom_test_indices: + return + if bool(self.all_custom_train_indices) != bool( + self.all_custom_test_indices): + raise ValueError('custom_train_indices and custom_test_indices must ' + 'be provided or set to None at the same time.') + if len(self.all_custom_train_indices) != len(self.all_custom_test_indices): + raise ValueError('all_custom_train_indices and all_custom_test_indices ' + 'should have the same length, but got' + f'{len(self.all_custom_train_indices)} and ' + f'{len(self.all_custom_test_indices)}.') + def __str__(self): """Only keeps the True values.""" result = ['SlicingSpec('] @@ -107,6 +144,8 @@ class SlicingSpec: result.append(' By percentiles,') if self.by_classification_correctness: result.append(' By classification correctness,') + if self.all_custom_train_indices: + result.append(' By custom indices,') result.append(')') return '\n'.join(result) @@ -123,8 +162,9 @@ class AttackType(enum.Enum): @property def is_trained_attack(self): """Returns whether this type of attack requires training a model.""" - return (self != AttackType.THRESHOLD_ATTACK) and ( - self != AttackType.THRESHOLD_ENTROPY_ATTACK) + # Compare by name instead of the variable itself to support module reload. + return self.name not in (AttackType.THRESHOLD_ATTACK.name, + AttackType.THRESHOLD_ENTROPY_ATTACK.name) def __str__(self): """Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION.""" diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py index aad3a09..feeb3fe 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py @@ -44,6 +44,8 @@ class SingleSliceSpecTest(parameterized.TestCase): (SlicingFeature.CLASS, 2, 'CLASS=2'), (SlicingFeature.PERCENTILE, (10, 20), 'Loss percentiles: 10-20'), (SlicingFeature.CORRECTLY_CLASSIFIED, True, 'CORRECTLY_CLASSIFIED=True'), + (SlicingFeature.CUSTOM, (np.array([1]), np.array([2, 1]), 1), + 'Custom indices: train = [1], test = [2 1], group_value = 1'), ) def testStr(self, feature, value, expected_str): self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py index e785dd9..2678c82 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py @@ -26,13 +26,16 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec +_MAX_NUM_OF_SLICES = 1000 + + def _slice_if_not_none(a, idx): return None if a is None else a[idx] def _slice_data_by_indices(data: AttackInputData, idx_train, idx_test) -> AttackInputData: - """Slices train fields with with idx_train and test fields with and idx_test.""" + """Slices train fields with idx_train and test fields with idx_test.""" result = AttackInputData() @@ -128,10 +131,55 @@ def _slice_by_classification_correctness(data: AttackInputData, return _slice_data_by_indices(data, idx_train, idx_test) +def _slice_by_custom_indices(data: AttackInputData, + custom_train_indices: np.ndarray, + custom_test_indices: np.ndarray, + group_value: int) -> AttackInputData: + """Slices attack inputs by custom indices. + + Args: + data: Data to be used as input to the attack models. + custom_train_indices: The group indices of each training example. + custom_test_indices: The group indices of each test example. + group_value: The group value to pick. + + Returns: + AttackInputData object containing the sliced data. + """ + train_size, test_size = data.get_train_size(), data.get_test_size() + if custom_train_indices.shape[0] != train_size: + raise ValueError( + "custom_train_indices should have the same number of elements as " + f"the training data, but got {custom_train_indices.shape} and " + f"{train_size}") + if custom_test_indices.shape[0] != test_size: + raise ValueError( + "custom_test_indices should have the same number of elements as " + f"the test data, but got {custom_test_indices.shape} and " + f"{test_size}") + idx_train = custom_train_indices == group_value + idx_test = custom_test_indices == group_value + return _slice_data_by_indices(data, idx_train, idx_test) + + def get_single_slice_specs( slicing_spec: SlicingSpec, num_classes: Optional[int] = None) -> List[SingleSliceSpec]: - """Returns slices of data according to slicing_spec.""" + """Returns slices of data according to slicing_spec. + + Args: + slicing_spec: the slicing specification + num_classes: number of classes of the examples. Required when slicing by + class. + + Returns: + Slices of data according to the slicing specification. + + Raises: + ValueError: If the number of slices is above `_MAX_NUM_OF_SLICES` when + slicing by class or slicing with custom indices. Or, if `num_classes` is + not provided when slicing by class. + """ result = [] if slicing_spec.entire_dataset: @@ -141,10 +189,12 @@ def get_single_slice_specs( by_class = slicing_spec.by_class if isinstance(by_class, bool): if by_class: - assert num_classes, "When by_class == True, num_classes should be given." - assert 0 <= num_classes <= 1000, ( - f"Too much classes for slicing by classes. " - f"Found {num_classes}.") + if not num_classes: + raise ValueError("When by_class == True, num_classes should be given.") + if not 0 <= num_classes <= _MAX_NUM_OF_SLICES: + raise ValueError(f"Too many classes for slicing by classes. " + f"Found {num_classes}." + f"Should be no more than {_MAX_NUM_OF_SLICES}.") for c in range(num_classes): result.append(SingleSliceSpec(SlicingFeature.CLASS, c)) elif isinstance(by_class, int): @@ -164,6 +214,23 @@ def get_single_slice_specs( result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True)) result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, False)) + # Create slices by custom indices. + if slicing_spec.all_custom_train_indices: + for custom_train_indices, custom_test_indices in zip( + slicing_spec.all_custom_train_indices, + slicing_spec.all_custom_test_indices): + groups = np.intersect1d( + np.unique(custom_train_indices), + np.unique(custom_test_indices), + assume_unique=True) + if not 0 <= groups.size <= _MAX_NUM_OF_SLICES: + raise ValueError( + f"Too many groups ({groups.size}) for slicing by custom indices. " + f"Should be no more than {_MAX_NUM_OF_SLICES}.") + for g in groups: + result.append( + SingleSliceSpec(SlicingFeature.CUSTOM, + (custom_train_indices, custom_test_indices, g))) return result @@ -179,6 +246,10 @@ def get_slice(data: AttackInputData, data_slice = _slice_by_percentiles(data, from_percentile, to_percentile) elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED: data_slice = _slice_by_classification_correctness(data, slice_spec.value) + elif slice_spec.feature == SlicingFeature.CUSTOM: + custom_train_indices, custom_test_indices, group_value = slice_spec.value + data_slice = _slice_by_custom_indices(data, custom_train_indices, + custom_test_indices, group_value) else: raise ValueError('Unknown slice spec feature "%s"' % slice_spec.feature) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py index 0324b9a..db1ddca 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py @@ -13,10 +13,11 @@ # limitations under the License. import logging + from absl.testing import absltest +from absl.testing import parameterized from absl.testing.absltest import mock import numpy as np - from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleSliceSpec from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingFeature @@ -38,7 +39,7 @@ def _are_lists_equal(lhs, rhs) -> bool: return True -class SingleSliceSpecsTest(absltest.TestCase): +class SingleSliceSpecsTest(parameterized.TestCase): """Tests for get_single_slice_specs.""" ENTIRE_DATASET_SLICE = SingleSliceSpec() @@ -95,8 +96,81 @@ class SingleSliceSpecsTest(absltest.TestCase): output = get_single_slice_specs(input_data, n_classes) self.assertLen(output, expected_slices) + @parameterized.parameters( + (np.array([1, 2, 1, 2]), np.array([2, 2, 1, 2]), [1, 2]), + (np.array([0, -1, 2, -1, 2]), np.array([2, 2, -1, 2]), [-1, 2]), + (np.array([1, 2, 1, 2] + list(range(5000))), np.array([2, 2, 1]), [1, 2]), + (np.array([1, 2, 1, 2]), np.array([3, 4]), []), + ) + def test_slicing_by_custom_indices_one_pair(self, custom_train_indices, + custom_test_indices, + expected_groups): + input_data = SlicingSpec( + all_custom_train_indices=[custom_train_indices], + all_custom_test_indices=[custom_test_indices]) + expected = [self.ENTIRE_DATASET_SLICE] + [ + SingleSliceSpec(SlicingFeature.CUSTOM, + (custom_train_indices, custom_test_indices, g)) + for g in expected_groups + ] + output = get_single_slice_specs(input_data) + self.assertTrue(_are_lists_equal(output, expected)) -class GetSliceTest(absltest.TestCase): + def test_slicing_by_custom_indices_multi_pairs(self): + all_custom_train_indices = [ + np.array([1, 2, 1, 2]), + np.array([0, -1, 2, -1, 2]), + np.array([1, 2, 1, 2] + list(range(5000))), + np.array([1, 2, 1, 2]) + ] + all_custom_test_indices = [ + np.array([2, 2, 1, 2]), + np.array([2, 2, -1, 2]), + np.array([2, 2, 1]), + np.array([3, 4]) + ] + expected_group_values = [[1, 2], [-1, 2], [1, 2], []] + + input_data = SlicingSpec( + all_custom_train_indices=all_custom_train_indices, + all_custom_test_indices=all_custom_test_indices) + expected = [self.ENTIRE_DATASET_SLICE] + for custom_train_indices, custom_test_indices, eg in zip( + all_custom_train_indices, all_custom_test_indices, + expected_group_values): + expected.extend([ + SingleSliceSpec(SlicingFeature.CUSTOM, + (custom_train_indices, custom_test_indices, g)) + for g in eg + ]) + output = get_single_slice_specs(input_data) + self.assertTrue(_are_lists_equal(output, expected)) + + @parameterized.parameters( + ([np.array([1, 2])], None), + (None, [np.array([1, 2])]), + ([], [np.array([1, 2])]), + ([np.array([1, 2])], [np.array([1, 2]), + np.array([1, 2])]), + ) + def test_slicing_by_custom_indices_wrong_indices(self, + all_custom_train_indices, + all_custom_test_indices): + self.assertRaises( + ValueError, + SlicingSpec, + all_custom_train_indices=all_custom_train_indices, + all_custom_test_indices=all_custom_test_indices) + + def test_slicing_by_custom_indices_too_many_groups(self): + input_data = SlicingSpec( + all_custom_train_indices=[np.arange(1001), + np.arange(3)], + all_custom_test_indices=[np.arange(1001), np.arange(3)]) + self.assertRaises(ValueError, get_single_slice_specs, input_data) + + +class GetSliceTest(parameterized.TestCase): def __init__(self, methodname): """Initialize the test class.""" @@ -210,6 +284,40 @@ class GetSliceTest(absltest.TestCase): self.assertTrue((output.labels_train == [0, 2]).all()) self.assertTrue((output.labels_test == [1, 2, 0]).all()) + def test_slice_by_custom_indices(self): + custom_train_indices = np.array([2, 2, 100, 4]) + custom_test_indices = np.array([100, 2, 2, 2]) + custom_slice = SingleSliceSpec( + SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2)) + output = get_slice(self.input_data, custom_slice) + np.testing.assert_array_equal(output.logits_train, + np.array([[0, 1, 0], [2, 0, 3]])) + np.testing.assert_array_equal( + output.logits_test, np.array([[12, 13, 0], [14, 15, 0], [0, 16, 17]])) + np.testing.assert_array_equal(output.probs_train, + np.array([[0, 1, 0], [0.1, 0, 0.7]])) + np.testing.assert_array_equal( + output.probs_test, np.array([[0.1, 0.9, 0], [0.15, 0.85, 0], [0, 0, + 1]])) + np.testing.assert_array_equal(output.labels_train, np.array([1, 0])) + np.testing.assert_array_equal(output.labels_test, np.array([2, 0, 2])) + np.testing.assert_array_equal(output.loss_train, np.array([2, 0.25])) + np.testing.assert_array_equal(output.loss_test, np.array([3.5, 7, 4.5])) + np.testing.assert_array_equal(output.entropy_train, np.array([0.4, 8])) + np.testing.assert_array_equal(output.entropy_test, + np.array([10.5, 4.5, 0.3])) + + @parameterized.parameters( + (np.array([2, 2, 100]), np.array([100, 2, 2])), + (np.array([2, 2, 100, 4]), np.array([100, 2, 2])), + (np.array([2, 100, 4]), np.array([100, 2, 2, 2])), + ) + def test_slice_by_custom_indices_wrong_size(self, custom_train_indices, + custom_test_indices): + custom_slice = SingleSliceSpec( + SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2)) + self.assertRaises(ValueError, get_slice, self.input_data, custom_slice) + class GetSliceTestForMultilabelData(absltest.TestCase): @@ -288,6 +396,26 @@ class GetSliceTestForMultilabelData(absltest.TestCase): False) self.assertRaises(ValueError, get_slice, self.input_data, percentile_slice) + def test_slice_by_custom_indices(self): + custom_train_indices = np.array([2, 2, 100, 4]) + custom_test_indices = np.array([100, 2, 2, 2]) + custom_slice = SingleSliceSpec( + SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2)) + output = get_slice(self.input_data, custom_slice) + # Check logits. + with self.subTest(msg='Check logits'): + np.testing.assert_array_equal(output.logits_train, + np.array([[0, 1, 0], [2, 0, 3]])) + np.testing.assert_array_equal( + output.logits_test, np.array([[12, 13, 0], [14, 15, 0], [0, 16, 17]])) + + # Check labels. + with self.subTest(msg='Check labels'): + np.testing.assert_array_equal(output.labels_train, + np.array([[0, 1, 1], [1, 0, 1]])) + np.testing.assert_array_equal(output.labels_test, + np.array([[0, 1, 0], [0, 1, 0], [0, 0, 1]])) + if __name__ == '__main__': absltest.main() diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py index 77045b1..a05ef46 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py @@ -285,12 +285,13 @@ class KNearestNeighborsAttacker(TrainedAttacker): def create_attacker(attack_type, backend: Optional[str] = None) -> TrainedAttacker: """Returns the corresponding attacker for the provided attack_type.""" - if attack_type == data_structures.AttackType.LOGISTIC_REGRESSION: + # Compare by name instead of the variable itself to support module reload. + if attack_type.name == data_structures.AttackType.LOGISTIC_REGRESSION.name: return LogisticRegressionAttacker(backend=backend) - if attack_type == data_structures.AttackType.MULTI_LAYERED_PERCEPTRON: + if attack_type.name == data_structures.AttackType.MULTI_LAYERED_PERCEPTRON.name: return MultilayerPerceptronAttacker(backend=backend) - if attack_type == data_structures.AttackType.RANDOM_FOREST: + if attack_type.name == data_structures.AttackType.RANDOM_FOREST.name: return RandomForestAttacker(backend=backend) - if attack_type == data_structures.AttackType.K_NEAREST_NEIGHBORS: + if attack_type.name == data_structures.AttackType.K_NEAREST_NEIGHBORS.name: return KNearestNeighborsAttacker(backend=backend) raise NotImplementedError('Attack type %s not implemented yet.' % attack_type)