Introduces an AttackResultsCollection class for the ML Privacy report.
PiperOrigin-RevId: 335858822
This commit is contained in:
parent
e19c53a78c
commit
703cd413c6
5 changed files with 125 additions and 20 deletions
|
@ -15,6 +15,8 @@
|
|||
# Lint as: python3
|
||||
"""Data structures representing attack inputs, configuration, outputs."""
|
||||
import enum
|
||||
import glob
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Iterable, Union
|
||||
|
||||
|
@ -577,6 +579,41 @@ class AttackResults:
|
|||
return pickle.load(inp)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttackResultsCollection:
|
||||
"""A collection of AttackResults."""
|
||||
attack_results_list: Iterable[AttackResults]
|
||||
|
||||
def append(self, attack_results: AttackResults):
|
||||
self.attack_results_list.append(attack_results)
|
||||
|
||||
def save(self, dirname):
|
||||
"""Saves self to a pickle file."""
|
||||
for i, attack_results in enumerate(self.attack_results_list):
|
||||
filepath = os.path.join(dirname,
|
||||
_get_attack_results_filename(attack_results, i))
|
||||
|
||||
attack_results.save(filepath)
|
||||
|
||||
@classmethod
|
||||
def load(cls, dirname):
|
||||
"""Loads AttackResultsCollection from all files in a directory."""
|
||||
loaded_collection = AttackResultsCollection([])
|
||||
for filepath in sorted(glob.glob('%s/*' % dirname)):
|
||||
with open(filepath, 'rb') as inp:
|
||||
loaded_collection.attack_results_list.append(pickle.load(inp))
|
||||
return loaded_collection
|
||||
|
||||
|
||||
def _get_attack_results_filename(attack_results: AttackResults, index: int):
|
||||
"""Creates a filename for a specific set of AttackResults."""
|
||||
metadata = attack_results.privacy_report_metadata
|
||||
if metadata is not None:
|
||||
return '%s_%s_%s.pickle' % (metadata.model_variant_label,
|
||||
metadata.epoch_num, index)
|
||||
return '%s.pickle' % index
|
||||
|
||||
|
||||
def get_flattened_attack_metrics(results: AttackResults):
|
||||
"""Get flattened attack metrics.
|
||||
|
||||
|
|
|
@ -23,7 +23,9 @@ import pandas as pd
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
|
@ -206,6 +208,62 @@ class SingleAttackResultTest(absltest.TestCase):
|
|||
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||
|
||||
|
||||
class AttackResultsCollectionTest(absltest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttackResultsCollectionTest, self).__init__(*args, **kwargs)
|
||||
|
||||
self.some_attack_result = SingleAttackResult(
|
||||
slice_spec=SingleSliceSpec(None),
|
||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 0.5, 1.0]),
|
||||
fpr=np.array([0.0, 0.5, 1.0]),
|
||||
thresholds=np.array([0, 1, 2])))
|
||||
|
||||
self.results_epoch_10 = AttackResults(
|
||||
single_attack_results=[self.some_attack_result],
|
||||
privacy_report_metadata=PrivacyReportMetadata(
|
||||
accuracy_train=0.4,
|
||||
accuracy_test=0.3,
|
||||
epoch_num=10,
|
||||
model_variant_label='default'))
|
||||
|
||||
self.results_epoch_15 = AttackResults(
|
||||
single_attack_results=[self.some_attack_result],
|
||||
privacy_report_metadata=PrivacyReportMetadata(
|
||||
accuracy_train=0.5,
|
||||
accuracy_test=0.4,
|
||||
epoch_num=15,
|
||||
model_variant_label='default'))
|
||||
|
||||
self.attack_results_no_metadata = AttackResults(
|
||||
single_attack_results=[self.some_attack_result])
|
||||
|
||||
self.collection_with_metadata = AttackResultsCollection(
|
||||
[self.results_epoch_10, self.results_epoch_15])
|
||||
|
||||
self.collection_no_metadata = AttackResultsCollection(
|
||||
[self.attack_results_no_metadata, self.attack_results_no_metadata])
|
||||
|
||||
def test_save_with_metadata(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.collection_with_metadata.save(tmpdirname)
|
||||
loaded_collection = AttackResultsCollection.load(tmpdirname)
|
||||
|
||||
self.assertEqual(
|
||||
repr(self.collection_with_metadata), repr(loaded_collection))
|
||||
self.assertLen(loaded_collection.attack_results_list, 2)
|
||||
|
||||
def test_save_no_metadata(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.collection_no_metadata.save(tmpdirname)
|
||||
loaded_collection = AttackResultsCollection.load(tmpdirname)
|
||||
|
||||
self.assertEqual(repr(self.collection_no_metadata), repr(loaded_collection))
|
||||
self.assertLen(loaded_collection.attack_results_list, 2)
|
||||
|
||||
|
||||
class AttackResultsTest(absltest.TestCase):
|
||||
|
||||
perfect_classifier_result: SingleAttackResult
|
||||
|
|
|
@ -32,6 +32,7 @@ from tensorflow_privacy.privacy.membership_inference_attack import membership_in
|
|||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||
|
@ -119,7 +120,7 @@ def crossentropy(true_labels, predictions):
|
|||
|
||||
|
||||
def main(unused_argv):
|
||||
epoch_results = []
|
||||
epoch_results = AttackResultsCollection([])
|
||||
|
||||
num_epochs = 2
|
||||
models = {
|
||||
|
|
|
@ -19,6 +19,7 @@ import matplotlib.pyplot as plt
|
|||
import pandas as pd
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsDFColumns
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import ENTIRE_DATASET_SLICE_STR
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
|
||||
|
@ -29,7 +30,7 @@ EPOCH_STR = 'Epoch'
|
|||
TRAIN_ACCURACY_STR = 'Train accuracy'
|
||||
|
||||
|
||||
def plot_by_epochs(results: Iterable[AttackResults],
|
||||
def plot_by_epochs(results: AttackResultsCollection,
|
||||
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
|
||||
"""Plots privacy vulnerabilities vs epoch numbers for a single model variant.
|
||||
|
||||
|
@ -43,8 +44,9 @@ def plot_by_epochs(results: Iterable[AttackResults],
|
|||
A pyplot figure with privacy vs accuracy plots.
|
||||
"""
|
||||
|
||||
_validate_results(results)
|
||||
all_results_df = _calculate_combined_df_with_metadata(results)
|
||||
_validate_results(results.attack_results_list)
|
||||
all_results_df = _calculate_combined_df_with_metadata(
|
||||
results.attack_results_list)
|
||||
return _generate_subplots(
|
||||
all_results_df=all_results_df,
|
||||
x_axis_metric='Epoch',
|
||||
|
@ -53,7 +55,7 @@ def plot_by_epochs(results: Iterable[AttackResults],
|
|||
|
||||
|
||||
def plot_privacy_vs_accuracy_single_model(
|
||||
results: Iterable[AttackResults], privacy_metrics: Iterable[PrivacyMetric]):
|
||||
results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]):
|
||||
"""Plots privacy vulnerabilities vs accuracy plots for a single model variant.
|
||||
|
||||
In case multiple privacy metrics are specified, the plot will feature
|
||||
|
@ -66,8 +68,9 @@ def plot_privacy_vs_accuracy_single_model(
|
|||
A pyplot figure with privacy vs accuracy plots.
|
||||
|
||||
"""
|
||||
_validate_results(results)
|
||||
all_results_df = _calculate_combined_df_with_metadata(results)
|
||||
_validate_results(results.attack_results_list)
|
||||
all_results_df = _calculate_combined_df_with_metadata(
|
||||
results.attack_results_list)
|
||||
return _generate_subplots(
|
||||
all_results_df=all_results_df,
|
||||
x_axis_metric='Train accuracy',
|
||||
|
|
|
@ -20,6 +20,7 @@ import numpy as np
|
|||
from tensorflow_privacy.privacy.membership_inference_attack import privacy_report
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||
PrivacyReportMetadata
|
||||
|
@ -80,12 +81,14 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
|
||||
def test_plot_by_epochs_no_metadata(self):
|
||||
# Raise error if metadata is missing
|
||||
self.assertRaises(ValueError, privacy_report.plot_by_epochs,
|
||||
(self.attack_results_no_metadata,), ['AUC'])
|
||||
self.assertRaises(
|
||||
ValueError, privacy_report.plot_by_epochs,
|
||||
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
|
||||
|
||||
def test_single_metric_plot_by_epochs(self):
|
||||
fig = privacy_report.plot_by_epochs(
|
||||
(self.results_epoch_10, self.results_epoch_15), ['AUC'])
|
||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||
['AUC'])
|
||||
# extract data from figure.
|
||||
auc_data = fig.gca().lines[0].get_data()
|
||||
# X axis lists epoch values
|
||||
|
@ -97,7 +100,7 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
|
||||
def test_multiple_metrics_plot_by_epochs(self):
|
||||
fig = privacy_report.plot_by_epochs(
|
||||
(self.results_epoch_10, self.results_epoch_15),
|
||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||
['AUC', 'Attacker advantage'])
|
||||
# extract data from figure.
|
||||
auc_data = fig.axes[0].lines[0].get_data()
|
||||
|
@ -113,8 +116,9 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
|
||||
def test_multiple_metrics_plot_by_epochs_multiple_models(self):
|
||||
fig = privacy_report.plot_by_epochs(
|
||||
(self.results_epoch_10, self.results_epoch_15,
|
||||
self.results_epoch_15_model_2), ['AUC', 'Attacker advantage'])
|
||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
|
||||
self.results_epoch_15_model_2)),
|
||||
['AUC', 'Attacker advantage'])
|
||||
# extract data from figure.
|
||||
# extract data from figure.
|
||||
auc_data_model_1 = fig.axes[0].lines[0].get_data()
|
||||
|
@ -136,13 +140,14 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
|
||||
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
|
||||
# Raise error if metadata is missing
|
||||
self.assertRaises(ValueError,
|
||||
privacy_report.plot_privacy_vs_accuracy_single_model,
|
||||
(self.attack_results_no_metadata,), ['AUC'])
|
||||
self.assertRaises(
|
||||
ValueError, privacy_report.plot_privacy_vs_accuracy_single_model,
|
||||
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
|
||||
|
||||
def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
|
||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||
(self.results_epoch_10, self.results_epoch_15), ['AUC'])
|
||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||
['AUC'])
|
||||
# extract data from figure.
|
||||
auc_data = fig.gca().lines[0].get_data()
|
||||
# X axis lists epoch values
|
||||
|
@ -154,7 +159,7 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
|
||||
def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
|
||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||
(self.results_epoch_10, self.results_epoch_15),
|
||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||
['AUC', 'Attacker advantage'])
|
||||
# extract data from figure.
|
||||
auc_data = fig.axes[0].lines[0].get_data()
|
||||
|
@ -170,8 +175,9 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
|
||||
def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self):
|
||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||
(self.results_epoch_10, self.results_epoch_15,
|
||||
self.results_epoch_15_model_2), ['AUC', 'Attacker advantage'])
|
||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
|
||||
self.results_epoch_15_model_2)),
|
||||
['AUC', 'Attacker advantage'])
|
||||
# extract data from figure.
|
||||
auc_data_model_1 = fig.axes[0].lines[0].get_data()
|
||||
auc_data_model_2 = fig.axes[0].lines[1].get_data()
|
||||
|
|
Loading…
Reference in a new issue