Adds Privacy vs Utility charts to the Privacy Report for a single model.
PiperOrigin-RevId: 331720083
This commit is contained in:
parent
fc38e3f733
commit
70f9585a24
4 changed files with 126 additions and 16 deletions
|
@ -123,6 +123,16 @@ class AttackType(enum.Enum):
|
||||||
return '%s' % self.name
|
return '%s' % self.name
|
||||||
|
|
||||||
|
|
||||||
|
class PrivacyMetric(enum.Enum):
|
||||||
|
"""An enum for the supported privacy risk metrics."""
|
||||||
|
AUC = 'AUC'
|
||||||
|
ATTACKER_ADVANTAGE = 'Attacker advantage'
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Returns 'AUC' instead of PrivacyMetric.AUC."""
|
||||||
|
return '%s' % self.value
|
||||||
|
|
||||||
|
|
||||||
def _is_integer_type_array(a):
|
def _is_integer_type_array(a):
|
||||||
return np.issubdtype(a.dtype, np.integer)
|
return np.issubdtype(a.dtype, np.integer)
|
||||||
|
|
||||||
|
@ -469,8 +479,8 @@ class AttackResults:
|
||||||
'slice feature': slice_features,
|
'slice feature': slice_features,
|
||||||
'slice value': slice_values,
|
'slice value': slice_values,
|
||||||
'attack type': attack_types,
|
'attack type': attack_types,
|
||||||
'Attacker advantage': advantages,
|
str(PrivacyMetric.ATTACKER_ADVANTAGE): advantages,
|
||||||
'AUC': aucs
|
str(PrivacyMetric.AUC): aucs
|
||||||
})
|
})
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
|
@ -28,9 +28,11 @@ from tensorflow import keras
|
||||||
from tensorflow.keras import layers
|
from tensorflow.keras import layers
|
||||||
from tensorflow.keras.utils import to_categorical
|
from tensorflow.keras.utils import to_categorical
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||||
PrivacyReportMetadata
|
PrivacyReportMetadata
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
|
@ -109,6 +111,7 @@ def crossentropy(true_labels, predictions):
|
||||||
epoch_results = []
|
epoch_results = []
|
||||||
|
|
||||||
# Incrementally train the model and store privacy risk metrics every 10 epochs.
|
# Incrementally train the model and store privacy risk metrics every 10 epochs.
|
||||||
|
num_epochs = 2
|
||||||
for i in range(1, 6):
|
for i in range(1, 6):
|
||||||
model.fit(
|
model.fit(
|
||||||
training_features,
|
training_features,
|
||||||
|
@ -116,7 +119,7 @@ for i in range(1, 6):
|
||||||
validation_data=(test_features, to_categorical(test_labels,
|
validation_data=(test_features, to_categorical(test_labels,
|
||||||
num_clusters)),
|
num_clusters)),
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
epochs=2,
|
epochs=num_epochs,
|
||||||
shuffle=True)
|
shuffle=True)
|
||||||
|
|
||||||
training_pred = model.predict(training_features)
|
training_pred = model.predict(training_features)
|
||||||
|
@ -128,7 +131,7 @@ for i in range(1, 6):
|
||||||
np.argmax(training_pred, axis=1)),
|
np.argmax(training_pred, axis=1)),
|
||||||
accuracy_test=metrics.accuracy_score(test_labels,
|
accuracy_test=metrics.accuracy_score(test_labels,
|
||||||
np.argmax(test_pred, axis=1)),
|
np.argmax(test_pred, axis=1)),
|
||||||
epoch_num=2 * i,
|
epoch_num=num_epochs * i,
|
||||||
model_variant_label="default")
|
model_variant_label="default")
|
||||||
|
|
||||||
attack_results = mia.run_attacks(
|
attack_results = mia.run_attacks(
|
||||||
|
@ -145,10 +148,13 @@ for i in range(1, 6):
|
||||||
privacy_report_metadata=privacy_report_metadata)
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
epoch_results.append(attack_results)
|
epoch_results.append(attack_results)
|
||||||
|
|
||||||
# Generate privacy report
|
# Generate privacy reports
|
||||||
epoch_figure = privacy_report.plot_by_epochs(epoch_results,
|
epoch_figure = privacy_report.plot_by_epochs(
|
||||||
["Attacker advantage", "AUC"])
|
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
||||||
epoch_figure.show()
|
epoch_figure.show()
|
||||||
|
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||||
|
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
||||||
|
privacy_utility_figure.show()
|
||||||
|
|
||||||
# Example of saving the results to the file and loading them back.
|
# Example of saving the results to the file and loading them back.
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
|
|
@ -19,12 +19,57 @@ import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric
|
||||||
|
|
||||||
|
|
||||||
def plot_by_epochs(results: Iterable[AttackResults],
|
def plot_by_epochs(results: Iterable[AttackResults],
|
||||||
privacy_metrics: Iterable[str]) -> plt.Figure:
|
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
|
||||||
"""Plots privacy vulnerabilities by epochs."""
|
"""Plots privacy vulnerabilities vs epoch numbers for a single model variant.
|
||||||
|
|
||||||
|
In case multiple privacy metrics are specified, the plot will feature
|
||||||
|
multiple subplots (one subplot per metrics).
|
||||||
|
Args:
|
||||||
|
results: AttackResults for the plot
|
||||||
|
privacy_metrics: List of enumerated privacy metrics that should be plotted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pyplot figure with privacy vs accuracy plots.
|
||||||
|
"""
|
||||||
|
|
||||||
_validate_results(results)
|
_validate_results(results)
|
||||||
|
all_results_df = _calculate_combined_df_with_metadata(results)
|
||||||
|
return _generate_subplots(
|
||||||
|
all_results_df=all_results_df,
|
||||||
|
x_axis_metric='Epoch',
|
||||||
|
figure_title='Vulnerability per Epoch',
|
||||||
|
privacy_metrics=privacy_metrics)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_privacy_vs_accuracy_single_model(
|
||||||
|
results: Iterable[AttackResults], privacy_metrics: Iterable[PrivacyMetric]):
|
||||||
|
"""Plots privacy vulnerabilities vs accuracy plots for a single model variant.
|
||||||
|
|
||||||
|
In case multiple privacy metrics are specified, the plot will feature
|
||||||
|
multiple subplots (one subplot per metrics).
|
||||||
|
Args:
|
||||||
|
results: AttackResults for the plot
|
||||||
|
privacy_metrics: List of enumerated privacy metrics that should be plotted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pyplot figure with privacy vs accuracy plots.
|
||||||
|
|
||||||
|
"""
|
||||||
|
_validate_results(results)
|
||||||
|
all_results_df = _calculate_combined_df_with_metadata(results)
|
||||||
|
return _generate_subplots(
|
||||||
|
all_results_df=all_results_df,
|
||||||
|
x_axis_metric='Train accuracy',
|
||||||
|
figure_title='Privacy vs Utility Analysis',
|
||||||
|
privacy_metrics=privacy_metrics)
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_combined_df_with_metadata(results: Iterable[AttackResults]):
|
||||||
|
"""Adds metadata to the dataframe and concats them together."""
|
||||||
all_results_df = None
|
all_results_df = None
|
||||||
for attack_results in results:
|
for attack_results in results:
|
||||||
attack_results_df = attack_results.calculate_pd_dataframe()
|
attack_results_df = attack_results.calculate_pd_dataframe()
|
||||||
|
@ -32,25 +77,36 @@ def plot_by_epochs(results: Iterable[AttackResults],
|
||||||
== 'entire_dataset']
|
== 'entire_dataset']
|
||||||
attack_results_df.insert(0, 'Epoch',
|
attack_results_df.insert(0, 'Epoch',
|
||||||
attack_results.privacy_report_metadata.epoch_num)
|
attack_results.privacy_report_metadata.epoch_num)
|
||||||
|
attack_results_df.insert(
|
||||||
|
0, 'Train accuracy',
|
||||||
|
attack_results.privacy_report_metadata.accuracy_train)
|
||||||
if all_results_df is None:
|
if all_results_df is None:
|
||||||
all_results_df = attack_results_df
|
all_results_df = attack_results_df
|
||||||
else:
|
else:
|
||||||
all_results_df = pd.concat([all_results_df, attack_results_df],
|
all_results_df = pd.concat([all_results_df, attack_results_df],
|
||||||
ignore_index=True)
|
ignore_index=True)
|
||||||
|
return all_results_df
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
|
||||||
|
figure_title: str,
|
||||||
|
privacy_metrics: Iterable[PrivacyMetric]):
|
||||||
|
"""Create one subplot per privacy metric for a specified x_axis_metric."""
|
||||||
fig, axes = plt.subplots(1, len(privacy_metrics))
|
fig, axes = plt.subplots(1, len(privacy_metrics))
|
||||||
|
# Set a title for the entire group of subplots.
|
||||||
|
fig.suptitle(figure_title)
|
||||||
if len(privacy_metrics) == 1:
|
if len(privacy_metrics) == 1:
|
||||||
axes = (axes,)
|
axes = (axes,)
|
||||||
for i, privacy_metric in enumerate(privacy_metrics):
|
for i, privacy_metric in enumerate(privacy_metrics):
|
||||||
attack_types = all_results_df['attack type'].unique()
|
attack_types = all_results_df['attack type'].unique()
|
||||||
for attack_type in attack_types:
|
for attack_type in attack_types:
|
||||||
axes[i].plot(
|
attack_type_results = all_results_df.loc[all_results_df['attack type'] ==
|
||||||
all_results_df.loc[all_results_df['attack type'] == attack_type]
|
attack_type]
|
||||||
['Epoch'], all_results_df.loc[all_results_df['attack type'] ==
|
axes[i].plot(attack_type_results[x_axis_metric],
|
||||||
attack_type][privacy_metric])
|
attack_type_results[str(privacy_metric)])
|
||||||
axes[i].legend(attack_types)
|
axes[i].legend(attack_types)
|
||||||
axes[i].set_xlabel('Epoch')
|
axes[i].set_xlabel(x_axis_metric)
|
||||||
axes[i].set_title('%s for Entire dataset' % privacy_metric)
|
axes[i].set_title('%s for Entire dataset' % str(privacy_metric))
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,8 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
np.testing.assert_array_equal(auc_data[0], [10, 15])
|
np.testing.assert_array_equal(auc_data[0], [10, 15])
|
||||||
# Y axis lists AUC values
|
# Y axis lists AUC values
|
||||||
np.testing.assert_array_equal(auc_data[1], [0.5, 1.0])
|
np.testing.assert_array_equal(auc_data[1], [0.5, 1.0])
|
||||||
|
# Check the title
|
||||||
|
self.assertEqual(fig._suptitle.get_text(), 'Vulnerability per Epoch')
|
||||||
|
|
||||||
def test_multiple_metrics_plot_by_epochs(self):
|
def test_multiple_metrics_plot_by_epochs(self):
|
||||||
fig = privacy_report.plot_by_epochs(
|
fig = privacy_report.plot_by_epochs(
|
||||||
|
@ -98,6 +100,42 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
# Y axis lists privacy metrics
|
# Y axis lists privacy metrics
|
||||||
np.testing.assert_array_equal(auc_data[1], [0.5, 1.0])
|
np.testing.assert_array_equal(auc_data[1], [0.5, 1.0])
|
||||||
np.testing.assert_array_equal(attacker_advantage_data[1], [0, 1.0])
|
np.testing.assert_array_equal(attacker_advantage_data[1], [0, 1.0])
|
||||||
|
# Check the title
|
||||||
|
self.assertEqual(fig._suptitle.get_text(), 'Vulnerability per Epoch')
|
||||||
|
|
||||||
|
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
|
||||||
|
# Raise error if metadata is missing
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
privacy_report.plot_privacy_vs_accuracy_single_model,
|
||||||
|
(self.attack_results_no_metadata,), ['AUC'])
|
||||||
|
|
||||||
|
def test_single_metric_plot_privacy_vs_accuracy_single_model(self):
|
||||||
|
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||||
|
(self.results_epoch_10, self.results_epoch_15), ['AUC'])
|
||||||
|
# extract data from figure.
|
||||||
|
auc_data = fig.gca().lines[0].get_data()
|
||||||
|
# X axis lists epoch values
|
||||||
|
np.testing.assert_array_equal(auc_data[0], [0.4, 0.5])
|
||||||
|
# Y axis lists AUC values
|
||||||
|
np.testing.assert_array_equal(auc_data[1], [0.5, 1.0])
|
||||||
|
# Check the title
|
||||||
|
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
|
||||||
|
|
||||||
|
def test_multiple_metrics_plot_privacy_vs_accuracy_single_model(self):
|
||||||
|
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||||
|
(self.results_epoch_10, self.results_epoch_15),
|
||||||
|
['AUC', 'Attacker advantage'])
|
||||||
|
# extract data from figure.
|
||||||
|
auc_data = fig.axes[0].lines[0].get_data()
|
||||||
|
attacker_advantage_data = fig.axes[1].lines[0].get_data()
|
||||||
|
# X axis lists epoch values
|
||||||
|
np.testing.assert_array_equal(auc_data[0], [0.4, 0.5])
|
||||||
|
np.testing.assert_array_equal(attacker_advantage_data[0], [0.4, 0.5])
|
||||||
|
# Y axis lists privacy metrics
|
||||||
|
np.testing.assert_array_equal(auc_data[1], [0.5, 1.0])
|
||||||
|
np.testing.assert_array_equal(attacker_advantage_data[1], [0, 1.0])
|
||||||
|
# Check the title
|
||||||
|
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in a new issue