From fc38e3f7337a64bf56d92d968441a886aced0113 Mon Sep 17 00:00:00 2001 From: David Marn Date: Sat, 12 Sep 2020 09:11:27 -0700 Subject: [PATCH] Modifies Privacy Report metadata and adds an epoch chart. PiperOrigin-RevId: 331326000 --- .../data_structures.py | 7 +- .../data_structures_test.py | 4 +- .../membership_inference_attack/example.py | 70 +++++++----- .../privacy_report.py | 63 +++++++++++ .../privacy_report_test.py | 104 ++++++++++++++++++ 5 files changed, 217 insertions(+), 31 deletions(-) create mode 100644 tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py create mode 100644 tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 0da4372..23274ab 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -434,6 +434,9 @@ class PrivacyReportMetadata: loss_train: float = None loss_test: float = None + model_variant_label: str = 'Default model variant' + epoch_num: int = None + @dataclass class AttackResults: @@ -466,8 +469,8 @@ class AttackResults: 'slice feature': slice_features, 'slice value': slice_values, 'attack type': attack_types, - 'attack advantage': advantages, - 'roc auc': aucs + 'Attacker advantage': advantages, + 'AUC': aucs }) return df diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 6a3be8e..55442ed 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -305,8 +305,8 @@ class AttackResultsTest(absltest.TestCase): 'slice feature': ['correctly_classfied', 'entire_dataset'], 'slice value': ['True', ''], 'attack type': ['threshold', 'threshold'], - 'attack advantage': [1.0, 0.0], - 'roc auc': [1.0, 0.5] + 'Attacker advantage': [1.0, 0.0], + 'AUC': [1.0, 0.5] }) self.assertTrue(df.equals(df_expected)) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index 5a9ff4a..5d4b6b5 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -35,6 +35,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo PrivacyReportMetadata from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting +import tensorflow_privacy.privacy.membership_inference_attack.privacy_report as privacy_report def generate_random_cluster(center, scale, num_points): @@ -96,16 +97,6 @@ model = keras.models.Sequential([ ]) model.compile( optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]) -model.fit( - training_features, - to_categorical(training_labels, num_clusters), - validation_data=(test_features, to_categorical(test_labels, num_clusters)), - batch_size=64, - epochs=2, - shuffle=True) - -training_pred = model.predict(training_features) -test_pred = model.predict(test_features) def crossentropy(true_labels, predictions): @@ -115,24 +106,49 @@ def crossentropy(true_labels, predictions): keras.backend.variable(predictions))) -# Add metadata to generate a privacy report. -privacy_report_metadata = PrivacyReportMetadata( - accuracy_train=metrics.accuracy_score(training_labels, - np.argmax(training_pred, axis=1)), - accuracy_test=metrics.accuracy_score(test_labels, - np.argmax(test_pred, axis=1))) +epoch_results = [] -attack_results = mia.run_attacks( - AttackInputData( - labels_train=training_labels, - labels_test=test_labels, - probs_train=training_pred, - probs_test=test_pred, - loss_train=crossentropy(training_labels, training_pred), - loss_test=crossentropy(test_labels, test_pred)), - SlicingSpec(entire_dataset=True, by_class=True), - attack_types=(AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION), - privacy_report_metadata=None) +# Incrementally train the model and store privacy risk metrics every 10 epochs. +for i in range(1, 6): + model.fit( + training_features, + to_categorical(training_labels, num_clusters), + validation_data=(test_features, to_categorical(test_labels, + num_clusters)), + batch_size=64, + epochs=2, + shuffle=True) + + training_pred = model.predict(training_features) + test_pred = model.predict(test_features) + + # Add metadata to generate a privacy report. + privacy_report_metadata = PrivacyReportMetadata( + accuracy_train=metrics.accuracy_score(training_labels, + np.argmax(training_pred, axis=1)), + accuracy_test=metrics.accuracy_score(test_labels, + np.argmax(test_pred, axis=1)), + epoch_num=2 * i, + model_variant_label="default") + + attack_results = mia.run_attacks( + AttackInputData( + labels_train=training_labels, + labels_test=test_labels, + probs_train=training_pred, + probs_test=test_pred, + loss_train=crossentropy(training_labels, training_pred), + loss_test=crossentropy(test_labels, test_pred)), + SlicingSpec(entire_dataset=True, by_class=True), + attack_types=(AttackType.THRESHOLD_ATTACK, + AttackType.LOGISTIC_REGRESSION), + privacy_report_metadata=privacy_report_metadata) + epoch_results.append(attack_results) + +# Generate privacy report +epoch_figure = privacy_report.plot_by_epochs(epoch_results, + ["Attacker advantage", "AUC"]) +epoch_figure.show() # Example of saving the results to the file and loading them back. with tempfile.TemporaryDirectory() as tmpdirname: diff --git a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py new file mode 100644 index 0000000..0a0b48d --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py @@ -0,0 +1,63 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Plotting code for ML Privacy Reports.""" +from typing import Iterable +import matplotlib.pyplot as plt +import pandas as pd + +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults + + +def plot_by_epochs(results: Iterable[AttackResults], + privacy_metrics: Iterable[str]) -> plt.Figure: + """Plots privacy vulnerabilities by epochs.""" + _validate_results(results) + all_results_df = None + for attack_results in results: + attack_results_df = attack_results.calculate_pd_dataframe() + attack_results_df = attack_results_df.loc[attack_results_df['slice feature'] + == 'entire_dataset'] + attack_results_df.insert(0, 'Epoch', + attack_results.privacy_report_metadata.epoch_num) + if all_results_df is None: + all_results_df = attack_results_df + else: + all_results_df = pd.concat([all_results_df, attack_results_df], + ignore_index=True) + + fig, axes = plt.subplots(1, len(privacy_metrics)) + if len(privacy_metrics) == 1: + axes = (axes,) + for i, privacy_metric in enumerate(privacy_metrics): + attack_types = all_results_df['attack type'].unique() + for attack_type in attack_types: + axes[i].plot( + all_results_df.loc[all_results_df['attack type'] == attack_type] + ['Epoch'], all_results_df.loc[all_results_df['attack type'] == + attack_type][privacy_metric]) + axes[i].legend(attack_types) + axes[i].set_xlabel('Epoch') + axes[i].set_title('%s for Entire dataset' % privacy_metric) + + return fig + + +def _validate_results(results: Iterable[AttackResults]): + for attack_results in results: + if not attack_results or not attack_results.privacy_report_metadata: + raise ValueError('Privacy metadata is not defined.') + if not attack_results.privacy_report_metadata.epoch_num: + raise ValueError('epoch_num in metadata is not defined.') diff --git a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py new file mode 100644 index 0000000..f6a193e --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py @@ -0,0 +1,104 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for tensorflow_privacy.privacy.membership_inference_attack.privacy_report.""" +from absl.testing import absltest +import numpy as np + +from tensorflow_privacy.privacy.membership_inference_attack import privacy_report + +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ + PrivacyReportMetadata +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec + + +class PrivacyReportTest(absltest.TestCase): + + def __init__(self, *args, **kwargs): + super(PrivacyReportTest, self).__init__(*args, **kwargs) + + # Classifier that achieves an AUC of 0.5. + self.imperfect_classifier_result = SingleAttackResult( + slice_spec=SingleSliceSpec(None), + attack_type=AttackType.THRESHOLD_ATTACK, + roc_curve=RocCurve( + tpr=np.array([0.0, 0.5, 1.0]), + fpr=np.array([0.0, 0.5, 1.0]), + thresholds=np.array([0, 1, 2]))) + + # Classifier that achieves an AUC of 1.0. + self.perfect_classifier_result = SingleAttackResult( + slice_spec=SingleSliceSpec(None), + attack_type=AttackType.THRESHOLD_ATTACK, + roc_curve=RocCurve( + tpr=np.array([0.0, 1.0, 1.0]), + fpr=np.array([1.0, 1.0, 0.0]), + thresholds=np.array([0, 1, 2]))) + + self.results_epoch_10 = AttackResults( + single_attack_results=[self.imperfect_classifier_result], + privacy_report_metadata=PrivacyReportMetadata( + accuracy_train=0.4, + accuracy_test=0.3, + epoch_num=10, + model_variant_label='default')) + + self.results_epoch_15 = AttackResults( + single_attack_results=[self.perfect_classifier_result], + privacy_report_metadata=PrivacyReportMetadata( + accuracy_train=0.5, + accuracy_test=0.4, + epoch_num=15, + model_variant_label='default')) + + self.attack_results_no_metadata = AttackResults( + single_attack_results=[self.perfect_classifier_result]) + + def test_plot_by_epochs_no_metadata(self): + # Raise error if metadata is missing + self.assertRaises(ValueError, privacy_report.plot_by_epochs, + (self.attack_results_no_metadata,), ['AUC']) + + def test_single_metric_plot_by_epochs(self): + fig = privacy_report.plot_by_epochs( + (self.results_epoch_10, self.results_epoch_15), ['AUC']) + # extract data from figure. + auc_data = fig.gca().lines[0].get_data() + # X axis lists epoch values + np.testing.assert_array_equal(auc_data[0], [10, 15]) + # Y axis lists AUC values + np.testing.assert_array_equal(auc_data[1], [0.5, 1.0]) + + def test_multiple_metrics_plot_by_epochs(self): + fig = privacy_report.plot_by_epochs( + (self.results_epoch_10, self.results_epoch_15), + ['AUC', 'Attacker advantage']) + # extract data from figure. + auc_data = fig.axes[0].lines[0].get_data() + attacker_advantage_data = fig.axes[1].lines[0].get_data() + # X axis lists epoch values + np.testing.assert_array_equal(auc_data[0], [10, 15]) + np.testing.assert_array_equal(attacker_advantage_data[0], [10, 15]) + # Y axis lists privacy metrics + np.testing.assert_array_equal(auc_data[1], [0.5, 1.0]) + np.testing.assert_array_equal(attacker_advantage_data[1], [0, 1.0]) + + +if __name__ == '__main__': + absltest.main()