diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index ae2974b..0da4372 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -161,6 +161,11 @@ class AttackInputData: logits_train: np.ndarray = None logits_test: np.ndarray = None + # Predicted probabilities for each class. They can be derived from logits, + # so they can be set only if logits are not explicitly provided. + probs_train: np.ndarray = None + probs_test: np.ndarray = None + # Contains ground-truth classes. Classes are assumed to be integers starting # from 0. labels_train: np.ndarray = None @@ -185,6 +190,16 @@ class AttackInputData: 'Please set labels_train and labels_test') return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1 + @property + def logits_or_probs_train(self): + """Returns train logits or probs whatever is not None.""" + return self.logits_train if self.probs_train is None else self.probs_train + + @property + def logits_or_probs_test(self): + """Returns test logits or probs whatever is not None.""" + return self.logits_test if self.probs_test is None else self.probs_test + @staticmethod def _get_entropy(logits: np.ndarray, true_labels: np.ndarray): """Computes the prediction entropy (by Song and Mittal).""" @@ -199,7 +214,7 @@ class AttackInputData: # See the Equation (7) in https://arxiv.org/pdf/2003.10595.pdf return np.sum(np.multiply(probs, _log_value(probs)), axis=1) else: - # When given the groud truth label, we compute the + # When given the ground truth label, we compute the # modified prediction entropy. # See the Equation (8) in https://arxiv.org/pdf/2003.10595.pdf log_probs = _log_value(probs) @@ -218,15 +233,21 @@ class AttackInputData: def get_loss_train(self): """Calculates (if needed) cross-entropy losses for the training set.""" if self.loss_train is None: - self.loss_train = utils.log_loss_from_logits(self.labels_train, - self.logits_train) + if self.logits_train is not None: + self.loss_train = utils.log_loss_from_logits(self.labels_train, + self.logits_train) + else: + self.loss_train = utils.log_loss(self.labels_train, self.probs_train) return self.loss_train def get_loss_test(self): """Calculates (if needed) cross-entropy losses for the test set.""" if self.loss_test is None: - self.loss_test = utils.log_loss_from_logits(self.labels_test, - self.logits_test) + if self.logits_train is not None: + self.loss_test = utils.log_loss_from_logits(self.labels_test, + self.logits_test) + else: + self.loss_test = utils.log_loss(self.labels_test, self.probs_test) return self.loss_test def get_entropy_train(self): @@ -267,6 +288,13 @@ class AttackInputData: raise ValueError( 'logits_train and logits_test should both be either set or unset') + if (self.probs_train is None) != (self.probs_test is None): + raise ValueError( + 'probs_train and probs_test should both be either set or unset') + + if (self.logits_train is not None) and (self.probs_train is not None): + raise ValueError('Logits and probs can not be both set') + if (self.labels_train is None) != (self.labels_test is None): raise ValueError( 'labels_train and labels_test should both be either set or unset') @@ -286,6 +314,8 @@ class AttackInputData: _is_np_array(self.logits_train, 'logits_train') _is_np_array(self.logits_test, 'logits_test') + _is_np_array(self.probs_train, 'probs_train') + _is_np_array(self.probs_test, 'probs_test') _is_np_array(self.labels_train, 'labels_train') _is_np_array(self.labels_test, 'labels_test') _is_np_array(self.loss_train, 'loss_train') @@ -295,6 +325,8 @@ class AttackInputData: _is_last_dim_equal(self.logits_train, 'logits_train', self.logits_test, 'logits_test') + _is_last_dim_equal(self.probs_train, 'probs_train', self.probs_test, + 'probs_test') _is_array_one_dimensional(self.loss_train, 'loss_train') _is_array_one_dimensional(self.loss_test, 'loss_test') _is_array_one_dimensional(self.entropy_train, 'entropy_train') @@ -311,6 +343,8 @@ class AttackInputData: _append_array_shape(self.entropy_test, 'entropy_test', result) _append_array_shape(self.logits_train, 'logits_train', result) _append_array_shape(self.logits_test, 'logits_test', result) + _append_array_shape(self.probs_train, 'probs_train', result) + _append_array_shape(self.probs_test, 'probs_test', result) _append_array_shape(self.labels_train, 'labels_train', result) _append_array_shape(self.labels_test, 'labels_test', result) result.append(')') diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index a052e97..6a3be8e 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -46,7 +46,7 @@ class SingleSliceSpecTest(parameterized.TestCase): class AttackInputDataTest(absltest.TestCase): - def test_get_loss(self): + def test_get_loss_from_logits(self): attack_input = AttackInputData( logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]), logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]), @@ -58,6 +58,18 @@ class AttackInputDataTest(absltest.TestCase): np.testing.assert_allclose( attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7) + def test_get_loss_from_probs(self): + attack_input = AttackInputData( + probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]), + probs_test=np.array([[0, 0.0001, 0.9999], [0.07, 0.18, 0.75]]), + labels_train=np.array([1, 0]), + labels_test=np.array([0, 2])) + + np.testing.assert_allclose( + attack_input.get_loss_train(), [2.30258509, 0.2231436], atol=1e-7) + np.testing.assert_allclose( + attack_input.get_loss_test(), [18.42068074, 0.28768207], atol=1e-7) + def test_get_loss_explicitly_provided(self): attack_input = AttackInputData( loss_train=np.array([1.0, 3.0, 6.0]), @@ -99,6 +111,8 @@ class AttackInputDataTest(absltest.TestCase): def test_validator(self): self.assertRaises(ValueError, AttackInputData(logits_train=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(probs_train=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(labels_train=np.array([])).validate) self.assertRaises(ValueError, @@ -107,6 +121,8 @@ class AttackInputDataTest(absltest.TestCase): AttackInputData(entropy_train=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(logits_test=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(probs_test=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(labels_test=np.array([])).validate) self.assertRaises(ValueError, @@ -114,6 +130,14 @@ class AttackInputDataTest(absltest.TestCase): self.assertRaises(ValueError, AttackInputData(entropy_test=np.array([])).validate) self.assertRaises(ValueError, AttackInputData().validate) + # Tests that having both logits and probs are not allowed. + self.assertRaises( + ValueError, + AttackInputData( + logits_train=np.array([]), + logits_test=np.array([]), + probs_train=np.array([]), + probs_test=np.array([])).validate) class RocCurveTest(absltest.TestCase): diff --git a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py index d62be31..8e5e4b0 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py @@ -38,11 +38,13 @@ def _slice_data_by_indices(data: AttackInputData, idx_train, # Slice train data. result.logits_train = _slice_if_not_none(data.logits_train, idx_train) + result.probs_train = _slice_if_not_none(data.probs_train, idx_train) result.labels_train = _slice_if_not_none(data.labels_train, idx_train) result.loss_train = _slice_if_not_none(data.loss_train, idx_train) # Slice test data. result.logits_test = _slice_if_not_none(data.logits_test, idx_test) + result.probs_test = _slice_if_not_none(data.probs_test, idx_test) result.labels_test = _slice_if_not_none(data.labels_test, idx_test) result.loss_test = _slice_if_not_none(data.loss_test, idx_test) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py index e6570a3..75b8a3f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py @@ -106,13 +106,24 @@ class GetSliceTest(absltest.TestCase): # Create test data for 3 class classification task. logits_train = np.array([[0, 1, 0], [2, 0, 3], [4, 5, 0], [6, 7, 0]]) logits_test = np.array([[10, 0, 11], [12, 13, 0], [14, 15, 0], [0, 16, 17]]) + probs_train = np.array([[0, 1, 0], [0.1, 0, 0.7], [0.4, 0.6, 0], + [0.3, 0.7, 0]]) + probs_test = np.array([[0.4, 0, 0.6], [0.1, 0.9, 0], [0.15, 0.85, 0], + [0, 0, 1]]) labels_train = np.array([1, 0, 1, 2]) labels_test = np.array([1, 2, 0, 2]) loss_train = np.array([2, 0.25, 4, 3]) loss_test = np.array([0.5, 3.5, 7, 4.5]) - self.input_data = AttackInputData(logits_train, logits_test, labels_train, - labels_test, loss_train, loss_test) + self.input_data = AttackInputData( + logits_train=logits_train, + logits_test=logits_test, + probs_train=probs_train, + probs_test=probs_test, + labels_train=labels_train, + labels_test=labels_test, + loss_train=loss_train, + loss_test=loss_test) def test_slice_entire_dataset(self): entire_dataset_slice = SingleSliceSpec() @@ -131,6 +142,11 @@ class GetSliceTest(absltest.TestCase): self.assertLen(output.logits_test, 1) self.assertTrue((output.logits_train[1] == [4, 5, 0]).all()) + # Check probs. + self.assertLen(output.probs_train, 2) + self.assertLen(output.probs_test, 1) + self.assertTrue((output.probs_train[1] == [0.4, 0.6, 0]).all()) + # Check labels. self.assertLen(output.labels_train, 2) self.assertLen(output.labels_test, 1) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index 84c8468..5a9ff4a 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -101,7 +101,7 @@ model.fit( to_categorical(training_labels, num_clusters), validation_data=(test_features, to_categorical(test_labels, num_clusters)), batch_size=64, - epochs=10, + epochs=2, shuffle=True) training_pred = model.predict(training_features) @@ -126,8 +126,8 @@ attack_results = mia.run_attacks( AttackInputData( labels_train=training_labels, labels_test=test_labels, - logits_train=training_pred, - logits_test=test_pred, + probs_train=training_pred, + probs_test=test_pred, loss_train=crossentropy(training_labels, training_pred), loss_test=crossentropy(test_labels, test_pred)), SlicingSpec(entire_dataset=True, by_class=True), diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py index 42881f4..5e695b9 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py @@ -28,9 +28,10 @@ def get_test_input(n_train, n_test): """Get example inputs for attacks.""" rng = np.random.RandomState(4) return AttackInputData( - rng.randn(n_train, 5) + 0.2, - rng.randn(n_test, 5) + 0.2, np.array([i % 5 for i in range(n_train)]), - np.array([i % 5 for i in range(n_test)])) + logits_train=rng.randn(n_train, 5) + 0.2, + logits_test=rng.randn(n_test, 5) + 0.2, + labels_train=np.array([i % 5 for i in range(n_train)]), + labels_test=np.array([i % 5 for i in range(n_test)])) class RunAttacksTest(absltest.TestCase): diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models.py b/tensorflow_privacy/privacy/membership_inference_attack/models.py index 0f4c238..86004d7 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models.py @@ -55,9 +55,9 @@ def create_attacker_data(attack_input_data: AttackInputData, Returns: AttackerData. """ - attack_input_train = _column_stack(attack_input_data.logits_train, + attack_input_train = _column_stack(attack_input_data.logits_or_probs_train, attack_input_data.get_loss_train()) - attack_input_test = _column_stack(attack_input_data.logits_test, + attack_input_test = _column_stack(attack_input_data.logits_or_probs_test, attack_input_data.get_loss_test()) features_all = np.concatenate((attack_input_train, attack_input_test))