forked from 626_privacy/tensorflow_privacy
Trained attackers no longer fail when labels are missing.
PiperOrigin-RevId: 368598111
This commit is contained in:
parent
edd9c44269
commit
ca347b8995
3 changed files with 43 additions and 11 deletions
|
@ -249,8 +249,14 @@ class AttackInputData:
|
||||||
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
|
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
|
||||||
|
|
||||||
def get_loss_train(self):
|
def get_loss_train(self):
|
||||||
"""Calculates (if needed) cross-entropy losses for the training set."""
|
"""Calculates (if needed) cross-entropy losses for the training set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss (or None if neither the loss nor the labels are present).
|
||||||
|
"""
|
||||||
if self.loss_train is None:
|
if self.loss_train is None:
|
||||||
|
if self.labels_train is None:
|
||||||
|
return None
|
||||||
if self.logits_train is not None:
|
if self.logits_train is not None:
|
||||||
self.loss_train = utils.log_loss_from_logits(self.labels_train,
|
self.loss_train = utils.log_loss_from_logits(self.labels_train,
|
||||||
self.logits_train)
|
self.logits_train)
|
||||||
|
@ -259,9 +265,15 @@ class AttackInputData:
|
||||||
return self.loss_train
|
return self.loss_train
|
||||||
|
|
||||||
def get_loss_test(self):
|
def get_loss_test(self):
|
||||||
"""Calculates (if needed) cross-entropy losses for the test set."""
|
"""Calculates (if needed) cross-entropy losses for the test set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss (or None if neither the loss nor the labels are present).
|
||||||
|
"""
|
||||||
if self.loss_test is None:
|
if self.loss_test is None:
|
||||||
if self.logits_train is not None:
|
if self.labels_test is None:
|
||||||
|
return None
|
||||||
|
if self.logits_test is not None:
|
||||||
self.loss_test = utils.log_loss_from_logits(self.labels_test,
|
self.loss_test = utils.log_loss_from_logits(self.labels_test,
|
||||||
self.logits_test)
|
self.logits_test)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -29,8 +29,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import MembershipProbabilityResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import MembershipProbabilityResults
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||||
PrivacyReportMetadata
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleMembershipProbabilityResult
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleMembershipProbabilityResult
|
||||||
|
@ -93,11 +92,15 @@ def _run_trained_attack(attack_input: AttackInputData,
|
||||||
|
|
||||||
|
|
||||||
def _run_threshold_attack(attack_input: AttackInputData):
|
def _run_threshold_attack(attack_input: AttackInputData):
|
||||||
|
"""Runs a threshold attack on loss."""
|
||||||
ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size()
|
ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size()
|
||||||
|
loss_train = attack_input.get_loss_train()
|
||||||
|
loss_test = attack_input.get_loss_test()
|
||||||
|
if loss_train is None or loss_test is None:
|
||||||
|
raise ValueError('Not possible to run threshold attack without losses.')
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
|
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
|
||||||
np.concatenate(
|
np.concatenate((loss_train, loss_test)))
|
||||||
(attack_input.get_loss_train(), attack_input.get_loss_test())))
|
|
||||||
|
|
||||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||||
|
|
||||||
|
@ -313,10 +316,12 @@ def _compute_missing_privacy_report_metadata(
|
||||||
if metadata.accuracy_test is None:
|
if metadata.accuracy_test is None:
|
||||||
metadata.accuracy_test = _get_accuracy(attack_input.logits_test,
|
metadata.accuracy_test = _get_accuracy(attack_input.logits_test,
|
||||||
attack_input.labels_test)
|
attack_input.labels_test)
|
||||||
if metadata.loss_train is None:
|
loss_train = attack_input.get_loss_train()
|
||||||
metadata.loss_train = np.average(attack_input.get_loss_train())
|
loss_test = attack_input.get_loss_test()
|
||||||
if metadata.loss_test is None:
|
if metadata.loss_train is None and loss_train is not None:
|
||||||
metadata.loss_test = np.average(attack_input.get_loss_test())
|
metadata.loss_train = np.average(loss_train)
|
||||||
|
if metadata.loss_test is None and loss_test is not None:
|
||||||
|
metadata.loss_test = np.average(loss_test)
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,14 @@ def get_test_input(n_train, n_test):
|
||||||
labels_test=np.array([i % 5 for i in range(n_test)]))
|
labels_test=np.array([i % 5 for i in range(n_test)]))
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_input_logits_only(n_train, n_test):
|
||||||
|
"""Get example input logits for attacks."""
|
||||||
|
rng = np.random.RandomState(4)
|
||||||
|
return AttackInputData(
|
||||||
|
logits_train=rng.randn(n_train, 5) + 0.2,
|
||||||
|
logits_test=rng.randn(n_test, 5) + 0.2)
|
||||||
|
|
||||||
|
|
||||||
class RunAttacksTest(absltest.TestCase):
|
class RunAttacksTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_run_attacks_size(self):
|
def test_run_attacks_size(self):
|
||||||
|
@ -45,6 +53,13 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
|
|
||||||
self.assertLen(result.single_attack_results, 2)
|
self.assertLen(result.single_attack_results, 2)
|
||||||
|
|
||||||
|
def test_trained_attacks_logits_only_size(self):
|
||||||
|
result = mia.run_attacks(
|
||||||
|
get_test_input_logits_only(100, 100), SlicingSpec(),
|
||||||
|
(AttackType.LOGISTIC_REGRESSION,))
|
||||||
|
|
||||||
|
self.assertLen(result.single_attack_results, 1)
|
||||||
|
|
||||||
def test_run_attack_trained_sets_attack_type(self):
|
def test_run_attack_trained_sets_attack_type(self):
|
||||||
result = mia._run_attack(
|
result = mia._run_attack(
|
||||||
get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
||||||
|
|
Loading…
Reference in a new issue