Internal change.
PiperOrigin-RevId: 335385162
This commit is contained in:
parent
9a56402c0d
commit
ab1090717c
1 changed files with 89 additions and 81 deletions
|
@ -20,6 +20,7 @@ This is using a toy model based on classifying four spacial clusters of data.
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
from absl import app
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -117,99 +118,106 @@ def crossentropy(true_labels, predictions):
|
||||||
keras.backend.variable(predictions)))
|
keras.backend.variable(predictions)))
|
||||||
|
|
||||||
|
|
||||||
epoch_results = []
|
def main(unused_argv):
|
||||||
|
epoch_results = []
|
||||||
|
|
||||||
num_epochs = 2
|
num_epochs = 2
|
||||||
models = {
|
models = {
|
||||||
"two layer model": two_layer_model,
|
"two layer model": two_layer_model,
|
||||||
"three layer model": three_layer_model,
|
"three layer model": three_layer_model,
|
||||||
}
|
}
|
||||||
for model_name in models:
|
for model_name in models:
|
||||||
# Incrementally train the model and store privacy metrics every num_epochs.
|
# Incrementally train the model and store privacy metrics every num_epochs.
|
||||||
for i in range(1, 6):
|
for i in range(1, 6):
|
||||||
models[model_name].fit(
|
models[model_name].fit(
|
||||||
training_features,
|
training_features,
|
||||||
to_categorical(training_labels, num_clusters),
|
to_categorical(training_labels, num_clusters),
|
||||||
validation_data=(test_features, to_categorical(test_labels,
|
validation_data=(test_features,
|
||||||
num_clusters)),
|
to_categorical(test_labels, num_clusters)),
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
epochs=num_epochs,
|
epochs=num_epochs,
|
||||||
shuffle=True)
|
shuffle=True)
|
||||||
|
|
||||||
training_pred = models[model_name].predict(training_features)
|
training_pred = models[model_name].predict(training_features)
|
||||||
test_pred = models[model_name].predict(test_features)
|
test_pred = models[model_name].predict(test_features)
|
||||||
|
|
||||||
# Add metadata to generate a privacy report.
|
# Add metadata to generate a privacy report.
|
||||||
privacy_report_metadata = PrivacyReportMetadata(
|
privacy_report_metadata = PrivacyReportMetadata(
|
||||||
accuracy_train=metrics.accuracy_score(training_labels,
|
accuracy_train=metrics.accuracy_score(
|
||||||
np.argmax(training_pred, axis=1)),
|
training_labels, np.argmax(training_pred, axis=1)),
|
||||||
accuracy_test=metrics.accuracy_score(test_labels,
|
accuracy_test=metrics.accuracy_score(test_labels,
|
||||||
np.argmax(test_pred, axis=1)),
|
np.argmax(test_pred, axis=1)),
|
||||||
epoch_num=num_epochs * i,
|
epoch_num=num_epochs * i,
|
||||||
model_variant_label=model_name)
|
model_variant_label=model_name)
|
||||||
|
|
||||||
attack_results = mia.run_attacks(
|
attack_results = mia.run_attacks(
|
||||||
AttackInputData(
|
AttackInputData(
|
||||||
labels_train=training_labels,
|
labels_train=training_labels,
|
||||||
labels_test=test_labels,
|
labels_test=test_labels,
|
||||||
probs_train=training_pred,
|
probs_train=training_pred,
|
||||||
probs_test=test_pred,
|
probs_test=test_pred,
|
||||||
loss_train=crossentropy(training_labels, training_pred),
|
loss_train=crossentropy(training_labels, training_pred),
|
||||||
loss_test=crossentropy(test_labels, test_pred)),
|
loss_test=crossentropy(test_labels, test_pred)),
|
||||||
SlicingSpec(entire_dataset=True, by_class=True),
|
SlicingSpec(entire_dataset=True, by_class=True),
|
||||||
attack_types=(AttackType.THRESHOLD_ATTACK,
|
attack_types=(AttackType.THRESHOLD_ATTACK,
|
||||||
AttackType.LOGISTIC_REGRESSION),
|
AttackType.LOGISTIC_REGRESSION),
|
||||||
privacy_report_metadata=privacy_report_metadata)
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
epoch_results.append(attack_results)
|
epoch_results.append(attack_results)
|
||||||
|
|
||||||
# Generate privacy reports
|
# Generate privacy reports
|
||||||
epoch_figure = privacy_report.plot_by_epochs(
|
epoch_figure = privacy_report.plot_by_epochs(
|
||||||
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
||||||
epoch_figure.show()
|
epoch_figure.show()
|
||||||
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model(
|
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy_single_model(
|
||||||
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
||||||
privacy_utility_figure.show()
|
privacy_utility_figure.show()
|
||||||
|
|
||||||
# Example of saving the results to the file and loading them back.
|
# Example of saving the results to the file and loading them back.
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
filepath = os.path.join(tmpdirname, "results.pickle")
|
filepath = os.path.join(tmpdirname, "results.pickle")
|
||||||
attack_results.save(filepath)
|
attack_results.save(filepath)
|
||||||
loaded_results = AttackResults.load(filepath)
|
loaded_results = AttackResults.load(filepath)
|
||||||
|
print(loaded_results.summary(by_slices=False))
|
||||||
|
|
||||||
# Print attack metrics
|
# Print attack metrics
|
||||||
for attack_result in attack_results.single_attack_results:
|
for attack_result in attack_results.single_attack_results:
|
||||||
print("Slice: %s" % attack_result.slice_spec)
|
print("Slice: %s" % attack_result.slice_spec)
|
||||||
print("Attack type: %s" % attack_result.attack_type)
|
print("Attack type: %s" % attack_result.attack_type)
|
||||||
print("AUC: %.2f" % attack_result.roc_curve.get_auc())
|
print("AUC: %.2f" % attack_result.roc_curve.get_auc())
|
||||||
|
|
||||||
print("Attacker advantage: %.2f\n" %
|
print("Attacker advantage: %.2f\n" %
|
||||||
attack_result.roc_curve.get_attacker_advantage())
|
attack_result.roc_curve.get_attacker_advantage())
|
||||||
|
|
||||||
max_auc_attacker = attack_results.get_result_with_max_auc()
|
max_auc_attacker = attack_results.get_result_with_max_auc()
|
||||||
print("Attack type with max AUC: %s, AUC of %.2f" %
|
print("Attack type with max AUC: %s, AUC of %.2f" %
|
||||||
(max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc()))
|
(max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc()))
|
||||||
|
|
||||||
max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage()
|
max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage(
|
||||||
print("Attack type with max advantage: %s, Attacker advantage of %.2f" %
|
)
|
||||||
(max_advantage_attacker.attack_type,
|
print("Attack type with max advantage: %s, Attacker advantage of %.2f" %
|
||||||
max_advantage_attacker.roc_curve.get_attacker_advantage()))
|
(max_advantage_attacker.attack_type,
|
||||||
|
max_advantage_attacker.roc_curve.get_attacker_advantage()))
|
||||||
|
|
||||||
# Print summary
|
# Print summary
|
||||||
print("Summary without slices: \n")
|
print("Summary without slices: \n")
|
||||||
print(attack_results.summary(by_slices=False))
|
print(attack_results.summary(by_slices=False))
|
||||||
|
|
||||||
print("Summary by slices: \n")
|
print("Summary by slices: \n")
|
||||||
print(attack_results.summary(by_slices=True))
|
print(attack_results.summary(by_slices=True))
|
||||||
|
|
||||||
# Print pandas data frame
|
# Print pandas data frame
|
||||||
print("Pandas frame: \n")
|
print("Pandas frame: \n")
|
||||||
pd.set_option("display.max_rows", None, "display.max_columns", None)
|
pd.set_option("display.max_rows", None, "display.max_columns", None)
|
||||||
print(attack_results.calculate_pd_dataframe())
|
print(attack_results.calculate_pd_dataframe())
|
||||||
|
|
||||||
# Example of ROC curve plotting.
|
# Example of ROC curve plotting.
|
||||||
figure = plotting.plot_roc_curve(
|
figure = plotting.plot_roc_curve(
|
||||||
attack_results.single_attack_results[0].roc_curve)
|
attack_results.single_attack_results[0].roc_curve)
|
||||||
plt.show()
|
figure.show()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
# For saving a figure into a file:
|
# For saving a figure into a file:
|
||||||
# plotting.save_plot(figure, <file_path>)
|
# plotting.save_plot(figure, <file_path>)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(main)
|
||||||
|
|
Loading…
Reference in a new issue