Internal change.

PiperOrigin-RevId: 335385162
This commit is contained in:
David Marn 2020-10-05 03:54:01 -07:00 committed by A. Unique TensorFlower
parent 9a56402c0d
commit ab1090717c

View file

@ -20,6 +20,7 @@ This is using a toy model based on classifying four spacial clusters of data.
import os import os
import tempfile import tempfile
from absl import app
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -117,99 +118,106 @@ def crossentropy(true_labels, predictions):
keras.backend.variable(predictions))) keras.backend.variable(predictions)))
epoch_results = [] def main(unused_argv):
epoch_results = []
num_epochs = 2 num_epochs = 2
models = { models = {
"two layer model": two_layer_model, "two layer model": two_layer_model,
"three layer model": three_layer_model, "three layer model": three_layer_model,
} }
for model_name in models: for model_name in models:
# Incrementally train the model and store privacy metrics every num_epochs. # Incrementally train the model and store privacy metrics every num_epochs.
for i in range(1, 6): for i in range(1, 6):
models[model_name].fit( models[model_name].fit(
training_features, training_features,
to_categorical(training_labels, num_clusters), to_categorical(training_labels, num_clusters),
validation_data=(test_features, to_categorical(test_labels, validation_data=(test_features,
num_clusters)), to_categorical(test_labels, num_clusters)),
batch_size=64, batch_size=64,
epochs=num_epochs, epochs=num_epochs,
shuffle=True) shuffle=True)
training_pred = models[model_name].predict(training_features) training_pred = models[model_name].predict(training_features)
test_pred = models[model_name].predict(test_features) test_pred = models[model_name].predict(test_features)
# Add metadata to generate a privacy report. # Add metadata to generate a privacy report.
privacy_report_metadata = PrivacyReportMetadata( privacy_report_metadata = PrivacyReportMetadata(
accuracy_train=metrics.accuracy_score(training_labels, accuracy_train=metrics.accuracy_score(
np.argmax(training_pred, axis=1)), training_labels, np.argmax(training_pred, axis=1)),
accuracy_test=metrics.accuracy_score(test_labels, accuracy_test=metrics.accuracy_score(test_labels,
np.argmax(test_pred, axis=1)), np.argmax(test_pred, axis=1)),
epoch_num=num_epochs * i, epoch_num=num_epochs * i,
model_variant_label=model_name) model_variant_label=model_name)
attack_results = mia.run_attacks( attack_results = mia.run_attacks(
AttackInputData( AttackInputData(
labels_train=training_labels, labels_train=training_labels,
labels_test=test_labels, labels_test=test_labels,
probs_train=training_pred, probs_train=training_pred,
probs_test=test_pred, probs_test=test_pred,
loss_train=crossentropy(training_labels, training_pred), loss_train=crossentropy(training_labels, training_pred),
loss_test=crossentropy(test_labels, test_pred)), loss_test=crossentropy(test_labels, test_pred)),
SlicingSpec(entire_dataset=True, by_class=True), SlicingSpec(entire_dataset=True, by_class=True),
attack_types=(AttackType.THRESHOLD_ATTACK, attack_types=(AttackType.THRESHOLD_ATTACK,
AttackType.LOGISTIC_REGRESSION), AttackType.LOGISTIC_REGRESSION),
privacy_report_metadata=privacy_report_metadata) privacy_report_metadata=privacy_report_metadata)
epoch_results.append(attack_results) epoch_results.append(attack_results)
# Generate privacy reports # Generate privacy reports
epoch_figure = privacy_report.plot_by_epochs( epoch_figure = privacy_report.plot_by_epochs(
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC]) epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
epoch_figure.show() epoch_figure.show()
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model( privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model(
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC]) epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
privacy_utility_figure.show() privacy_utility_figure.show()
# Example of saving the results to the file and loading them back. # Example of saving the results to the file and loading them back.
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "results.pickle") filepath = os.path.join(tmpdirname, "results.pickle")
attack_results.save(filepath) attack_results.save(filepath)
loaded_results = AttackResults.load(filepath) loaded_results = AttackResults.load(filepath)
print(loaded_results.summary(by_slices=False))
# Print attack metrics # Print attack metrics
for attack_result in attack_results.single_attack_results: for attack_result in attack_results.single_attack_results:
print("Slice: %s" % attack_result.slice_spec) print("Slice: %s" % attack_result.slice_spec)
print("Attack type: %s" % attack_result.attack_type) print("Attack type: %s" % attack_result.attack_type)
print("AUC: %.2f" % attack_result.roc_curve.get_auc()) print("AUC: %.2f" % attack_result.roc_curve.get_auc())
print("Attacker advantage: %.2f\n" % print("Attacker advantage: %.2f\n" %
attack_result.roc_curve.get_attacker_advantage()) attack_result.roc_curve.get_attacker_advantage())
max_auc_attacker = attack_results.get_result_with_max_auc() max_auc_attacker = attack_results.get_result_with_max_auc()
print("Attack type with max AUC: %s, AUC of %.2f" % print("Attack type with max AUC: %s, AUC of %.2f" %
(max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc())) (max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc()))
max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage() max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage(
print("Attack type with max advantage: %s, Attacker advantage of %.2f" % )
(max_advantage_attacker.attack_type, print("Attack type with max advantage: %s, Attacker advantage of %.2f" %
max_advantage_attacker.roc_curve.get_attacker_advantage())) (max_advantage_attacker.attack_type,
max_advantage_attacker.roc_curve.get_attacker_advantage()))
# Print summary # Print summary
print("Summary without slices: \n") print("Summary without slices: \n")
print(attack_results.summary(by_slices=False)) print(attack_results.summary(by_slices=False))
print("Summary by slices: \n") print("Summary by slices: \n")
print(attack_results.summary(by_slices=True)) print(attack_results.summary(by_slices=True))
# Print pandas data frame # Print pandas data frame
print("Pandas frame: \n") print("Pandas frame: \n")
pd.set_option("display.max_rows", None, "display.max_columns", None) pd.set_option("display.max_rows", None, "display.max_columns", None)
print(attack_results.calculate_pd_dataframe()) print(attack_results.calculate_pd_dataframe())
# Example of ROC curve plotting. # Example of ROC curve plotting.
figure = plotting.plot_roc_curve( figure = plotting.plot_roc_curve(
attack_results.single_attack_results[0].roc_curve) attack_results.single_attack_results[0].roc_curve)
plt.show() figure.show()
plt.show()
# For saving a figure into a file: # For saving a figure into a file:
# plotting.save_plot(figure, <file_path>) # plotting.save_plot(figure, <file_path>)
if __name__ == "__main__":
app.run(main)