From ca347b8995cbee924dd121afa941583097ffd3fd Mon Sep 17 00:00:00 2001 From: David Marn Date: Thu, 15 Apr 2021 02:18:36 -0700 Subject: [PATCH] Trained attackers no longer fail when labels are missing. PiperOrigin-RevId: 368598111 --- .../data_structures.py | 18 +++++++++++++--- .../membership_inference_attack.py | 21 ++++++++++++------- .../membership_inference_attack_test.py | 15 +++++++++++++ 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 72474c8..1e2f075 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -249,8 +249,14 @@ class AttackInputData: return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1) def get_loss_train(self): - """Calculates (if needed) cross-entropy losses for the training set.""" + """Calculates (if needed) cross-entropy losses for the training set. + + Returns: + Loss (or None if neither the loss nor the labels are present). + """ if self.loss_train is None: + if self.labels_train is None: + return None if self.logits_train is not None: self.loss_train = utils.log_loss_from_logits(self.labels_train, self.logits_train) @@ -259,9 +265,15 @@ class AttackInputData: return self.loss_train def get_loss_test(self): - """Calculates (if needed) cross-entropy losses for the test set.""" + """Calculates (if needed) cross-entropy losses for the test set. + + Returns: + Loss (or None if neither the loss nor the labels are present). + """ if self.loss_test is None: - if self.logits_train is not None: + if self.labels_test is None: + return None + if self.logits_test is not None: self.loss_test = utils.log_loss_from_logits(self.labels_test, self.logits_test) else: diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py index 6b670be..0914c03 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -29,8 +29,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize from tensorflow_privacy.privacy.membership_inference_attack.data_structures import MembershipProbabilityResults -from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ - PrivacyReportMetadata +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleMembershipProbabilityResult @@ -93,11 +92,15 @@ def _run_trained_attack(attack_input: AttackInputData, def _run_threshold_attack(attack_input: AttackInputData): + """Runs a threshold attack on loss.""" ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size() + loss_train = attack_input.get_loss_train() + loss_test = attack_input.get_loss_test() + if loss_train is None or loss_test is None: + raise ValueError('Not possible to run threshold attack without losses.') fpr, tpr, thresholds = metrics.roc_curve( np.concatenate((np.zeros(ntrain), np.ones(ntest))), - np.concatenate( - (attack_input.get_loss_train(), attack_input.get_loss_test()))) + np.concatenate((loss_train, loss_test))) roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) @@ -313,10 +316,12 @@ def _compute_missing_privacy_report_metadata( if metadata.accuracy_test is None: metadata.accuracy_test = _get_accuracy(attack_input.logits_test, attack_input.labels_test) - if metadata.loss_train is None: - metadata.loss_train = np.average(attack_input.get_loss_train()) - if metadata.loss_test is None: - metadata.loss_test = np.average(attack_input.get_loss_test()) + loss_train = attack_input.get_loss_train() + loss_test = attack_input.get_loss_test() + if metadata.loss_train is None and loss_train is not None: + metadata.loss_train = np.average(loss_train) + if metadata.loss_test is None and loss_test is not None: + metadata.loss_test = np.average(loss_test) return metadata diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py index 6d46d80..5b8d82b 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py @@ -36,6 +36,14 @@ def get_test_input(n_train, n_test): labels_test=np.array([i % 5 for i in range(n_test)])) +def get_test_input_logits_only(n_train, n_test): + """Get example input logits for attacks.""" + rng = np.random.RandomState(4) + return AttackInputData( + logits_train=rng.randn(n_train, 5) + 0.2, + logits_test=rng.randn(n_test, 5) + 0.2) + + class RunAttacksTest(absltest.TestCase): def test_run_attacks_size(self): @@ -45,6 +53,13 @@ class RunAttacksTest(absltest.TestCase): self.assertLen(result.single_attack_results, 2) + def test_trained_attacks_logits_only_size(self): + result = mia.run_attacks( + get_test_input_logits_only(100, 100), SlicingSpec(), + (AttackType.LOGISTIC_REGRESSION,)) + + self.assertLen(result.single_attack_results, 1) + def test_run_attack_trained_sets_attack_type(self): result = mia._run_attack( get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)