Introduces an AttackResultsCollection class for the ML Privacy report.

PiperOrigin-RevId: 335858822
This commit is contained in:
David Marn 2020-10-07 06:59:08 -07:00 committed by A. Unique TensorFlower
parent e19c53a78c
commit 703cd413c6
5 changed files with 125 additions and 20 deletions

View file

@ -15,6 +15,8 @@
# Lint as: python3 # Lint as: python3
"""Data structures representing attack inputs, configuration, outputs.""" """Data structures representing attack inputs, configuration, outputs."""
import enum import enum
import glob
import os
import pickle import pickle
from typing import Any, Iterable, Union from typing import Any, Iterable, Union
@ -577,6 +579,41 @@ class AttackResults:
return pickle.load(inp) return pickle.load(inp)
@dataclass
class AttackResultsCollection:
"""A collection of AttackResults."""
attack_results_list: Iterable[AttackResults]
def append(self, attack_results: AttackResults):
self.attack_results_list.append(attack_results)
def save(self, dirname):
"""Saves self to a pickle file."""
for i, attack_results in enumerate(self.attack_results_list):
filepath = os.path.join(dirname,
_get_attack_results_filename(attack_results, i))
attack_results.save(filepath)
@classmethod
def load(cls, dirname):
"""Loads AttackResultsCollection from all files in a directory."""
loaded_collection = AttackResultsCollection([])
for filepath in sorted(glob.glob('%s/*' % dirname)):
with open(filepath, 'rb') as inp:
loaded_collection.attack_results_list.append(pickle.load(inp))
return loaded_collection
def _get_attack_results_filename(attack_results: AttackResults, index: int):
"""Creates a filename for a specific set of AttackResults."""
metadata = attack_results.privacy_report_metadata
if metadata is not None:
return '%s_%s_%s.pickle' % (metadata.model_variant_label,
metadata.epoch_num, index)
return '%s.pickle' % index
def get_flattened_attack_metrics(results: AttackResults): def get_flattened_attack_metrics(results: AttackResults):
"""Get flattened attack metrics. """Get flattened attack metrics.

View file

