Modifies Privacy Report metadata and adds an epoch chart.
PiperOrigin-RevId: 331326000
This commit is contained in:
parent
f44b63eb78
commit
fc38e3f733
5 changed files with 217 additions and 31 deletions
|
@ -434,6 +434,9 @@ class PrivacyReportMetadata:
|
||||||
loss_train: float = None
|
loss_train: float = None
|
||||||
loss_test: float = None
|
loss_test: float = None
|
||||||
|
|
||||||
|
model_variant_label: str = 'Default model variant'
|
||||||
|
epoch_num: int = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AttackResults:
|
class AttackResults:
|
||||||
|
@ -466,8 +469,8 @@ class AttackResults:
|
||||||
'slice feature': slice_features,
|
'slice feature': slice_features,
|
||||||
'slice value': slice_values,
|
'slice value': slice_values,
|
||||||
'attack type': attack_types,
|
'attack type': attack_types,
|
||||||
'attack advantage': advantages,
|
'Attacker advantage': advantages,
|
||||||
'roc auc': aucs
|
'AUC': aucs
|
||||||
})
|
})
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
|
@ -305,8 +305,8 @@ class AttackResultsTest(absltest.TestCase):
|
||||||
'slice feature': ['correctly_classfied', 'entire_dataset'],
|
'slice feature': ['correctly_classfied', 'entire_dataset'],
|
||||||
'slice value': ['True', ''],
|
'slice value': ['True', ''],
|
||||||
'attack type': ['threshold', 'threshold'],
|
'attack type': ['threshold', 'threshold'],
|
||||||
'attack advantage': [1.0, 0.0],
|
'Attacker advantage': [1.0, 0.0],
|
||||||
'roc auc': [1.0, 0.5]
|
'AUC': [1.0, 0.5]
|
||||||
})
|
})
|
||||||
self.assertTrue(df.equals(df_expected))
|
self.assertTrue(df.equals(df_expected))
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
||||||
PrivacyReportMetadata
|
PrivacyReportMetadata
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting
|
import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting
|
||||||
|
import tensorflow_privacy.privacy.membership_inference_attack.privacy_report as privacy_report
|
||||||
|
|
||||||
|
|
||||||
def generate_random_cluster(center, scale, num_points):
|
def generate_random_cluster(center, scale, num_points):
|
||||||
|
@ -96,16 +97,6 @@ model = keras.models.Sequential([
|
||||||
])
|
])
|
||||||
model.compile(
|
model.compile(
|
||||||
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
|
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
|
||||||
model.fit(
|
|
||||||
training_features,
|
|
||||||
to_categorical(training_labels, num_clusters),
|
|
||||||
validation_data=(test_features, to_categorical(test_labels, num_clusters)),
|
|
||||||
batch_size=64,
|
|
||||||
epochs=2,
|
|
||||||
shuffle=True)
|
|
||||||
|
|
||||||
training_pred = model.predict(training_features)
|
|
||||||
test_pred = model.predict(test_features)
|
|
||||||
|
|
||||||
|
|
||||||
def crossentropy(true_labels, predictions):
|
def crossentropy(true_labels, predictions):
|
||||||
|
@ -115,14 +106,32 @@ def crossentropy(true_labels, predictions):
|
||||||
keras.backend.variable(predictions)))
|
keras.backend.variable(predictions)))
|
||||||
|
|
||||||
|
|
||||||
# Add metadata to generate a privacy report.
|
epoch_results = []
|
||||||
privacy_report_metadata = PrivacyReportMetadata(
|
|
||||||
|
# Incrementally train the model and store privacy risk metrics every 10 epochs.
|
||||||
|
for i in range(1, 6):
|
||||||
|
model.fit(
|
||||||
|
training_features,
|
||||||
|
to_categorical(training_labels, num_clusters),
|
||||||
|
validation_data=(test_features, to_categorical(test_labels,
|
||||||
|
num_clusters)),
|
||||||
|
batch_size=64,
|
||||||
|
epochs=2,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
|
training_pred = model.predict(training_features)
|
||||||
|
test_pred = model.predict(test_features)
|
||||||
|
|
||||||
|
# Add metadata to generate a privacy report.
|
||||||
|
privacy_report_metadata = PrivacyReportMetadata(
|
||||||
accuracy_train=metrics.accuracy_score(training_labels,
|
accuracy_train=metrics.accuracy_score(training_labels,
|
||||||
np.argmax(training_pred, axis=1)),
|
np.argmax(training_pred, axis=1)),
|
||||||
accuracy_test=metrics.accuracy_score(test_labels,
|
accuracy_test=metrics.accuracy_score(test_labels,
|
||||||
np.argmax(test_pred, axis=1)))
|
np.argmax(test_pred, axis=1)),
|
||||||
|
epoch_num=2 * i,
|
||||||
|
model_variant_label="default")
|
||||||
|
|
||||||
attack_results = mia.run_attacks(
|
attack_results = mia.run_attacks(
|
||||||
AttackInputData(
|
AttackInputData(
|
||||||
labels_train=training_labels,
|
labels_train=training_labels,
|
||||||
labels_test=test_labels,
|
labels_test=test_labels,
|
||||||
|
@ -131,8 +140,15 @@ attack_results = mia.run_attacks(
|
||||||
loss_train=crossentropy(training_labels, training_pred),
|
loss_train=crossentropy(training_labels, training_pred),
|
||||||
loss_test=crossentropy(test_labels, test_pred)),
|
loss_test=crossentropy(test_labels, test_pred)),
|
||||||
SlicingSpec(entire_dataset=True, by_class=True),
|
SlicingSpec(entire_dataset=True, by_class=True),
|
||||||
attack_types=(AttackType.THRESHOLD_ATTACK, AttackType.LOGISTIC_REGRESSION),
|
attack_types=(AttackType.THRESHOLD_ATTACK,
|
||||||
privacy_report_metadata=None)
|
AttackType.LOGISTIC_REGRESSION),
|
||||||
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
|
epoch_results.append(attack_results)
|
||||||
|
|
||||||
|
# Generate privacy report
|
||||||
|
epoch_figure = privacy_report.plot_by_epochs(epoch_results,
|
||||||
|
["Attacker advantage", "AUC"])
|
||||||
|
epoch_figure.show()
|
||||||
|
|
||||||
# Example of saving the results to the file and loading them back.
|
# Example of saving the results to the file and loading them back.
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""Plotting code for ML Privacy Reports."""
|
||||||
|
from typing import Iterable
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
|
||||||
|
|
||||||
|
def plot_by_epochs(results: Iterable[AttackResults],
|
||||||
|
privacy_metrics: Iterable[str]) -> plt.Figure:
|
||||||
|
"""Plots privacy vulnerabilities by epochs."""
|
||||||
|
_validate_results(results)
|
||||||
|
all_results_df = None
|
||||||
|
for attack_results in results:
|
||||||
|
attack_results_df = attack_results.calculate_pd_dataframe()
|
||||||
|
attack_results_df = attack_results_df.loc[attack_results_df['slice feature']
|
||||||
|
== 'entire_dataset']
|
||||||
|
attack_results_df.insert(0, 'Epoch',
|
||||||
|
attack_results.privacy_report_metadata.epoch_num)
|
||||||
|
if all_results_df is None:
|
||||||
|
all_results_df = attack_results_df
|
||||||
|
else:
|
||||||
|
all_results_df = pd.concat([all_results_df, attack_results_df],
|
||||||
|
ignore_index=True)
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(1, len(privacy_metrics))
|
||||||
|
if len(privacy_metrics) == 1:
|
||||||
|
axes = (axes,)
|
||||||
|
for i, privacy_metric in enumerate(privacy_metrics):
|
||||||
|
attack_types = all_results_df['attack type'].unique()
|
||||||
|
for attack_type in attack_types:
|
||||||
|
axes[i].plot(
|
||||||
|
all_results_df.loc[all_results_df['attack type'] == attack_type]
|
||||||
|
['Epoch'], all_results_df.loc[all_results_df['attack type'] ==
|
||||||
|
attack_type][privacy_metric])
|
||||||
|
axes[i].legend(attack_types)
|
||||||
|
axes[i].set_xlabel('Epoch')
|
||||||
|
axes[i].set_title('%s for Entire dataset' % privacy_metric)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_results(results: Iterable[AttackResults]):
|
||||||
|
for attack_results in results:
|
||||||
|
if not attack_results or not attack_results.privacy_report_metadata:
|
||||||
|
raise ValueError('Privacy metadata is not defined.')
|
||||||
|
if not attack_results.privacy_report_metadata.epoch_num:
|
||||||
|
raise ValueError('epoch_num in metadata is not defined.')
|
|
@ -0,0 +1,104 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.privacy_report."""
|
||||||
|
from absl.testing import absltest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import privacy_report
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||||
|
PrivacyReportMetadata
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
|
|
||||||
|
|
||||||
|
class PrivacyReportTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(PrivacyReportTest, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Classifier that achieves an AUC of 0.5.
|
||||||
|
self.imperfect_classifier_result = SingleAttackResult(
|
||||||
|
slice_spec=SingleSliceSpec(None),
|
||||||
|
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||||
|
roc_curve=RocCurve(
|
||||||
|
tpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
fpr=np.array([0.0, 0.5, 1.0]),
|
||||||
|
thresholds=np.array([0, 1, 2])))
|
||||||
|
|
||||||
|
# Classifier that achieves an AUC of 1.0.
|
||||||
|
self.perfect_classifier_result = SingleAttackResult(
|
||||||
|
slice_spec=SingleSliceSpec(None),
|
||||||
|
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||||
|
roc_curve=RocCurve(
|
||||||
|
tpr=np.array([0.0, 1.0, 1.0]),
|
||||||
|
fpr=np.array([1.0, 1.0, 0.0]),
|
||||||
|
thresholds=np.array([0, 1, 2])))
|
||||||
|
|
||||||
|
self.results_epoch_10 = AttackResults(
|
||||||
|
single_attack_results=[self.imperfect_classifier_result],
|
||||||
|
privacy_report_metadata=PrivacyReportMetadata(
|
||||||
|
accuracy_train=0.4,
|
||||||
|
accuracy_test=0.3,
|
||||||
|
epoch_num=10,
|
||||||
|
model_variant_label='default'))
|
||||||
|
|
||||||
|
self.results_epoch_15 = AttackResults(
|
||||||
|
single_attack_results=[self.perfect_classifier_result],
|
||||||
|
privacy_report_metadata=PrivacyReportMetadata(
|
||||||
|
accuracy_train=0.5,
|
||||||
|
accuracy_test=0.4,
|
||||||
|
epoch_num=15,
|
||||||
|
model_variant_label='default'))
|
||||||
|
|
||||||
|
self.attack_results_no_metadata = AttackResults(
|
||||||
|
single_attack_results=[self.perfect_classifier_result])
|
||||||
|
|
||||||
|
def test_plot_by_epochs_no_metadata(self):
|
||||||
|
# Raise error if metadata is missing
|
||||||
|
self.assertRaises(ValueError, privacy_report.plot_by_epochs,
|
||||||
|
(self.attack_results_no_metadata,), ['AUC'])
|
||||||
|
|
||||||
|
def test_single_metric_plot_by_epochs(self):
|
||||||
|
fig = privacy_report.plot_by_epochs(
|
||||||
|
(self.results_epoch_10, self.results_epoch_15), ['AUC'])
|
||||||
|
# extract data from figure.
|
||||||
|
auc_data = fig.gca().lines[0].get_data()
|
||||||
|
# X axis lists epoch values
|
||||||
|
np.testing.assert_array_equal(auc_data[0], [10, 15])
|
||||||
|
# Y axis lists AUC values
|
||||||
|
np.testing.assert_array_equal(auc_data[1], [0.5, 1.0])
|
||||||
|
|
||||||
|
def test_multiple_metrics_plot_by_epochs(self):
|
||||||
|
fig = privacy_report.plot_by_epochs(
|
||||||
|
(self.results_epoch_10, self.results_epoch_15),
|
||||||
|
['AUC', 'Attacker advantage'])
|
||||||
|
# extract data from figure.
|
||||||
|
auc_data = fig.axes[0].lines[0].get_data()
|
||||||
|
attacker_advantage_data = fig.axes[1].lines[0].get_data()
|
||||||
|
# X axis lists epoch values
|
||||||
|
np.testing.assert_array_equal(auc_data[0], [10, 15])
|
||||||
|
np.testing.assert_array_equal(attacker_advantage_data[0], [10, 15])
|
||||||
|
# Y axis lists privacy metrics
|
||||||
|
np.testing.assert_array_equal(auc_data[1], [0.5, 1.0])
|
||||||
|
np.testing.assert_array_equal(attacker_advantage_data[1], [0, 1.0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
Loading…
Reference in a new issue