diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 2f7c205..1140611 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -207,12 +207,16 @@ class AttackInputData: @property def logits_or_probs_train(self): """Returns train logits or probs whatever is not None.""" - return self.probs_train or self.logits_train + if self.logits_train is not None: + return self.logits_train + return self.probs_train @property def logits_or_probs_test(self): """Returns test logits or probs whatever is not None.""" - return self.probs_test or self.logits_test + if self.logits_test is not None: + return self.logits_test + return self.probs_test @staticmethod def _get_entropy(logits: np.ndarray, true_labels: np.ndarray): diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index d5c9d1d..caa8ab6 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -82,6 +82,16 @@ class AttackInputDataTest(absltest.TestCase): np.testing.assert_equal(attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0]) + def test_get_probs_sizes(self): + attack_input = AttackInputData( + probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]), + probs_test=np.array([[0, 0.0001, 0.9999]]), + labels_train=np.array([1, 0]), + labels_test=np.array([0])) + + np.testing.assert_equal(attack_input.get_train_size(), 2) + np.testing.assert_equal(attack_input.get_test_size(), 1) + def test_get_entropy(self): attack_input = AttackInputData( logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), diff --git a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py index fe108a6..a6694e4 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py @@ -76,17 +76,18 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float, return _slice_data_by_indices(data, idx_train, idx_test) -def _indices_by_classification(logits, labels, correctly_classified): - idx_correct = labels == np.argmax(logits, axis=1) +def _indices_by_classification(logits_or_probs, labels, correctly_classified): + idx_correct = labels == np.argmax(logits_or_probs, axis=1) return idx_correct if correctly_classified else np.invert(idx_correct) def _slice_by_classification_correctness(data: AttackInputData, correctly_classified: bool): - idx_train = _indices_by_classification(data.logits_train, data.labels_train, + idx_train = _indices_by_classification(data.logits_or_probs_train, + data.labels_train, correctly_classified) - idx_test = _indices_by_classification(data.logits_test, data.labels_test, - correctly_classified) + idx_test = _indices_by_classification(data.logits_or_probs_test, + data.labels_test, correctly_classified) return _slice_data_by_indices(data, idx_train, idx_test)