forked from 626_privacy/tensorflow_privacy
Adds Privacy Report metadata to AttackResults.
PiperOrigin-RevId: 329871255
This commit is contained in:
parent
8d89ef0a4b
commit
2f0a078dd9
4 changed files with 72 additions and 9 deletions
|
@ -308,10 +308,6 @@ class SingleAttackResult:
|
|||
attack_type: AttackType
|
||||
roc_curve: RocCurve # for drawing and metrics calculation
|
||||
|
||||
# TODO(b/162693190): Add more metrics. Think which info we should store
|
||||
# to derive metrics like f1_score or accuracy. Should we store labels and
|
||||
# predictions, or rather some aggregate data?
|
||||
|
||||
def get_attacker_advantage(self):
|
||||
return self.roc_curve.get_attacker_advantage()
|
||||
|
||||
|
@ -329,12 +325,26 @@ class SingleAttackResult:
|
|||
])
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrivacyReportMetadata:
|
||||
"""Metadata about the evaluated model.
|
||||
|
||||
Used to create a privacy report based on AttackResults.
|
||||
"""
|
||||
accuracy_train: float = None
|
||||
accuracy_test: float = None
|
||||
|
||||
loss_train: float = None
|
||||
loss_test: float = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttackResults:
|
||||
"""Results from running multiple attacks."""
|
||||
# add metadata, such as parameters of attack evaluation, input data etc
|
||||
single_attack_results: Iterable[SingleAttackResult]
|
||||
|
||||
privacy_report_metadata: PrivacyReportMetadata = None
|
||||
|
||||
def calculate_pd_dataframe(self):
|
||||
"""Returns all metrics as a Pandas DataFrame."""
|
||||
slice_features = []
|
||||
|
|
|
@ -23,6 +23,7 @@ import tempfile
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn import metrics
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.utils import to_categorical
|
||||
|
@ -30,6 +31,8 @@ from tensorflow_privacy.privacy.membership_inference_attack import membership_in
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||
PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting
|
||||
|
||||
|
@ -112,6 +115,13 @@ def crossentropy(true_labels, predictions):
|
|||
keras.backend.variable(predictions)))
|
||||
|
||||
|
||||
# Add metadata to generate a privacy report.
|
||||
privacy_report_metadata = PrivacyReportMetadata(
|
||||
accuracy_train=metrics.accuracy_score(training_labels,
|
||||
np.argmax(training_pred, axis=1)),
|
||||
accuracy_test=metrics.accuracy_score(test_labels,
|
||||
np.argmax(test_pred, axis=1)))
|
||||
|
||||
attack_results = mia.run_attacks(
|
||||
AttackInputData(
|
||||
labels_train=training_labels,
|
||||
|
@ -121,7 +131,8 @@ attack_results = mia.run_attacks(
|
|||
loss_train=crossentropy(training_labels, training_pred),
|
||||
loss_test=crossentropy(test_labels, test_pred)),
|
||||
SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=(AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION))
|
||||
attack_types=(AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION),
|
||||
privacy_report_metadata=None)
|
||||
|
||||
# Example of saving the results to the file and loading them back.
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
|
|
@ -27,6 +27,8 @@ from tensorflow_privacy.privacy.membership_inference_attack import models
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||
PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
|
@ -103,8 +105,8 @@ def run_attack(attack_input: AttackInputData, attack_type: AttackType):
|
|||
def run_attacks(
|
||||
attack_input: AttackInputData,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)
|
||||
) -> AttackResults:
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackResults:
|
||||
"""Run all attacks."""
|
||||
attack_input.validate()
|
||||
attack_results = []
|
||||
|
@ -118,4 +120,35 @@ def run_attacks(
|
|||
for attack_type in attack_types:
|
||||
attack_results.append(run_attack(attack_input_slice, attack_type))
|
||||
|
||||
return AttackResults(single_attack_results=attack_results)
|
||||
privacy_report_metadata = _compute_missing_privacy_report_metadata(
|
||||
privacy_report_metadata, attack_input)
|
||||
|
||||
return AttackResults(
|
||||
single_attack_results=attack_results,
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
|
||||
|
||||
def _compute_missing_privacy_report_metadata(
|
||||
metadata: PrivacyReportMetadata,
|
||||
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
||||
"""Populates metadata fields if they are missing."""
|
||||
if metadata is None:
|
||||
metadata = PrivacyReportMetadata()
|
||||
if metadata.accuracy_train is None:
|
||||
metadata.accuracy_train = _get_accuracy(attack_input.logits_train,
|
||||
attack_input.labels_train)
|
||||
if metadata.accuracy_test is None:
|
||||
metadata.accuracy_test = _get_accuracy(attack_input.logits_test,
|
||||
attack_input.labels_test)
|
||||
if metadata.loss_train is None:
|
||||
metadata.loss_train = np.average(attack_input.get_loss_train())
|
||||
if metadata.loss_test is None:
|
||||
metadata.loss_test = np.average(attack_input.get_loss_test())
|
||||
return metadata
|
||||
|
||||
|
||||
def _get_accuracy(logits, labels):
|
||||
"""Computes the accuracy if it is missing."""
|
||||
if logits is None or labels is None:
|
||||
return None
|
||||
return metrics.accuracy_score(labels, np.argmax(logits, axis=1))
|
||||
|
|
|
@ -72,6 +72,15 @@ class RunAttacksTest(absltest.TestCase):
|
|||
expected_slice = SingleSliceSpec(SlicingFeature.CLASS, 2)
|
||||
self.assertEqual(result.single_attack_results[3].slice_spec, expected_slice)
|
||||
|
||||
def test_accuracy(self):
|
||||
predictions = [[0.5, 0.2, 0.3], [0.1, 0.6, 0.3], [0.5, 0.2, 0.3]]
|
||||
logits = [[1, -1, -3], [-3, -1, -2], [9, 8, 8.5]]
|
||||
labels = [0, 1, 2]
|
||||
self.assertEqual(mia._get_accuracy(predictions, labels), 2 / 3)
|
||||
self.assertEqual(mia._get_accuracy(logits, labels), 2 / 3)
|
||||
# If accuracy is already present, simply return it.
|
||||
self.assertIsNone(mia._get_accuracy(None, labels))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue