Adds Privacy Report metadata to AttackResults.

PiperOrigin-RevId: 329871255
This commit is contained in:
A. Unique TensorFlower 2020-09-03 01:04:49 -07:00
parent 8d89ef0a4b
commit 2f0a078dd9
4 changed files with 72 additions and 9 deletions

View file

@ -308,10 +308,6 @@ class SingleAttackResult:
attack_type: AttackType
roc_curve: RocCurve # for drawing and metrics calculation
# TODO(b/162693190): Add more metrics. Think which info we should store
# to derive metrics like f1_score or accuracy. Should we store labels and
# predictions, or rather some aggregate data?
def get_attacker_advantage(self):
return self.roc_curve.get_attacker_advantage()
@ -329,12 +325,26 @@ class SingleAttackResult:
])
@dataclass
class PrivacyReportMetadata:
"""Metadata about the evaluated model.
Used to create a privacy report based on AttackResults.
"""
accuracy_train: float = None
accuracy_test: float = None
loss_train: float = None
loss_test: float = None
@dataclass
class AttackResults:
"""Results from running multiple attacks."""
# add metadata, such as parameters of attack evaluation, input data etc
single_attack_results: Iterable[SingleAttackResult]
privacy_report_metadata: PrivacyReportMetadata = None
def calculate_pd_dataframe(self):
"""Returns all metrics as a Pandas DataFrame."""
slice_features = []

View file

@ -23,6 +23,7 @@ import tempfile
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import metrics
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical
@ -30,6 +31,8 @@ from tensorflow_privacy.privacy.membership_inference_attack import membership_in
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting
@ -112,6 +115,13 @@ def crossentropy(true_labels, predictions):
keras.backend.variable(predictions)))
# Add metadata to generate a privacy report.
privacy_report_metadata = PrivacyReportMetadata(
accuracy_train=metrics.accuracy_score(training_labels,
np.argmax(training_pred, axis=1)),
accuracy_test=metrics.accuracy_score(test_labels,
np.argmax(test_pred, axis=1)))
attack_results = mia.run_attacks(
AttackInputData(
labels_train=training_labels,
@ -121,7 +131,8 @@ attack_results = mia.run_attacks(
loss_train=crossentropy(training_labels, training_pred),
loss_test=crossentropy(test_labels, test_pred)),
SlicingSpec(entire_dataset=True, by_class=True),
attack_types=(AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION))
attack_types=(AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION),
privacy_report_metadata=None)
# Example of saving the results to the file and loading them back.
with tempfile.TemporaryDirectory() as tmpdirname:

View file

@ -27,6 +27,8 @@ from tensorflow_privacy.privacy.membership_inference_attack import models
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
@ -103,8 +105,8 @@ def run_attack(attack_input: AttackInputData, attack_type: AttackType):
def run_attacks(
attack_input: AttackInputData,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,)
) -> AttackResults:
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackResults:
"""Run all attacks."""
attack_input.validate()
attack_results = []
@ -118,4 +120,35 @@ def run_attacks(
for attack_type in attack_types:
attack_results.append(run_attack(attack_input_slice, attack_type))
return AttackResults(single_attack_results=attack_results)
privacy_report_metadata = _compute_missing_privacy_report_metadata(
privacy_report_metadata, attack_input)
return AttackResults(
single_attack_results=attack_results,
privacy_report_metadata=privacy_report_metadata)
def _compute_missing_privacy_report_metadata(
metadata: PrivacyReportMetadata,
attack_input: AttackInputData) -> PrivacyReportMetadata:
"""Populates metadata fields if they are missing."""
if metadata is None:
metadata = PrivacyReportMetadata()
if metadata.accuracy_train is None:
metadata.accuracy_train = _get_accuracy(attack_input.logits_train,
attack_input.labels_train)
if metadata.accuracy_test is None:
metadata.accuracy_test = _get_accuracy(attack_input.logits_test,
attack_input.labels_test)
if metadata.loss_train is None:
metadata.loss_train = np.average(attack_input.get_loss_train())
if metadata.loss_test is None:
metadata.loss_test = np.average(attack_input.get_loss_test())
return metadata
def _get_accuracy(logits, labels):
"""Computes the accuracy if it is missing."""
if logits is None or labels is None:
return None
return metrics.accuracy_score(labels, np.argmax(logits, axis=1))

View file

@ -72,6 +72,15 @@ class RunAttacksTest(absltest.TestCase):
expected_slice = SingleSliceSpec(SlicingFeature.CLASS, 2)
self.assertEqual(result.single_attack_results[3].slice_spec, expected_slice)
def test_accuracy(self):
predictions = [[0.5, 0.2, 0.3], [0.1, 0.6, 0.3], [0.5, 0.2, 0.3]]
logits = [[1, -1, -3], [-3, -1, -2], [9, 8, 8.5]]
labels = [0, 1, 2]
self.assertEqual(mia._get_accuracy(predictions, labels), 2 / 3)
self.assertEqual(mia._get_accuracy(logits, labels), 2 / 3)
# If accuracy is already present, simply return it.
self.assertIsNone(mia._get_accuracy(None, labels))
if __name__ == '__main__':
absltest.main()