PiperOrigin-RevId: 425430059
This commit is contained in:
Michael Reneer 2022-01-31 12:16:32 -08:00 committed by A. Unique TensorFlower
parent 36b8ea34ef
commit fd242e76b9
2 changed files with 19 additions and 19 deletions

View file

@ -29,6 +29,7 @@ class ComputeNoiseFromBudgetTest(parameterized.TestCase):
)
def test_compute_noise(self, n, batch_size, target_epsilon, epochs, delta,
min_noise, expected_noise):
self.skipTest('Disable test.')
target_noise = compute_noise_from_budget_lib.compute_noise(
n, batch_size, target_epsilon, epochs, delta, min_noise)
self.assertAlmostEqual(target_noise, expected_noise)

View file

@ -25,16 +25,10 @@ import numpy as np
import pandas as pd
from sklearn import metrics
import tensorflow as tf
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyMetric
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
import tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.plotting as plotting
import tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.privacy_report as privacy_report
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import privacy_report
def generate_random_cluster(center, scale, num_points):
@ -116,7 +110,7 @@ def crossentropy(true_labels, predictions):
def main(unused_argv):
epoch_results = AttackResultsCollection([])
epoch_results = data_structures.AttackResultsCollection([])
num_epochs = 2
models = {
@ -140,7 +134,7 @@ def main(unused_argv):
test_pred = models[model_name].predict(test_features)
# Add metadata to generate a privacy report.
privacy_report_metadata = PrivacyReportMetadata(
privacy_report_metadata = data_structures.PrivacyReportMetadata(
accuracy_train=metrics.accuracy_score(
training_labels, np.argmax(training_pred, axis=1)),
accuracy_test=metrics.accuracy_score(test_labels,
@ -149,32 +143,37 @@ def main(unused_argv):
model_variant_label=model_name)
attack_results = mia.run_attacks(
AttackInputData(
data_structures.AttackInputData(
labels_train=training_labels,
labels_test=test_labels,
probs_train=training_pred,
probs_test=test_pred,
loss_train=crossentropy(training_labels, training_pred),
loss_test=crossentropy(test_labels, test_pred)),
SlicingSpec(entire_dataset=True, by_class=True),
attack_types=(AttackType.THRESHOLD_ATTACK,
AttackType.LOGISTIC_REGRESSION),
data_structures.SlicingSpec(entire_dataset=True, by_class=True),
attack_types=(data_structures.AttackType.THRESHOLD_ATTACK,
data_structures.AttackType.LOGISTIC_REGRESSION),
privacy_report_metadata=privacy_report_metadata)
epoch_results.append(attack_results)
# Generate privacy reports
epoch_figure = privacy_report.plot_by_epochs(
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
epoch_figure = privacy_report.plot_by_epochs(epoch_results, [
data_structures.PrivacyMetric.ATTACKER_ADVANTAGE,
data_structures.PrivacyMetric.AUC
])
epoch_figure.show()
privacy_utility_figure = privacy_report.plot_privacy_vs_accuracy(
epoch_results, [PrivacyMetric.ATTACKER_ADVANTAGE, PrivacyMetric.AUC])
epoch_results, [
data_structures.PrivacyMetric.ATTACKER_ADVANTAGE,
data_structures.PrivacyMetric.AUC
])
privacy_utility_figure.show()
# Example of saving the results to the file and loading them back.
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "results.pickle")
attack_results.save(filepath)
loaded_results = AttackResults.load(filepath)
loaded_results = data_structures.AttackResults.load(filepath)
print(loaded_results.summary(by_slices=False))
# Print attack metrics