forked from 626_privacy/tensorflow_privacy
Internal change.
PiperOrigin-RevId: 335385162
This commit is contained in:
parent
9a56402c0d
commit
ab1090717c
1 changed files with 89 additions and 81 deletions
|
@ -20,6 +20,7 @@ This is using a toy model based on classifying four spacial clusters of data.
|
|||
import os
|
||||
import tempfile
|
||||
|
||||
from absl import app
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -117,6 +118,7 @@ def crossentropy(true_labels, predictions):
|
|||
keras.backend.variable(predictions)))
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
epoch_results = []
|
||||
|
||||
num_epochs = 2
|
||||
|
@ -130,8 +132,8 @@ for model_name in models:
|
|||
models[model_name].fit(
|
||||
training_features,
|
||||
to_categorical(training_labels, num_clusters),
|
||||
validation_data=(test_features, to_categorical(test_labels,
|
||||
num_clusters)),
|
||||
validation_data=(test_features,
|
||||
to_categorical(test_labels, num_clusters)),
|
||||
batch_size=64,
|
||||
epochs=num_epochs,
|
||||
shuffle=True)
|
||||
|
@ -141,8 +143,8 @@ for model_name in models:
|
|||
|
||||
# Add metadata to generate a privacy report.
|
||||
privacy_report_metadata = PrivacyReportMetadata(
|
||||
accuracy_train=metrics.accuracy_score(training_labels,
|
||||
np.argmax(training_pred, axis=1)),
|
||||
accuracy_train=metrics.accuracy_score(
|
||||
training_labels, np.argmax(training_pred, axis=1)),
|
||||
accuracy_test=metrics.accuracy_score(test_labels,
|
||||
np.argmax(test_pred, axis=1)),
|
||||
epoch_num=num_epochs * i,
|
||||
|
@ -175,6 +177,7 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
|||
filepath = os.path.join(tmpdirname, "results.pickle")
|
||||
attack_results.save(filepath)
|
||||
loaded_results = AttackResults.load(filepath)
|
||||
print(loaded_results.summary(by_slices=False))
|
||||
|
||||
# Print attack metrics
|
||||
for attack_result in attack_results.single_attack_results:
|
||||
|
@ -189,7 +192,8 @@ max_auc_attacker = attack_results.get_result_with_max_auc()
|
|||
print("Attack type with max AUC: %s, AUC of %.2f" %
|
||||
(max_auc_attacker.attack_type, max_auc_attacker.roc_curve.get_auc()))
|
||||
|
||||
max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage()
|
||||
max_advantage_attacker = attack_results.get_result_with_max_attacker_advantage(
|
||||
)
|
||||
print("Attack type with max advantage: %s, Attacker advantage of %.2f" %
|
||||
(max_advantage_attacker.attack_type,
|
||||
max_advantage_attacker.roc_curve.get_attacker_advantage()))
|
||||
|
@ -209,7 +213,11 @@ print(attack_results.calculate_pd_dataframe())
|
|||
# Example of ROC curve plotting.
|
||||
figure = plotting.plot_roc_curve(
|
||||
attack_results.single_attack_results[0].roc_curve)
|
||||
figure.show()
|
||||
plt.show()
|
||||
|
||||
# For saving a figure into a file:
|
||||
# plotting.save_plot(figure, <file_path>)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
|
Loading…
Reference in a new issue