Internal change.

PiperOrigin-RevId: 335385162
This commit is contained in:
David Marn 2020-10-05 03:54:01 -07:00 committed by A. Unique TensorFlower
parent 9a56402c0d
commit ab1090717c

View file

@ -20,6 +20,7 @@ This is using a toy model based on classifying four spacial clusters of data.
import os import os
import tempfile import tempfile
from absl import app
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -117,6 +118,7 @@ def crossentropy(true_labels, predictions):
keras.backend.variable(predictions))) keras.backend.variable(predictions)))
def main(unused_argv):
epoch_results = [] epoch_results = []
num_epochs = 2 num_epochs = 2
@ -130,8 +132,8 @@ for model_name in models:
models[model_name].fit( models[model_name].fit(
training_features, training_features,
to_categorical(training_labels, num_clusters), to_categorical(training_labels, num_clusters),
validation_data=(test_features, to_categorical(test_labels, validation_data=(test_features,
num_clusters)), to_categorical(test_labels, num_clusters)),
batch_size=64, batch_size=64,
epochs=num_epochs, epochs=num_epochs,
shuffle=True) shuffle=True)
@ -141,8 +143,8 @@ for model_name in models:
# Add metadata to generate a privacy report. # Add metadata to generate a privacy report.
privacy_report_metadata = PrivacyReportMetadata( privacy_report_metadata = PrivacyReportMetadata(
accuracy_train=metrics.accuracy_score(training_labels, accuracy_train=metrics.accuracy_score(
np.argmax(training_pred, axis=1)), training_labels, np.argmax(training_pred, axis=1)),
accuracy_test=metrics.accuracy_score(test_labels, accuracy_test=metrics.accuracy_score(test_labels,
np.argmax(test_pred, axis=1)), np.argmax(test_pred, axis=1)),
epoch_num=num_epochs * i, epoch_num=num_epochs * i,
@ -175,6 +177,7 @@ with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "results.pickle") filepath = os.path.join(tmpdirname, "results.pickle")
attack_results.save(filepath) attack_results.save(filepath)
loaded_results = AttackResults.load(filepath) loaded_results = AttackResults.load(filepath)
print(loaded_results.summary(by_slices=False))
# Print attack metrics # Print attack metrics
for attack_result in attack_results.single_attack_results: for attack_result in attack_results.single_attack_results:
@ -189,7 +192,8 @@ max_auc_attacker = attack_results.get_result_with_max_auc()
print("Attack type with max AUC: %s, AUC of %.2f" % print("Attack type with max AUC: %s, AUC of %.2f" %
(max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc())) (max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc()))
max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage() max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage(
)
print("Attack type with max advantage: %s, Attacker advantage of %.2f" % print("Attack type with max advantage: %s, Attacker advantage of %.2f" %
(max_advantage_attacker.attack_type, (max_advantage_attacker.attack_type,
max_advantage_attacker.roc_curve.get_attacker_advantage())) max_advantage_attacker.roc_curve.get_attacker_advantage()))
@ -209,7 +213,11 @@ print(attack_results.calculate_pd_dataframe())
# Example of ROC curve plotting. # Example of ROC curve plotting.
figure = plotting.plot_roc_curve( figure = plotting.plot_roc_curve(
attack_results.single_attack_results[0].roc_curve) attack_results.single_attack_results[0].roc_curve)
figure.show()
plt.show() plt.show()
# For saving a figure into a file: # For saving a figure into a file:
# plotting.save_plot(figure, <file_path>) # plotting.save_plot(figure, <file_path>)
if __name__ == "__main__":
app.run(main)