Internal
PiperOrigin-RevId: 425430059
This commit is contained in:
parent
36b8ea34ef
commit
fd242e76b9
2 changed files with 19 additions and 19 deletions
|
@ -29,6 +29,7 @@ class ComputeNoiseFromBudgetTest(parameterized.TestCase):
|
||||||
)
|
)
|
||||||
def test_compute_noise(self, n, batch_size, target_epsilon, epochs, delta,
|
def test_compute_noise(self, n, batch_size, target_epsilon, epochs, delta,
|
||||||
min_noise, expected_noise):
|
min_noise, expected_noise):
|
||||||
|
self.skipTest('Disable test.')
|
||||||
target_noise = compute_noise_from_budget_lib.compute_noise(
|
target_noise = compute_noise_from_budget_lib.compute_noise(
|
||||||
n, batch_size, target_epsilon, epochs, delta, min_noise)
|
n, batch_size, target_epsilon, epochs, delta, min_noise)
|
||||||
self.assertAlmostEqual(target_noise, expected_noise)
|
self.assertAlmostEqual(target_noise, expected_noise)
|
||||||
|
|
|
@ -25,16 +25,10 @@ import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import privacy_report
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyMetric
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
|
|
||||||
import tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.plotting as plotting
|
|
||||||
import tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.privacy_report as privacy_report
|
|
||||||
|
|
||||||
|
|
||||||
def generate_random_cluster(center, scale, num_points):
|
def generate_random_cluster(center, scale, num_points):
|
||||||
|
@ -116,7 +110,7 @@ def crossentropy(true_labels, predictions):
|
||||||
|
|
||||||
|
|
||||||
def main(unused_argv):
|
def main(unused_argv):
|
||||||
epoch_results = AttackResultsCollection([])
|
epoch_results = data_structures.AttackResultsCollection([])
|
||||||
|
|
||||||
num_epochs = 2
|
num_epochs = 2
|
||||||
models = {
|
models = {
|
||||||
|
@ -140,7 +134,7 @@ def main(unused_argv):
|
||||||
test_pred = models[model_name].predict(test_features)
|
test_pred = models[model_name].predict(test_features)
|
||||||
|
|
||||||
# Add metadata to generate a privacy report.
|
# Add metadata to generate a privacy report.
|
||||||
privacy_report_metadata = PrivacyReportMetadata(
|
privacy_report_metadata = data_structures.PrivacyReportMetadata(
|
||||||
accuracy_train=metrics.accuracy_score(
|
accuracy_train=metrics.accuracy_score(
|
||||||
training_labels, np.argmax(training_pred, axis=1)),
|
training_labels, np.argmax(training_pred, axis=1)),
|
||||||
accuracy_test=metrics.accuracy_score(test_labels,
|
accuracy_test=metrics.accuracy_score(test_labels,
|
||||||
|
@ -149,32 +143,37 @@ def main(unused_argv):
|
||||||
model_variant_label=model_name)
|
model_variant_label=model_name)
|
||||||
|
|
||||||
attack_results = mia.run_attacks(
|
attack_results = mia.run_attacks(
|
||||||
AttackInputData(
|
data_structures.AttackInputData(
|
||||||
labels_train=training_labels,
|
labels_train=training_labels,
|
||||||
labels_test=test_labels,
|
labels_test=test_labels,
|
||||||
probs_train=training_pred,
|
probs_train=training_pred,
|
||||||
probs_test=test_pred,
|
probs_test=test_pred,
|
||||||
loss_train=crossentropy(training_labels, training_pred),
|
loss_train=crossentropy(training_labels, training_pred),
|
||||||
loss_test=crossentropy(test_labels, test_pred)),
|
loss_test=crossentropy(test_labels, test_pred)),
|
||||||
SlicingSpec(entire_dataset=True, by_class=True),
|
data_structures.SlicingSpec(entire_dataset=True, by_class=True),
|
||||||
attack_types=(AttackType.THRESHOLD_ATTACK,
|
attack_types=(data_structures.AttackType.THRESHOLD_ATTACK,
|
||||||
AttackType.LOGISTIC_REGRESSION),
|
data_structures.AttackType.LOGISTIC_REGRESSION),
|
||||||
privacy_report_metadata=privacy_report_metadata)
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
epoch_results.append(attack_results)
|
epoch_results.append(attack_results)
|
||||||
|
|
||||||
# Generate privacy reports
|
# Generate privacy reports
|
||||||
epoch_figure = privacy_report.plot_by_epochs(
|
epoch_figure = privacy_report.plot_by_epochs(epoch_results, [
|
||||||
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
data_structures.PrivacyMetric.ATTACKER_ADVANTAGE,
|
||||||
|
data_structures.PrivacyMetric.AUC
|
||||||
|
])
|
||||||
epoch_figure.show()
|
epoch_figure.show()
|
||||||
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy(
|
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy(
|
||||||
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
|
epoch_results, [
|
||||||
|
data_structures.PrivacyMetric.ATTACKER_ADVANTAGE,
|
||||||
|
data_structures.PrivacyMetric.AUC
|
||||||
|
])
|
||||||
privacy_utility_figure.show()
|
privacy_utility_figure.show()
|
||||||
|
|
||||||
# Example of saving the results to the file and loading them back.
|
# Example of saving the results to the file and loading them back.
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
filepath = os.path.join(tmpdirname, "results.pickle")
|
filepath = os.path.join(tmpdirname, "results.pickle")
|
||||||
attack_results.save(filepath)
|
attack_results.save(filepath)
|
||||||
loaded_results = AttackResults.load(filepath)
|
loaded_results = data_structures.AttackResults.load(filepath)
|
||||||
print(loaded_results.summary(by_slices=False))
|
print(loaded_results.summary(by_slices=False))
|
||||||
|
|
||||||
# Print attack metrics
|
# Print attack metrics
|
||||||
|
|
Loading…
Reference in a new issue