forked from 626_privacy/tensorflow_privacy
Introduces an AttackResultsCollection class for the ML Privacy report.
PiperOrigin-RevId: 335858822
This commit is contained in:
parent
e19c53a78c
commit
703cd413c6
5 changed files with 125 additions and 20 deletions
|
@ -15,6 +15,8 @@
|
||||||
# Lint as: python3
|
# Lint as: python3
|
||||||
"""Data structures representing attack inputs, configuration, outputs."""
|
"""Data structures representing attack inputs, configuration, outputs."""
|
||||||
import enum
|
import enum
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, Iterable, Union
|
from typing import Any, Iterable, Union
|
||||||
|
|
||||||
|
@ -577,6 +579,41 @@ class AttackResults:
|
||||||
return pickle.load(inp)
|
return pickle.load(inp)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AttackResultsCollection:
|
||||||
|
"""A collection of AttackResults."""
|
||||||
|
attack_results_list: Iterable[AttackResults]
|
||||||
|
|
||||||
|
def append(self, attack_results: AttackResults):
|
||||||
|
self.attack_results_list.append(attack_results)
|
||||||
|
|
||||||
|
def save(self, dirname):
|
||||||
|
"""Saves self to a pickle file."""
|
||||||
|
for i, attack_results in enumerate(self.attack_results_list):
|
||||||
|
filepath = os.path.join(dirname,
|
||||||
|
_get_attack_results_filename(attack_results, i))
|
||||||
|
|
||||||
|
attack_results.save(filepath)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, dirname):
|
||||||
|
"""Loads AttackResultsCollection from all files in a directory."""
|
||||||
|
loaded_collection = AttackResultsCollection([])
|
||||||
|
for filepath in sorted(glob.glob('%s/*' % dirname)):
|
||||||
|
with open(filepath, 'rb') as inp:
|
||||||
|
loaded_collection.attack_results_list.append(pickle.load(inp))
|
||||||
|
return loaded_collection
|
||||||
|
|
||||||
|
|
||||||
|
def _get_attack_results_filename(attack_results: AttackResults, index: int):
|
||||||
|
"""Creates a filename for a specific set of AttackResults."""
|
||||||
|
metadata = attack_results.privacy_report_metadata
|
||||||
|
if metadata is not None:
|
||||||
|
return '%s_%s_%s.pickle' % (metadata.model_variant_label,
|
||||||
|
metadata.epoch_num, index)
|
||||||
|
return '%s.pickle' % index
|
||||||
|
|
||||||
|
|
||||||
def get_flattened_attack_metrics(results: AttackResults):
|
def get_flattened_attack_metrics(results: AttackResults):
|
||||||
"""Get flattened attack metrics.
|
"""Get flattened attack metrics.
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,9 @@ import pandas as pd
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
|
@ -206,6 +208,62 @@ class SingleAttackResultTest(absltest.TestCase):
|
||||||
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class AttackResultsCollectionTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(AttackResultsCollectionTest, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.some_attack_result = SingleAttackResult(
|
||||||
|
slice_spec=SingleSliceSpec(None),
|
||||||
|
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||||
|
roc_curve=RocCurve(
|
||||||
|
tpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
fpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
thresholds=np.array([0, 1, 2])))
|
||||||
|
|
||||||
|
self.results_epoch_10 = AttackResults(
|
||||||
|
single_attack_results=[self.some_attack_result],
|
||||||
|
privacy_report_metadata=PrivacyReportMetadata(
|
||||||
|
accuracy_train=0.4,
|
||||||
|
accuracy_test=0.3,
|
||||||
|
epoch_num=10,
|
||||||
|
model_variant_label='default'))
|
||||||
|
|
||||||
|
self.results_epoch_15 = AttackResults(
|
||||||
|
single_attack_results=[self.some_attack_result],
|
||||||
|
privacy_report_metadata=PrivacyReportMetadata(
|
||||||
|
accuracy_train=0.5,
|
||||||
|
accuracy_test=0.4,
|
||||||
|
epoch_num=15,
|
||||||
|
model_variant_label='default'))
|
||||||
|
|
||||||
|
self.attack_results_no_metadata = AttackResults(
|
||||||
|
single_attack_results=[self.some_attack_result])
|
||||||
|
|
||||||
|
self.collection_with_metadata = AttackResultsCollection(
|
||||||
|
[self.results_epoch_10, self.results_epoch_15])
|
||||||
|
|
||||||
|
self.collection_no_metadata = AttackResultsCollection(
|
||||||
|
[self.attack_results_no_metadata, self.attack_results_no_metadata])
|
||||||
|
|
||||||
|
def test_save_with_metadata(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
self.collection_with_metadata.save(tmpdirname)
|
||||||
|
loaded_collection = AttackResultsCollection.load(tmpdirname)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
repr(self.collection_with_metadata), repr(loaded_collection))
|
||||||
|
self.assertLen(loaded_collection.attack_results_list, 2)
|
||||||
|
|
||||||
|
def test_save_no_metadata(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
self.collection_no_metadata.save(tmpdirname)
|
||||||
|
loaded_collection = AttackResultsCollection.load(tmpdirname)
|
||||||
|
|
||||||
|
self.assertEqual(repr(self.collection_no_metadata), repr(loaded_collection))
|
||||||
|
self.assertLen(loaded_collection.attack_results_list, 2)
|
||||||
|
|
||||||
|
|
||||||
class AttackResultsTest(absltest.TestCase):
|
class AttackResultsTest(absltest.TestCase):
|
||||||
|
|
||||||
perfect_classifier_result: SingleAttackResult
|
perfect_classifier_result: SingleAttackResult
|
||||||
|
|
|
@ -32,6 +32,7 @@ from tensorflow_privacy.privacy.membership_inference_attack import membership_in
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||||
|
@ -119,7 +120,7 @@ def crossentropy(true_labels, predictions):
|
||||||
|
|
||||||
|
|
||||||
def main(unused_argv):
|
def main(unused_argv):
|
||||||
epoch_results = []
|
epoch_results = AttackResultsCollection([])
|
||||||
|
|
||||||
num_epochs = 2
|
num_epochs = 2
|
||||||
models = {
|
models = {
|
||||||
|
|
|
@ -19,6 +19,7 @@ import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsDFColumns
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsDFColumns
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import ENTIRE_DATASET_SLICE_STR
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import ENTIRE_DATASET_SLICE_STR
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
|
||||||
|
@ -29,7 +30,7 @@ EPOCH_STR = 'Epoch'
|
||||||
TRAIN_ACCURACY_STR = 'Train accuracy'
|
TRAIN_ACCURACY_STR = 'Train accuracy'
|
||||||
|
|
||||||
|
|
||||||
def plot_by_epochs(results: Iterable[AttackResults],
|
def plot_by_epochs(results: AttackResultsCollection,
|
||||||
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
|
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
|
||||||
"""Plots privacy vulnerabilities vs epoch numbers for a single model variant.
|
"""Plots privacy vulnerabilities vs epoch numbers for a single model variant.
|
||||||
|
|
||||||
|
@ -43,8 +44,9 @@ def plot_by_epochs(results: Iterable[AttackResults],
|
||||||
A pyplot figure with privacy vs accuracy plots.
|
A pyplot figure with privacy vs accuracy plots.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_validate_results(results)
|
_validate_results(results.attack_results_list)
|
||||||
all_results_df = _calculate_combined_df_with_metadata(results)
|
all_results_df = _calculate_combined_df_with_metadata(
|
||||||
|
results.attack_results_list)
|
||||||
return _generate_subplots(
|
return _generate_subplots(
|
||||||
all_results_df=all_results_df,
|
all_results_df=all_results_df,
|
||||||
x_axis_metric='Epoch',
|
x_axis_metric='Epoch',
|
||||||
|
@ -53,7 +55,7 @@ def plot_by_epochs(results: Iterable[AttackResults],
|
||||||
|
|
||||||
|
|
||||||
def plot_privacy_vs_accuracy_single_model(
|
def plot_privacy_vs_accuracy_single_model(
|
||||||
results: Iterable[AttackResults], privacy_metrics: Iterable[PrivacyMetric]):
|
results: AttackResultsCollection, privacy_metrics: Iterable[PrivacyMetric]):
|
||||||
"""Plots privacy vulnerabilities vs accuracy plots for a single model variant.
|
"""Plots privacy vulnerabilities vs accuracy plots for a single model variant.
|
||||||
|
|
||||||
In case multiple privacy metrics are specified, the plot will feature
|
In case multiple privacy metrics are specified, the plot will feature
|
||||||
|
@ -66,8 +68,9 @@ def plot_privacy_vs_accuracy_single_model(
|
||||||
A pyplot figure with privacy vs accuracy plots.
|
A pyplot figure with privacy vs accuracy plots.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_validate_results(results)
|
_validate_results(results.attack_results_list)
|
||||||
all_results_df = _calculate_combined_df_with_metadata(results)
|
all_results_df = _calculate_combined_df_with_metadata(
|
||||||
|
results.attack_results_list)
|
||||||
return _generate_subplots(
|
return _generate_subplots(
|
||||||
all_results_df=all_results_df,
|
all_results_df=all_results_df,
|
||||||
x_axis_metric='Train accuracy',
|
x_axis_metric='Train accuracy',
|
||||||
|
|
|
@ -20,6 +20,7 @@ import numpy as np
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import privacy_report
|
from tensorflow_privacy.privacy.membership_inference_attack import privacy_report
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||||
PrivacyReportMetadata
|
PrivacyReportMetadata
|
||||||
|
@ -80,12 +81,14 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_plot_by_epochs_no_metadata(self):
|
def test_plot_by_epochs_no_metadata(self):
|
||||||
# Raise error if metadata is missing
|
# Raise error if metadata is missing
|
||||||
self.assertRaises(ValueError, privacy_report.plot_by_epochs,
|
self.assertRaises(
|
||||||
(self.attack_results_no_metadata,), ['AUC'])
|
ValueError, privacy_report.plot_by_epochs,
|
||||||
|
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
|
||||||
|
|
||||||
def test_single_metric_plot_by_epochs(self):
|
def test_single_metric_plot_by_epochs(self):
|
||||||
fig = privacy_report.plot_by_epochs(
|
fig = privacy_report.plot_by_epochs(
|
||||||
(self.results_epoch_10, self.results_epoch_15), ['AUC'])
|
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||||
|
['AUC'])
|
||||||
# extract data from figure.
|
# extract data from figure.
|
||||||
auc_data = fig.gca().lines[0].get_data()
|
auc_data = fig.gca().lines[0].get_data()
|
||||||
# X axis lists epoch values
|
# X axis lists epoch values
|
||||||
|
@ -97,7 +100,7 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_multiple_metrics_plot_by_epochs(self):
|
def test_multiple_metrics_plot_by_epochs(self):
|
||||||
fig = privacy_report.plot_by_epochs(
|
fig = privacy_report.plot_by_epochs(
|
||||||
(self.results_epoch_10, self.results_epoch_15),
|
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||||
['AUC', 'Attacker advantage'])
|
['AUC', 'Attacker advantage'])
|
||||||
# extract data from figure.
|
# extract data from figure.
|
||||||
auc_data = fig.axes[0].lines[0].get_data()
|
auc_data = fig.axes[0].lines[0].get_data()
|
||||||
|
@ -113,8 +116,9 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_multiple_metrics_plot_by_epochs_multiple_models(self):
|
def test_multiple_metrics_plot_by_epochs_multiple_models(self):
|
||||||
fig = privacy_report.plot_by_epochs(
|
fig = privacy_report.plot_by_epochs(
|
||||||
(self.results_epoch_10, self.results_epoch_15,
|
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
|
||||||
self.results_epoch_15_model_2), ['AUC', 'Attacker advantage'])
|
self.results_epoch_15_model_2)),
|
||||||
|
['AUC', 'Attacker advantage'])
|
||||||
# extract data from figure.
|
# extract data from figure.
|
||||||
# extract data from figure.
|
# extract data from figure.
|
||||||
auc_data_model_1 = fig.axes[0].lines[0].get_data()
|
auc_data_model_1 = fig.axes[0].lines[0].get_data()
|
||||||
|
@ -136,13 +140,14 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
|
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
|
||||||
# Raise error if metadata is missing
|
# Raise error if metadata is missing
|
||||||
self.assertRaises(ValueError,
|
self.assertRaises(
|
||||||
privacy_report.plot_privacy_vs_accuracy_single_model,
|
ValueError, privacy_report.plot_privacy_vs_accuracy_single_model,
|
||||||
(self.attack_results_no_metadata,), ['AUC'])
|
AttackResultsCollection((self.attack_results_no_metadata,)), ['AUC'])
|
||||||
|
|
||||||
def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
|
def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
|
||||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||||
(self.results_epoch_10, self.results_epoch_15), ['AUC'])
|
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||||
|
['AUC'])
|
||||||
# extract data from figure.
|
# extract data from figure.
|
||||||
auc_data = fig.gca().lines[0].get_data()
|
auc_data = fig.gca().lines[0].get_data()
|
||||||
# X axis lists epoch values
|
# X axis lists epoch values
|
||||||
|
@ -154,7 +159,7 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
|
def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
|
||||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||||
(self.results_epoch_10, self.results_epoch_15),
|
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||||
['AUC', 'Attacker advantage'])
|
['AUC', 'Attacker advantage'])
|
||||||
# extract data from figure.
|
# extract data from figure.
|
||||||
auc_data = fig.axes[0].lines[0].get_data()
|
auc_data = fig.axes[0].lines[0].get_data()
|
||||||
|
@ -170,8 +175,9 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self):
|
def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self):
|
||||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||||
(self.results_epoch_10, self.results_epoch_15,
|
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15,
|
||||||
self.results_epoch_15_model_2), ['AUC', 'Attacker advantage'])
|
self.results_epoch_15_model_2)),
|
||||||
|
['AUC', 'Attacker advantage'])
|
||||||
# extract data from figure.
|
# extract data from figure.
|
||||||
auc_data_model_1 = fig.axes[0].lines[0].get_data()
|
auc_data_model_1 = fig.axes[0].lines[0].get_data()
|
||||||
auc_data_model_2 = fig.axes[0].lines[1].get_data()
|
auc_data_model_2 = fig.axes[0].lines[1].get_data()
|
||||||
|
|
Loading…
Reference in a new issue