From 2f0a078dd92386429f6687ef063f376a8d9d0ba3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Sep 2020 01:04:49 -0700 Subject: [PATCH] Adds Privacy Report metadata to AttackResults. PiperOrigin-RevId: 329871255 --- .../data_structures.py | 20 +++++++--- .../membership_inference_attack/example.py | 13 ++++++- .../membership_inference_attack_new.py | 39 +++++++++++++++++-- .../membership_inference_attack_new_test.py | 9 +++++ 4 files changed, 72 insertions(+), 9 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 082a90b..a309b35 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -308,10 +308,6 @@ class SingleAttackResult: attack_type: AttackType roc_curve: RocCurve # for drawing and metrics calculation - # TODO(b/162693190): Add more metrics. Think which info we should store - # to derive metrics like f1_score or accuracy. Should we store labels and - # predictions, or rather some aggregate data? - def get_attacker_advantage(self): return self.roc_curve.get_attacker_advantage() @@ -329,12 +325,26 @@ class SingleAttackResult: ]) +@dataclass +class PrivacyReportMetadata: + """Metadata about the evaluated model. + + Used to create a privacy report based on AttackResults. + """ + accuracy_train: float = None + accuracy_test: float = None + + loss_train: float = None + loss_test: float = None + + @dataclass class AttackResults: """Results from running multiple attacks.""" - # add metadata, such as parameters of attack evaluation, input data etc single_attack_results: Iterable[SingleAttackResult] + privacy_report_metadata: PrivacyReportMetadata = None + def calculate_pd_dataframe(self): """Returns all metrics as a Pandas DataFrame.""" slice_features = [] diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index 090fd56..84c8468 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -23,6 +23,7 @@ import tempfile import matplotlib.pyplot as plt import numpy as np import pandas as pd +from sklearn import metrics from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras.utils import to_categorical @@ -30,6 +31,8 @@ from tensorflow_privacy.privacy.membership_inference_attack import membership_in from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ + PrivacyReportMetadata from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting @@ -112,6 +115,13 @@ def crossentropy(true_labels, predictions): keras.backend.variable(predictions))) +# Add metadata to generate a privacy report. +privacy_report_metadata = PrivacyReportMetadata( + accuracy_train=metrics.accuracy_score(training_labels, + np.argmax(training_pred, axis=1)), + accuracy_test=metrics.accuracy_score(test_labels, + np.argmax(test_pred, axis=1))) + attack_results = mia.run_attacks( AttackInputData( labels_train=training_labels, @@ -121,7 +131,8 @@ attack_results = mia.run_attacks( loss_train=crossentropy(training_labels, training_pred), loss_test=crossentropy(test_labels, test_pred)), SlicingSpec(entire_dataset=True, by_class=True), - attack_types=(AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION)) + attack_types=(AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION), + privacy_report_metadata=None) # Example of saving the results to the file and loading them back. with tempfile.TemporaryDirectory() as tmpdirname: diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py index b0ce912..444b8bf 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py @@ -27,6 +27,8 @@ from tensorflow_privacy.privacy.membership_inference_attack import models from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ + PrivacyReportMetadata from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec @@ -103,8 +105,8 @@ def run_attack(attack_input: AttackInputData, attack_type: AttackType): def run_attacks( attack_input: AttackInputData, slicing_spec: SlicingSpec = None, - attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,) -) -> AttackResults: + attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), + privacy_report_metadata: PrivacyReportMetadata = None) -> AttackResults: """Run all attacks.""" attack_input.validate() attack_results = [] @@ -118,4 +120,35 @@ def run_attacks( for attack_type in attack_types: attack_results.append(run_attack(attack_input_slice, attack_type)) - return AttackResults(single_attack_results=attack_results) + privacy_report_metadata = _compute_missing_privacy_report_metadata( + privacy_report_metadata, attack_input) + + return AttackResults( + single_attack_results=attack_results, + privacy_report_metadata=privacy_report_metadata) + + +def _compute_missing_privacy_report_metadata( + metadata: PrivacyReportMetadata, + attack_input: AttackInputData) -> PrivacyReportMetadata: + """Populates metadata fields if they are missing.""" + if metadata is None: + metadata = PrivacyReportMetadata() + if metadata.accuracy_train is None: + metadata.accuracy_train = _get_accuracy(attack_input.logits_train, + attack_input.labels_train) + if metadata.accuracy_test is None: + metadata.accuracy_test = _get_accuracy(attack_input.logits_test, + attack_input.labels_test) + if metadata.loss_train is None: + metadata.loss_train = np.average(attack_input.get_loss_train()) + if metadata.loss_test is None: + metadata.loss_test = np.average(attack_input.get_loss_test()) + return metadata + + +def _get_accuracy(logits, labels): + """Computes the accuracy if it is missing.""" + if logits is None or labels is None: + return None + return metrics.accuracy_score(labels, np.argmax(logits, axis=1)) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py index ae291fb..1803fa0 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py @@ -72,6 +72,15 @@ class RunAttacksTest(absltest.TestCase): expected_slice = SingleSliceSpec(SlicingFeature.CLASS, 2) self.assertEqual(result.single_attack_results[3].slice_spec, expected_slice) + def test_accuracy(self): + predictions = [[0.5, 0.2, 0.3], [0.1, 0.6, 0.3], [0.5, 0.2, 0.3]] + logits = [[1, -1, -3], [-3, -1, -2], [9, 8, 8.5]] + labels = [0, 1, 2] + self.assertEqual(mia._get_accuracy(predictions, labels), 2 / 3) + self.assertEqual(mia._get_accuracy(logits, labels), 2 / 3) + # If accuracy is already present, simply return it. + self.assertIsNone(mia._get_accuracy(None, labels)) + if __name__ == '__main__': absltest.main()