diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index ddbc0c6..cdea468 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -20,6 +20,7 @@ This is using a toy model based on classifying four spacial clusters of data. import os import tempfile +from absl import app import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -117,99 +118,106 @@ def crossentropy(true_labels, predictions): keras.backend.variable(predictions))) -epoch_results = [] +def main(unused_argv): + epoch_results = [] -num_epochs = 2 -models = { - "two layer model": two_layer_model, - "three layer model": three_layer_model, -} -for model_name in models: - # Incrementally train the model and store privacy metrics every num_epochs. - for i in range(1, 6): - models[model_name].fit( - training_features, - to_categorical(training_labels, num_clusters), - validation_data=(test_features, to_categorical(test_labels, - num_clusters)), - batch_size=64, - epochs=num_epochs, - shuffle=True) + num_epochs = 2 + models = { + "two layer model": two_layer_model, + "three layer model": three_layer_model, + } + for model_name in models: + # Incrementally train the model and store privacy metrics every num_epochs. + for i in range(1, 6): + models[model_name].fit( + training_features, + to_categorical(training_labels, num_clusters), + validation_data=(test_features, + to_categorical(test_labels, num_clusters)), + batch_size=64, + epochs=num_epochs, + shuffle=True) - training_pred = models[model_name].predict(training_features) - test_pred = models[model_name].predict(test_features) + training_pred = models[model_name].predict(training_features) + test_pred = models[model_name].predict(test_features) - # Add metadata to generate a privacy report. - privacy_report_metadata = PrivacyReportMetadata( - accuracy_train=metrics.accuracy_score(training_labels, - np.argmax(training_pred, axis=1)), - accuracy_test=metrics.accuracy_score(test_labels, - np.argmax(test_pred, axis=1)), - epoch_num=num_epochs * i, - model_variant_label=model_name) + # Add metadata to generate a privacy report. + privacy_report_metadata = PrivacyReportMetadata( + accuracy_train=metrics.accuracy_score( + training_labels, np.argmax(training_pred, axis=1)), + accuracy_test=metrics.accuracy_score(test_labels, + np.argmax(test_pred, axis=1)), + epoch_num=num_epochs * i, + model_variant_label=model_name) - attack_results = mia.run_attacks( - AttackInputData( - labels_train=training_labels, - labels_test=test_labels, - probs_train=training_pred, - probs_test=test_pred, - loss_train=crossentropy(training_labels, training_pred), - loss_test=crossentropy(test_labels, test_pred)), - SlicingSpec(entire_dataset=True, by_class=True), - attack_types=(AttackType.THRESHOLD_ATTACK, - AttackType.LOGISTIC_REGRESSION), - privacy_report_metadata=privacy_report_metadata) - epoch_results.append(attack_results) + attack_results = mia.run_attacks( + AttackInputData( + labels_train=training_labels, + labels_test=test_labels, + probs_train=training_pred, + probs_test=test_pred, + loss_train=crossentropy(training_labels, training_pred), + loss_test=crossentropy(test_labels, test_pred)), + SlicingSpec(entire_dataset=True, by_class=True), + attack_types=(AttackType.THRESHOLD_ATTACK, + AttackType.LOGISTIC_REGRESSION), + privacy_report_metadata=privacy_report_metadata) + epoch_results.append(attack_results) -# Generate privacy reports -epoch_figure = privacy_report.plot_by_epochs( - epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC]) -epoch_figure.show() -privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model( - epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC]) -privacy_utility_figure.show() + # Generate privacy reports + epoch_figure = privacy_report.plot_by_epochs( + epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC]) + epoch_figure.show() + privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model( + epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC]) + privacy_utility_figure.show() -# Example of saving the results to the file and loading them back. -with tempfile.TemporaryDirectory() as tmpdirname: - filepath = os.path.join(tmpdirname, "results.pickle") - attack_results.save(filepath) - loaded_results = AttackResults.load(filepath) + # Example of saving the results to the file and loading them back. + with tempfile.TemporaryDirectory() as tmpdirname: + filepath = os.path.join(tmpdirname, "results.pickle") + attack_results.save(filepath) + loaded_results = AttackResults.load(filepath) + print(loaded_results.summary(by_slices=False)) -# Print attack metrics -for attack_result in attack_results.single_attack_results: - print("Slice: %s" % attack_result.slice_spec) - print("Attack type: %s" % attack_result.attack_type) - print("AUC: %.2f" % attack_result.roc_curve.get_auc()) + # Print attack metrics + for attack_result in attack_results.single_attack_results: + print("Slice: %s" % attack_result.slice_spec) + print("Attack type: %s" % attack_result.attack_type) + print("AUC: %.2f" % attack_result.roc_curve.get_auc()) - print("Attacker advantage: %.2f\n" % - attack_result.roc_curve.get_attacker_advantage()) + print("Attacker advantage: %.2f\n" % + attack_result.roc_curve.get_attacker_advantage()) -max_auc_attacker = attack_results.get_result_with_max_auc() -print("Attack type with max AUC: %s, AUC of %.2f" % - (max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc())) + max_auc_attacker = attack_results.get_result_with_max_auc() + print("Attack type with max AUC: %s, AUC of %.2f" % + (max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc())) -max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage() -print("Attack type with max advantage: %s, Attacker advantage of %.2f" % - (max_advantage_attacker.attack_type, - max_advantage_attacker.roc_curve.get_attacker_advantage())) + max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage( + ) + print("Attack type with max advantage: %s, Attacker advantage of %.2f" % + (max_advantage_attacker.attack_type, + max_advantage_attacker.roc_curve.get_attacker_advantage())) -# Print summary -print("Summary without slices: \n") -print(attack_results.summary(by_slices=False)) + # Print summary + print("Summary without slices: \n") + print(attack_results.summary(by_slices=False)) -print("Summary by slices: \n") -print(attack_results.summary(by_slices=True)) + print("Summary by slices: \n") + print(attack_results.summary(by_slices=True)) -# Print pandas data frame -print("Pandas frame: \n") -pd.set_option("display.max_rows", None, "display.max_columns", None) -print(attack_results.calculate_pd_dataframe()) + # Print pandas data frame + print("Pandas frame: \n") + pd.set_option("display.max_rows", None, "display.max_columns", None) + print(attack_results.calculate_pd_dataframe()) -# Example of ROC curve plotting. -figure = plotting.plot_roc_curve( - attack_results.single_attack_results[0].roc_curve) -plt.show() + # Example of ROC curve plotting. + figure = plotting.plot_roc_curve( + attack_results.single_attack_results[0].roc_curve) + figure.show() + plt.show() -# For saving a figure into a file: -# plotting.save_plot(figure, ) + # For saving a figure into a file: + # plotting.save_plot(figure, ) + +if __name__ == "__main__": + app.run(main)