diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py index 55d5656..8f877a7 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -43,7 +43,9 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec: return SingleSliceSpec() -def _run_trained_attack(attack_input: AttackInputData, attack_type: AttackType): +def _run_trained_attack(attack_input: AttackInputData, + attack_type: AttackType, + balance_attacker_training: bool = True): """Classification attack done by ML models.""" attacker = None @@ -59,7 +61,8 @@ def _run_trained_attack(attack_input: AttackInputData, attack_type: AttackType): raise NotImplementedError('Attack type %s not implemented yet.' % attack_type) - prepared_attacker_data = models.create_attacker_data(attack_input) + prepared_attacker_data = models.create_attacker_data( + attack_input, balance=balance_attacker_training) attacker.train_model(prepared_attacker_data.features_train, prepared_attacker_data.is_training_labels_train) @@ -94,19 +97,23 @@ def _run_threshold_attack(attack_input: AttackInputData): roc_curve=roc_curve) -def _run_attack(attack_input: AttackInputData, attack_type: AttackType): +def _run_attack(attack_input: AttackInputData, + attack_type: AttackType, + balance_attacker_training: bool = True): attack_input.validate() if attack_type.is_trained_attack: - return _run_trained_attack(attack_input, attack_type) + return _run_trained_attack(attack_input, attack_type, + balance_attacker_training) return _run_threshold_attack(attack_input) -def run_attacks( - attack_input: AttackInputData, - slicing_spec: SlicingSpec = None, - attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), - privacy_report_metadata: PrivacyReportMetadata = None) -> AttackResults: +def run_attacks(attack_input: AttackInputData, + slicing_spec: SlicingSpec = None, + attack_types: Iterable[AttackType] = ( + AttackType.THRESHOLD_ATTACK,), + privacy_report_metadata: PrivacyReportMetadata = None, + balance_attacker_training: bool = True) -> AttackResults: """Runs membership inference attacks on a classification model. It runs attacks specified by attack_types on each attack_input slice which is @@ -117,6 +124,10 @@ def run_attacks( slicing_spec: specifies attack_input slices to run attack on attack_types: attacks to run privacy_report_metadata: the metadata of the model under attack. + balance_attacker_training: Whether the training and test sets for the + membership inference attacker should have a balanced (roughly equal) + number of samples from the training and test sets used to develop + the model under attack. Returns: the attack result. @@ -131,7 +142,9 @@ def run_attacks( for single_slice_spec in input_slice_specs: attack_input_slice = get_slice(attack_input, single_slice_spec) for attack_type in attack_types: - attack_results.append(_run_attack(attack_input_slice, attack_type)) + attack_results.append( + _run_attack(attack_input_slice, attack_type, + balance_attacker_training)) privacy_report_metadata = _compute_missing_privacy_report_metadata( privacy_report_metadata, attack_input) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models.py b/tensorflow_privacy/privacy/membership_inference_attack/models.py index 86004d7..54674e0 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models.py @@ -43,7 +43,8 @@ class AttackerData: def create_attacker_data(attack_input_data: AttackInputData, - test_fraction: float = 0.25) -> AttackerData: + test_fraction: float = 0.25, + balance: bool = True) -> AttackerData: """Prepare AttackInputData to train ML attackers. Combines logits and losses and performs a random train-test split. @@ -51,6 +52,10 @@ def create_attacker_data(attack_input_data: AttackInputData, Args: attack_input_data: Original AttackInputData test_fraction: Fraction of the dataset to include in the test split. + balance: Whether the training and test sets for the membership inference + attacker should have a balanced (roughly equal) number of samples + from the training and test sets used to develop the model + under attack. Returns: AttackerData. @@ -60,20 +65,33 @@ def create_attacker_data(attack_input_data: AttackInputData, attack_input_test = _column_stack(attack_input_data.logits_or_probs_test, attack_input_data.get_loss_test()) + if balance: + min_size = min(attack_input_data.get_train_size(), + attack_input_data.get_test_size()) + attack_input_train = _sample_multidimensional_array(attack_input_train, + min_size) + attack_input_test = _sample_multidimensional_array(attack_input_test, + min_size) + features_all = np.concatenate((attack_input_train, attack_input_test)) - labels_all = np.concatenate(((np.zeros(attack_input_data.get_train_size())), - (np.ones(attack_input_data.get_test_size())))) + labels_all = np.concatenate( + ((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test))))) # Perform a train-test split features_train, features_test, \ is_training_labels_train, is_training_labels_test = \ model_selection.train_test_split( - features_all, labels_all, test_size=test_fraction) + features_all, labels_all, test_size=test_fraction, stratify=labels_all) return AttackerData(features_train, is_training_labels_train, features_test, is_training_labels_test) +def _sample_multidimensional_array(array, size): + indices = np.random.choice(len(array), size, replace=False) + return array[indices] + + def _column_stack(logits, loss): """Stacks logits and losses. diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models_test.py b/tensorflow_privacy/privacy/membership_inference_attack/models_test.py index b5cf1ac..c55ab98 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models_test.py @@ -34,19 +34,20 @@ class TrainedAttackerTest(absltest.TestCase): def test_create_attacker_data_loss_only(self): attack_input = AttackInputData( - loss_train=np.array([1]), loss_test=np.array([2])) + loss_train=np.array([1, 3]), loss_test=np.array([2, 4])) attacker_data = models.create_attacker_data(attack_input, 0.5) - self.assertLen(attacker_data.features_test, 1) - self.assertLen(attacker_data.features_train, 1) + self.assertLen(attacker_data.features_test, 2) + self.assertLen(attacker_data.features_train, 2) def test_create_attacker_data_loss_and_logits(self): attack_input = AttackInputData( - logits_train=np.array([[1, 2], [5, 6]]), + logits_train=np.array([[1, 2], [5, 6], [8, 9]]), logits_test=np.array([[10, 11], [14, 15]]), - loss_train=np.array([3, 7]), + loss_train=np.array([3, 7, 10]), loss_test=np.array([12, 16])) - attacker_data = models.create_attacker_data(attack_input, 0.25) - self.assertLen(attacker_data.features_test, 1) + attacker_data = models.create_attacker_data( + attack_input, 0.25, balance=False) + self.assertLen(attacker_data.features_test, 2) self.assertLen(attacker_data.features_train, 3) for i, feature in enumerate(attacker_data.features_train): @@ -54,6 +55,21 @@ class TrainedAttackerTest(absltest.TestCase): expected = feature[:2] not in attack_input.logits_train self.assertEqual(attacker_data.is_training_labels_train[i], expected) + def test_balanced_create_attacker_data_loss_and_logits(self): + attack_input = AttackInputData( + logits_train=np.array([[1, 2], [5, 6], [8, 9]]), + logits_test=np.array([[10, 11], [14, 15], [17, 18]]), + loss_train=np.array([3, 7, 10]), + loss_test=np.array([12, 16, 19])) + attacker_data = models.create_attacker_data(attack_input, 0.33) + self.assertLen(attacker_data.features_test, 2) + self.assertLen(attacker_data.features_train, 4) + + for i, feature in enumerate(attacker_data.features_train): + self.assertLen(feature, 3) # each feature has two logits and one loss + expected = feature[:2] not in attack_input.logits_train + self.assertEqual(attacker_data.is_training_labels_train[i], expected) + if __name__ == '__main__': absltest.main()