Adds plots for multiple model labels to the ML Privacy Report.
PiperOrigin-RevId: 334179759
This commit is contained in:
parent
837e014107
commit
c30c3fcb7a
3 changed files with 116 additions and 44 deletions
|
@ -90,14 +90,23 @@ num_clusters = int(round(np.max(training_labels))) + 1
|
|||
|
||||
# Hint: play with the number of layers to achieve different level of
|
||||
# over-fitting and observe its effects on membership inference performance.
|
||||
model = keras.models.Sequential([
|
||||
three_layer_model = keras.models.Sequential([
|
||||
layers.Dense(300, activation="relu"),
|
||||
layers.Dense(300, activation="relu"),
|
||||
layers.Dense(300, activation="relu"),
|
||||
layers.Dense(num_clusters, activation="relu"),
|
||||
layers.Softmax()
|
||||
])
|
||||
model.compile(
|
||||
three_layer_model.compile(
|
||||
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
|
||||
|
||||
two_layer_model = keras.models.Sequential([
|
||||
layers.Dense(300, activation="relu"),
|
||||
layers.Dense(300, activation="relu"),
|
||||
layers.Dense(num_clusters, activation="relu"),
|
||||
layers.Softmax()
|
||||
])
|
||||
two_layer_model.compile(
|
||||
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
|
||||
|
||||
|
||||
|
@ -110,43 +119,48 @@ def crossentropy(true_labels, predictions):
|
|||
|
||||
epoch_results = []
|
||||
|
||||
# Incrementally train the model and store privacy risk metrics every 10 epochs.
|
||||
num_epochs = 2
|
||||
for i in range(1, 6):
|
||||
model.fit(
|
||||
training_features,
|
||||
to_categorical(training_labels, num_clusters),
|
||||
validation_data=(test_features, to_categorical(test_labels,
|
||||
num_clusters)),
|
||||
batch_size=64,
|
||||
epochs=num_epochs,
|
||||
shuffle=True)
|
||||
models = {
|
||||
"two layer model": two_layer_model,
|
||||
"three layer model": three_layer_model,
|
||||
}
|
||||
for model_name in models:
|
||||
# Incrementally train the model and store privacy metrics every num_epochs.
|
||||
for i in range(1, 6):
|
||||
models[model_name].fit(
|
||||
training_features,
|
||||
to_categorical(training_labels, num_clusters),
|
||||
validation_data=(test_features, to_categorical(test_labels,
|
||||
num_clusters)),
|
||||
batch_size=64,
|
||||
epochs=num_epochs,
|
||||
shuffle=True)
|
||||
|
||||
training_pred = model.predict(training_features)
|
||||
test_pred = model.predict(test_features)
|
||||
training_pred = models[model_name].predict(training_features)
|
||||
test_pred = models[model_name].predict(test_features)
|
||||
|
||||
# Add metadata to generate a privacy report.
|
||||
privacy_report_metadata = PrivacyReportMetadata(
|
||||
accuracy_train=metrics.accuracy_score(training_labels,
|
||||
np.argmax(training_pred, axis=1)),
|
||||
accuracy_test=metrics.accuracy_score(test_labels,
|
||||
np.argmax(test_pred, axis=1)),
|
||||
epoch_num=num_epochs * i,
|
||||
model_variant_label="default")
|
||||
# Add metadata to generate a privacy report.
|
||||
privacy_report_metadata = PrivacyReportMetadata(
|
||||
accuracy_train=metrics.accuracy_score(training_labels,
|
||||
np.argmax(training_pred, axis=1)),
|
||||
accuracy_test=metrics.accuracy_score(test_labels,
|
||||
np.argmax(test_pred, axis=1)),
|
||||
epoch_num=num_epochs * i,
|
||||
model_variant_label=model_name)
|
||||
|
||||
attack_results = mia.run_attacks(
|
||||
AttackInputData(
|
||||
labels_train=training_labels,
|
||||
labels_test=test_labels,
|
||||
probs_train=training_pred,
|
||||
probs_test=test_pred,
|
||||
loss_train=crossentropy(training_labels, training_pred),
|
||||
loss_test=crossentropy(test_labels, test_pred)),
|
||||
SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=(AttackType.THRESHOLD_ATTACK,
|
||||
AttackType.LOGISTIC_REGRESSION),
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
epoch_results.append(attack_results)
|
||||
attack_results = mia.run_attacks(
|
||||
AttackInputData(
|
||||
labels_train=training_labels,
|
||||
labels_test=test_labels,
|
||||
probs_train=training_pred,
|
||||
probs_test=test_pred,
|
||||
loss_train=crossentropy(training_labels, training_pred),
|
||||
loss_test=crossentropy(test_labels, test_pred)),
|
||||
SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=(AttackType.THRESHOLD_ATTACK,
|
||||
AttackType.LOGISTIC_REGRESSION),
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
epoch_results.append(attack_results)
|
||||
|
||||
# Generate privacy reports
|
||||
epoch_figure = privacy_report.plot_by_epochs(
|
||||
|
|
|
@ -80,6 +80,10 @@ def _calculate_combined_df_with_metadata(results: Iterable[AttackResults]):
|
|||
attack_results_df.insert(
|
||||
0, 'Train accuracy',
|
||||
attack_results.privacy_report_metadata.accuracy_train)
|
||||
attack_results_df.insert(
|
||||
0, 'legend label',
|
||||
attack_results.privacy_report_metadata.model_variant_label + ' - ' +
|
||||
attack_results_df['attack type'])
|
||||
if all_results_df is None:
|
||||
all_results_df = attack_results_df
|
||||
else:
|
||||
|
@ -98,15 +102,15 @@ def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
|
|||
if len(privacy_metrics) == 1:
|
||||
axes = (axes,)
|
||||
for i, privacy_metric in enumerate(privacy_metrics):
|
||||
attack_types = all_results_df['attack type'].unique()
|
||||
for attack_type in attack_types:
|
||||
attack_type_results = all_results_df.loc[all_results_df['attack type'] ==
|
||||
attack_type]
|
||||
axes[i].plot(attack_type_results[x_axis_metric],
|
||||
attack_type_results[str(privacy_metric)])
|
||||
axes[i].legend(attack_types)
|
||||
axes[i].set_xlabel(x_axis_metric)
|
||||
axes[i].set_title('%s for Entire dataset' % str(privacy_metric))
|
||||
legend_labels = all_results_df['legend label'].unique()
|
||||
for legend_label in legend_labels:
|
||||
single_label_results = all_results_df.loc[all_results_df['legend label']
|
||||
== legend_label]
|
||||
axes[i].plot(single_label_results[x_axis_metric],
|
||||
single_label_results[str(privacy_metric)])
|
||||
axes[i].legend(legend_labels)
|
||||
axes[i].set_xlabel(x_axis_metric)
|
||||
axes[i].set_title('%s for Entire dataset' % str(privacy_metric))
|
||||
|
||||
return fig
|
||||
|
||||
|
|
|
@ -67,6 +67,14 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
epoch_num=15,
|
||||
model_variant_label='default'))
|
||||
|
||||
self.results_epoch_15_model_2 = AttackResults(
|
||||
single_attack_results=[self.perfect_classifier_result],
|
||||
privacy_report_metadata=PrivacyReportMetadata(
|
||||
accuracy_train=0.6,
|
||||
accuracy_test=0.7,
|
||||
epoch_num=15,
|
||||
model_variant_label='model 2'))
|
||||
|
||||
self.attack_results_no_metadata = AttackResults(
|
||||
single_attack_results=[self.perfect_classifier_result])
|
||||
|
||||
|
@ -103,6 +111,29 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
# Check the title
|
||||
self.assertEqual(fig._suptitle.get_text(), 'Vulnerability per Epoch')
|
||||
|
||||
def test_multiple_metrics_plot_by_epochs_multiple_models(self):
|
||||
fig = privacy_report.plot_by_epochs(
|
||||
(self.results_epoch_10, self.results_epoch_15,
|
||||
self.results_epoch_15_model_2), ['AUC', 'Attacker advantage'])
|
||||
# extract data from figure.
|
||||
# extract data from figure.
|
||||
auc_data_model_1 = fig.axes[0].lines[0].get_data()
|
||||
auc_data_model_2 = fig.axes[0].lines[1].get_data()
|
||||
attacker_advantage_data_model_1 = fig.axes[1].lines[0].get_data()
|
||||
attacker_advantage_data_model_2 = fig.axes[1].lines[1].get_data()
|
||||
# X axis lists epoch values
|
||||
np.testing.assert_array_equal(auc_data_model_1[0], [10, 15])
|
||||
np.testing.assert_array_equal(auc_data_model_2[0], [15])
|
||||
np.testing.assert_array_equal(attacker_advantage_data_model_1[0], [10, 15])
|
||||
np.testing.assert_array_equal(attacker_advantage_data_model_2[0], [15])
|
||||
# Y axis lists privacy metrics
|
||||
np.testing.assert_array_equal(auc_data_model_1[1], [0.5, 1.0])
|
||||
np.testing.assert_array_equal(auc_data_model_2[1], [1.0])
|
||||
np.testing.assert_array_equal(attacker_advantage_data_model_1[1], [0, 1.0])
|
||||
np.testing.assert_array_equal(attacker_advantage_data_model_2[1], [1.0])
|
||||
# Check the title
|
||||
self.assertEqual(fig._suptitle.get_text(), 'Vulnerability per Epoch')
|
||||
|
||||
def test_plot_privacy_vs_accuracy_single_model_no_metadata(self):
|
||||
# Raise error if metadata is missing
|
||||
self.assertRaises(ValueError,
|
||||
|
@ -137,6 +168,29 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
# Check the title
|
||||
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
|
||||
|
||||
def test_multiple_metrics_plot_privacy_vs_accuracy_multiple_model(self):
|
||||
fig = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||
(self.results_epoch_10, self.results_epoch_15,
|
||||
self.results_epoch_15_model_2), ['AUC', 'Attacker advantage'])
|
||||
# extract data from figure.
|
||||
auc_data_model_1 = fig.axes[0].lines[0].get_data()
|
||||
auc_data_model_2 = fig.axes[0].lines[1].get_data()
|
||||
attacker_advantage_data_model_1 = fig.axes[1].lines[0].get_data()
|
||||
attacker_advantage_data_model_2 = fig.axes[1].lines[1].get_data()
|
||||
# X axis lists epoch values
|
||||
np.testing.assert_array_equal(auc_data_model_1[0], [0.4, 0.5])
|
||||
np.testing.assert_array_equal(auc_data_model_2[0], [0.6])
|
||||
np.testing.assert_array_equal(attacker_advantage_data_model_1[0],
|
||||
[0.4, 0.5])
|
||||
np.testing.assert_array_equal(attacker_advantage_data_model_2[0], [0.6])
|
||||
# Y axis lists privacy metrics
|
||||
np.testing.assert_array_equal(auc_data_model_1[1], [0.5, 1.0])
|
||||
np.testing.assert_array_equal(auc_data_model_2[1], [1.0])
|
||||
np.testing.assert_array_equal(attacker_advantage_data_model_1[1], [0, 1.0])
|
||||
np.testing.assert_array_equal(attacker_advantage_data_model_2[1], [1.0])
|
||||
# Check the title
|
||||
self.assertEqual(fig._suptitle.get_text(), 'Privacy vs Utility Analysis')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue