Introduces an AttackResultsCollection class for the ML Privacy report.

PiperOrigin-RevId: 335858822
This commit is contained in:
David Marn 2020-10-07 06:59:08 -07:00 committed by A. Unique TensorFlower
parent e19c53a78c
commit 703cd413c6
5 changed files with 125 additions and 20 deletions

View file

@ -15,6 +15,8 @@
# Lint as: python3
"""Data structures representing attack inputs, configuration, outputs."""
import enum
import glob
import os
import pickle
from typing import Any, Iterable, Union
@ -577,6 +579,41 @@ class AttackResults:
return pickle.load(inp)
@dataclass
class AttackResultsCollection:
"""A collection of AttackResults."""
attack_results_list: Iterable[AttackResults]
def append(self, attack_results: AttackResults):
self.attack_results_list.append(attack_results)
def save(self, dirname):
"""Saves self to a pickle file."""
for i, attack_results in enumerate(self.attack_results_list):
filepath = os.path.join(dirname,
_get_attack_results_filename(attack_results, i))
attack_results.save(filepath)
@classmethod
def load(cls, dirname):
"""Loads AttackResultsCollection from all files in a directory."""
loaded_collection = AttackResultsCollection([])
for filepath in sorted(glob.glob('%s/*' % dirname)):
with open(filepath, 'rb') as inp:
loaded_collection.attack_results_list.append(pickle.load(inp))
return loaded_collection
def _get_attack_results_filename(attack_results: AttackResults, index: int):
"""Creates a filename for a specific set of AttackResults."""
metadata = attack_results.privacy_report_metadata
if metadata is not None:
return '%s_%s_%s.pickle' % (metadata.model_variant_label,
metadata.epoch_num, index)
return '%s.pickle' % index
def get_flattened_attack_metrics(results: AttackResults):
"""Get flattened attack metrics.

View file

@ -23,7 +23,9 @@ import pandas as pd
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
@ -206,6 +208,62 @@ class SingleAttackResultTest(absltest.TestCase):
self.assertEqual(result.get_attacker_advantage(), 0.0)
class AttackResultsCollectionTest(absltest.TestCase):
def __init__(self, *args, **kwargs):
super(AttackResultsCollectionTest, self).__init__(*args, **kwargs)
self.some_attack_result = SingleAttackResult(
slice_spec=SingleSliceSpec(None),
attack_type=AttackType.THRESHOLD_ATTACK,
roc_curve=RocCurve(
tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])))
self.results_epoch_10 = AttackResults(
single_attack_results=[self.some_attack_result],
privacy_report_metadata=PrivacyReportMetadata(
accuracy_train=0.4,
accuracy_test=0.3,
epoch_num=10,
model_variant_label='default'))
self.results_epoch_15 = AttackResults(
single_attack_results=[self.some_attack_result],
privacy_report_metadata=PrivacyReportMetadata(
accuracy_train=0.5,
accuracy_test=0.4,
epoch_num=15,
model_variant_label='default'))
self.attack_results_no_metadata = AttackResults(
single_attack_results=[self.some_attack_result])
self.collection_with_metadata = AttackResultsCollection(
[self.results_epoch_10, self.results_epoch_15])
self.collection_no_metadata = AttackResultsCollection(
[self.attack_results_no_metadata, self.attack_results_no_metadata])
def test_save_with_metadata(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.collection_with_metadata.save(tmpdirname)
loaded_collection = AttackResultsCollection.load(tmpdirname)
self.assertEqual(
repr(self.collection_with_metadata), repr(loaded_collection))
self.assertLen(loaded_collection.attack_results_list, 2)
def test_save_no_metadata(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.collection_no_metadata.save(tmpdirname)
loaded_collection = AttackResultsCollection.load(tmpdirname)
self.assertEqual(repr(self.collection_no_metadata), repr(loaded_collection))
self.assertLen(loaded_collection.attack_results_list, 2)
class AttackResultsTest(absltest.TestCase):
perfect_classifier_result: SingleAttackResult

View file

@ -32,6 +32,7 @@ from tensorflow_privacy.privacy.membership_inference_attack import membership_in
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
@ -119,7 +120,7 @@ def crossentropy(true_labels, predictions):
def main(unused_argv):
epoch_results = []
epoch_results = AttackResultsCollection([])
num_epochs = 2
models = {

View file

@ -19,6 +19,7 @@ import matplotlib.pyplot as plt
import pandas as pd
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsDFColumns
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import ENTIRE_DATASET_SLICE_STR
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
@ -29,7 +30,7 @@ EPOCH_STR = 'Epoch'
TRAIN_ACCURACY_STR = 'Train accuracy'
def plot_by_epochs(results: Iterable[AttackResults],
def plot_by_epochs(results: AttackResultsCollection,
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
"""Plots privacy vulnerabilities vs epoch numbers for a single model variant.
@ -43,8 +44,9 @@ def plot_by_epochs(results: Iterable[AttackResults],
A pyplot figure with privacy vs accuracy plots.
"""
_validate_results(results)
all_results_df = _calculate_combined_df_with_metadata(results)
_validate_results(results.attack_results_list)
all_results_df = _calculate_combined_df_with_metadata(
results.attack_results_list)
return _generate_subplots(
all_results_df=all_results_df,
x_axis_metric='Epoch',
@ -53,7 +55,7 @@ def plot_by_epochs(results: Iterable[AttackResults],
def plot_privacy_vs_accuracy_single_model(
results: Iterable[AttackResults], privacy_metrics: Iterable[PrivacyMetric]):
results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]):
"""Plots privacy vulnerabilities vs accuracy plots for a single model variant.
In case multiple privacy metrics are specified, the plot will feature
@ -66,8 +68,9 @@ def plot_privacy_vs_accuracy_single_model(
A pyplot figure with privacy vs accuracy plots.
"""
_validate_results(results)
all_results_df = _calculate_combined_df_with_metadata(results)
_validate_results(results.attack_results_list)
all_results_df = _calculate_combined_df_with_metadata(
results.attack_results_list)
return _generate_subplots(
all_results_df=all_results_df,
x_axis_metric='Train accuracy',

View file

@ -20,6 +20,7 @@ import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack import privacy_report
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
PrivacyReportMetadata
@ -80,12 +81,14 @@ class PrivacyReportTest(absltest.TestCase):
def test_plot_by_epochs_no_metadata(self):
# Raise error if metadata is missing
self.assertRaises(ValueError, privacy_report.plot_by_epochs,
(self.attack_results_no_metadata,), ['AUC'])
self.assertRaises(
ValueError, privacy_report.plot_by_epochs,
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
def test_single_metric_plot_by_epochs(self):
fig = privacy_report.plot_by_epochs(
(self.results_epoch_10, self.results_epoch_15), ['AUC'])
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC'])
# extract data from figure.
auc_data = fig.gca().lines[0].get_data()
# X axis lists epoch values
@ -97,7 +100,7 @@ class PrivacyReportTest(absltest.TestCase):
def test_multiple_metrics_plot_by_epochs(self):
fig = privacy_report.plot_by_epochs(
(self.results_epoch_10, self.results_epoch_15),
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC', 'Attacker advantage'])
# extract data from figure.
auc_data = fig.axes[0].lines[0].get_data()
@ -113,8 +116,9 @@ class PrivacyReportTest(absltest.TestCase):
def test_multiple_metrics_plot_by_epochs_multiple_models(self):
fig = privacy_report.plot_by_epochs(
(self.results_epoch_10, self.results_epoch_15,
self.results_epoch_15_model_2), ['AUC', 'Attacker advantage'])
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
self.results_epoch_15_model_2)),
['AUC', 'Attacker advantage'])
# extract data from figure.
# extract data from figure.
auc_data_model_1 = fig.axes[0].lines[0].get_data()
@ -136,13 +140,14 @@ class PrivacyReportTest(absltest.TestCase):
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
# Raise error if metadata is missing
self.assertRaises(ValueError,
privacy_report.plot_privacy_vs_accuracy_single_model,
(self.attack_results_no_metadata,), ['AUC'])
self.assertRaises(
ValueError, privacy_report.plot_privacy_vs_accuracy_single_model,
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
(self.results_epoch_10, self.results_epoch_15), ['AUC'])
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC'])
# extract data from figure.
auc_data = fig.gca().lines[0].get_data()
# X axis lists epoch values
@ -154,7 +159,7 @@ class PrivacyReportTest(absltest.TestCase):
def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
(self.results_epoch_10, self.results_epoch_15),
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC', 'Attacker advantage'])
# extract data from figure.
auc_data = fig.axes[0].lines[0].get_data()
@ -170,8 +175,9 @@ class PrivacyReportTest(absltest.TestCase):
def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
(self.results_epoch_10, self.results_epoch_15,
self.results_epoch_15_model_2), ['AUC', 'Attacker advantage'])
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
self.results_epoch_15_model_2)),
['AUC', 'Attacker advantage'])
# extract data from figure.
auc_data_model_1 = fig.axes[0].lines[0].get_data()
auc_data_model_2 = fig.axes[0].lines[1].get_data()