Internal change.

PiperOrigin-RevId: 335385162
This commit is contained in:
David Marn 2020-10-05 03:54:01 -07:00 committed by A. Unique TensorFlower
parent 9a56402c0d
commit ab1090717c

View file

@ -20,6 +20,7 @@ This is using a toy model based on classifying four spacial clusters of data.
import os
import tempfile
from absl import app
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
@ -117,21 +118,22 @@ def crossentropy(true_labels, predictions):
keras.backend.variable(predictions)))
epoch_results = []
def main(unused_argv):
epoch_results = []
num_epochs = 2
models = {
num_epochs = 2
models = {
"two layer model": two_layer_model,
"three layer model": three_layer_model,
}
for model_name in models:
}
for model_name in models:
# Incrementally train the model and store privacy metrics every num_epochs.
for i in range(1, 6):
models[model_name].fit(
training_features,
to_categorical(training_labels, num_clusters),
validation_data=(test_features, to_categorical(test_labels,
num_clusters)),
validation_data=(test_features,
to_categorical(test_labels, num_clusters)),
batch_size=64,
epochs=num_epochs,
shuffle=True)
@ -141,8 +143,8 @@ for model_name in models:
# Add metadata to generate a privacy report.
privacy_report_metadata = PrivacyReportMetadata(
accuracy_train=metrics.accuracy_score(training_labels,
np.argmax(training_pred, axis=1)),
accuracy_train=metrics.accuracy_score(
training_labels, np.argmax(training_pred, axis=1)),
accuracy_test=metrics.accuracy_score(test_labels,
np.argmax(test_pred, axis=1)),
epoch_num=num_epochs * i,
@ -162,22 +164,23 @@ for model_name in models:
privacy_report_metadata=privacy_report_metadata)
epoch_results.append(attack_results)
# Generate privacy reports
epoch_figure = privacy_report.plot_by_epochs(
# Generate privacy reports
epoch_figure = privacy_report.plot_by_epochs(
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
epoch_figure.show()
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model(
epoch_figure.show()
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model(
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
privacy_utility_figure.show()
privacy_utility_figure.show()
# Example of saving the results to the file and loading them back.
with tempfile.TemporaryDirectory() as tmpdirname:
# Example of saving the results to the file and loading them back.
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "results.pickle")
attack_results.save(filepath)
loaded_results = AttackResults.load(filepath)
print(loaded_results.summary(by_slices=False))
# Print attack metrics
for attack_result in attack_results.single_attack_results:
# Print attack metrics
for attack_result in attack_results.single_attack_results:
print("Slice: %s" % attack_result.slice_spec)
print("Attack type: %s" % attack_result.attack_type)
print("AUC: %.2f" % attack_result.roc_curve.get_auc())
@ -185,31 +188,36 @@ for attack_result in attack_results.single_attack_results:
print("Attacker advantage: %.2f\n" %
attack_result.roc_curve.get_attacker_advantage())
max_auc_attacker = attack_results.get_result_with_max_auc()
print("Attack type with max AUC: %s, AUC of %.2f" %
max_auc_attacker = attack_results.get_result_with_max_auc()
print("Attack type with max AUC: %s, AUC of %.2f" %
(max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc()))
max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage()
print("Attack type with max advantage: %s, Attacker advantage of %.2f" %
max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage(
)
print("Attack type with max advantage: %s, Attacker advantage of %.2f" %
(max_advantage_attacker.attack_type,
max_advantage_attacker.roc_curve.get_attacker_advantage()))
# Print summary
print("Summary without slices: \n")
print(attack_results.summary(by_slices=False))
# Print summary
print("Summary without slices: \n")
print(attack_results.summary(by_slices=False))
print("Summary by slices: \n")
print(attack_results.summary(by_slices=True))
print("Summary by slices: \n")
print(attack_results.summary(by_slices=True))
# Print pandas data frame
print("Pandas frame: \n")
pd.set_option("display.max_rows", None, "display.max_columns", None)
print(attack_results.calculate_pd_dataframe())
# Print pandas data frame
print("Pandas frame: \n")
pd.set_option("display.max_rows", None, "display.max_columns", None)
print(attack_results.calculate_pd_dataframe())
# Example of ROC curve plotting.
figure = plotting.plot_roc_curve(
# Example of ROC curve plotting.
figure = plotting.plot_roc_curve(
attack_results.single_attack_results[0].roc_curve)
plt.show()
figure.show()
plt.show()
# For saving a figure into a file:
# plotting.save_plot(figure, <file_path>)
# For saving a figure into a file:
# plotting.save_plot(figure, <file_path>)
if __name__ == "__main__":
app.run(main)