@ -23,7 +23,9 @@ import pandas as pd
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
@ -206,6 +208,62 @@ class SingleAttackResultTest(absltest.TestCase):
self.assertEqual(result.get_attacker_advantage(), 0.0) self.assertEqual(result.get_attacker_advantage(), 0.0)
class AttackResultsCollectionTest(absltest.TestCase):
def __init__(self, *args, **kwargs):
super(AttackResultsCollectionTest, self).__init__(*args, **kwargs)
self.some_attack_result = SingleAttackResult(
slice_spec=SingleSliceSpec(None),
attack_type=AttackType.THRESHOLD_ATTACK,
roc_curve=RocCurve(
tpr=np.array([0.0, 0.5, 1.0]),
fpr=np.array([0.0, 0.5, 1.0]),
thresholds=np.array([0, 1, 2])))
self.results_epoch_10 = AttackResults(
single_attack_results=[self.some_attack_result],
privacy_report_metadata=PrivacyReportMetadata(
accuracy_train=0.4,
accuracy_test=0.3,
epoch_num=10,
model_variant_label='default'))
self.results_epoch_15 = AttackResults(
single_attack_results=[self.some_attack_result],
privacy_report_metadata=PrivacyReportMetadata(
accuracy_train=0.5,
accuracy_test=0.4,
epoch_num=15,
model_variant_label='default'))
self.attack_results_no_metadata = AttackResults(
single_attack_results=[self.some_attack_result])
self.collection_with_metadata = AttackResultsCollection(
[self.results_epoch_10, self.results_epoch_15])
self.collection_no_metadata = AttackResultsCollection(
[self.attack_results_no_metadata, self.attack_results_no_metadata])
def test_save_with_metadata(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.collection_with_metadata.save(tmpdirname)
loaded_collection = AttackResultsCollection.load(tmpdirname)
self.assertEqual(
repr(self.collection_with_metadata), repr(loaded_collection))
self.assertLen(loaded_collection.attack_results_list, 2)
def test_save_no_metadata(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.collection_no_metadata.save(tmpdirname)
loaded_collection = AttackResultsCollection.load(tmpdirname)
self.assertEqual(repr(self.collection_no_metadata), repr(loaded_collection))
self.assertLen(loaded_collection.attack_results_list, 2)
class AttackResultsTest(absltest.TestCase): class AttackResultsTest(absltest.TestCase):
perfect_classifier_result: SingleAttackResult perfect_classifier_result: SingleAttackResult

View file

@ -32,6 +32,7 @@ from tensorflow_privacy.privacy.membership_inference_attack import membership_in
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
@ -119,7 +120,7 @@ def crossentropy(true_labels, predictions):
def main(unused_argv): def main(unused_argv):
epoch_results = [] epoch_results = AttackResultsCollection([])
num_epochs = 2 num_epochs = 2
models = { models = {

View file

@ -19,6 +19,7 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsDFColumns from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsDFColumns
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import ENTIRE_DATASET_SLICE_STR from tensorflow_privacy.privacy.membership_inference_attack.data_structures import ENTIRE_DATASET_SLICE_STR
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
@ -29,7 +30,7 @@ EPOCH_STR = 'Epoch'
TRAIN_ACCURACY_STR = 'Train accuracy' TRAIN_ACCURACY_STR = 'Train accuracy'
def plot_by_epochs(results: Iterable[AttackResults], def plot_by_epochs(results: AttackResultsCollection,
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure: privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
"""Plots privacy vulnerabilities vs epoch numbers for a single model variant. """Plots privacy vulnerabilities vs epoch numbers for a single model variant.
@ -43,8 +44,9 @@ def plot_by_epochs(results: Iterable[AttackResults],
A pyplot figure with privacy vs accuracy plots. A pyplot figure with privacy vs accuracy plots.
""" """
_validate_results(results) _validate_results(results.attack_results_list)
all_results_df = _calculate_combined_df_with_metadata(results) all_results_df = _calculate_combined_df_with_metadata(
results.attack_results_list)
return _generate_subplots( return _generate_subplots(
all_results_df=all_results_df, all_results_df=all_results_df,
x_axis_metric='Epoch', x_axis_metric='Epoch',
@ -53,7 +55,7 @@ def plot_by_epochs(results: Iterable[AttackResults],
def plot_privacy_vs_accuracy_single_model( def plot_privacy_vs_accuracy_single_model(
results: Iterable[AttackResults], privacy_metrics: Iterable[PrivacyMetric]): results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]):
"""Plots privacy vulnerabilities vs accuracy plots for a single model variant. """Plots privacy vulnerabilities vs accuracy plots for a single model variant.
In case multiple privacy metrics are specified, the plot will feature In case multiple privacy metrics are specified, the plot will feature
@ -66,8 +68,9 @@ def plot_privacy_vs_accuracy_single_model(
A pyplot figure with privacy vs accuracy plots. A pyplot figure with privacy vs accuracy plots.
""" """
_validate_results(results) _validate_results(results.attack_results_list)
all_results_df = _calculate_combined_df_with_metadata(results) all_results_df = _calculate_combined_df_with_metadata(
results.attack_results_list)
return _generate_subplots( return _generate_subplots(
all_results_df=all_results_df, all_results_df=all_results_df,
x_axis_metric='Train accuracy', x_axis_metric='Train accuracy',

View file

@ -20,6 +20,7 @@ import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack import privacy_report from tensorflow_privacy.privacy.membership_inference_attack import privacy_report
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
PrivacyReportMetadata PrivacyReportMetadata
@ -80,12 +81,14 @@ class PrivacyReportTest(absltest.TestCase):
def test_plot_by_epochs_no_metadata(self): def test_plot_by_epochs_no_metadata(self):
# Raise error if metadata is missing # Raise error if metadata is missing
self.assertRaises(ValueError, privacy_report.plot_by_epochs, self.assertRaises(
(self.attack_results_no_metadata,), ['AUC']) ValueError, privacy_report.plot_by_epochs,
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
def test_single_metric_plot_by_epochs(self): def test_single_metric_plot_by_epochs(self):
fig = privacy_report.plot_by_epochs( fig = privacy_report.plot_by_epochs(
(self.results_epoch_10, self.results_epoch_15), ['AUC']) AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC'])
# extract data from figure. # extract data from figure.
auc_data = fig.gca().lines[0].get_data() auc_data = fig.gca().lines[0].get_data()
# X axis lists epoch values # X axis lists epoch values
@ -97,7 +100,7 @@ class PrivacyReportTest(absltest.TestCase):
def test_multiple_metrics_plot_by_epochs(self): def test_multiple_metrics_plot_by_epochs(self):
fig = privacy_report.plot_by_epochs( fig = privacy_report.plot_by_epochs(
(self.results_epoch_10, self.results_epoch_15), AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC', 'Attacker advantage']) ['AUC', 'Attacker advantage'])
# extract data from figure. # extract data from figure.
auc_data = fig.axes[0].lines[0].get_data() auc_data = fig.axes[0].lines[0].get_data()
@ -113,8 +116,9 @@ class PrivacyReportTest(absltest.TestCase):
def test_multiple_metrics_plot_by_epochs_multiple_models(self): def test_multiple_metrics_plot_by_epochs_multiple_models(self):
fig = privacy_report.plot_by_epochs( fig = privacy_report.plot_by_epochs(
(self.results_epoch_10, self.results_epoch_15, AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
self.results_epoch_15_model_2), ['AUC', 'Attacker advantage']) self.results_epoch_15_model_2)),
['AUC', 'Attacker advantage'])
# extract data from figure. # extract data from figure.
# extract data from figure. # extract data from figure.
auc_data_model_1 = fig.axes[0].lines[0].get_data() auc_data_model_1 = fig.axes[0].lines[0].get_data()
@ -136,13 +140,14 @@ class PrivacyReportTest(absltest.TestCase):
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self): def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
# Raise error if metadata is missing # Raise error if metadata is missing
self.assertRaises(ValueError, self.assertRaises(
privacy_report.plot_privacy_vs_accuracy_single_model, ValueError, privacy_report.plot_privacy_vs_accuracy_single_model,
(self.attack_results_no_metadata,), ['AUC']) AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
def test_single_metric_plot_privacy_vs_accuracy_single_model(self): def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model( fig = privacy_report.plot_privacy_vs_accuracy_single_model(
(self.results_epoch_10, self.results_epoch_15), ['AUC']) AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC'])
# extract data from figure. # extract data from figure.
auc_data = fig.gca().lines[0].get_data() auc_data = fig.gca().lines[0].get_data()
# X axis lists epoch values # X axis lists epoch values
@ -154,7 +159,7 @@ class PrivacyReportTest(absltest.TestCase):
def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self): def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model( fig = privacy_report.plot_privacy_vs_accuracy_single_model(
(self.results_epoch_10, self.results_epoch_15), AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
['AUC', 'Attacker advantage']) ['AUC', 'Attacker advantage'])
# extract data from figure. # extract data from figure.
auc_data = fig.axes[0].lines[0].get_data() auc_data = fig.axes[0].lines[0].get_data()
@ -170,8 +175,9 @@ class PrivacyReportTest(absltest.TestCase):
def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self): def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self):
fig = privacy_report.plot_privacy_vs_accuracy_single_model( fig = privacy_report.plot_privacy_vs_accuracy_single_model(
(self.results_epoch_10, self.results_epoch_15, AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
self.results_epoch_15_model_2), ['AUC', 'Attacker advantage']) self.results_epoch_15_model_2)),
['AUC', 'Attacker advantage'])
# extract data from figure. # extract data from figure.
auc_data_model_1 = fig.axes[0].lines[0].get_data() auc_data_model_1 = fig.axes[0].lines[0].get_data()
auc_data_model_2 = fig.axes[0].lines[1].get_data() auc_data_model_2 = fig.axes[0].lines[1].get_data()