forked from 626_privacy/tensorflow_privacy
Trained attackers no longer fail when labels are missing.
PiperOrigin-RevId: 368598111
This commit is contained in:
parent
edd9c44269
commit
ca347b8995
3 changed files with 43 additions and 11 deletions
|
@ -249,8 +249,14 @@ class AttackInputData:
|
|||
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
|
||||
|
||||
def get_loss_train(self):
|
||||
"""Calculates (if needed) cross-entropy losses for the training set."""
|
||||
"""Calculates (if needed) cross-entropy losses for the training set.
|
||||
|
||||
Returns:
|
||||
Loss (or None if neither the loss nor the labels are present).
|
||||
"""
|
||||
if self.loss_train is None:
|
||||
if self.labels_train is None:
|
||||
return None
|
||||
if self.logits_train is not None:
|
||||
self.loss_train = utils.log_loss_from_logits(self.labels_train,
|
||||
self.logits_train)
|
||||
|
@ -259,9 +265,15 @@ class AttackInputData:
|
|||
return self.loss_train
|
||||
|
||||
def get_loss_test(self):
|
||||
"""Calculates (if needed) cross-entropy losses for the test set."""
|
||||
"""Calculates (if needed) cross-entropy losses for the test set.
|
||||
|
||||
Returns:
|
||||
Loss (or None if neither the loss nor the labels are present).
|
||||
"""
|
||||
if self.loss_test is None:
|
||||
if self.logits_train is not None:
|
||||
if self.labels_test is None:
|
||||
return None
|
||||
if self.logits_test is not None:
|
||||
self.loss_test = utils.log_loss_from_logits(self.labels_test,
|
||||
self.logits_test)
|
||||
else:
|
||||
|
|
|
@ -29,8 +29,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import MembershipProbabilityResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||
PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleMembershipProbabilityResult
|
||||
|
@ -93,11 +92,15 @@ def _run_trained_attack(attack_input: AttackInputData,
|
|||
|
||||
|
||||
def _run_threshold_attack(attack_input: AttackInputData):
|
||||
"""Runs a threshold attack on loss."""
|
||||
ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size()
|
||||
loss_train = attack_input.get_loss_train()
|
||||
loss_test = attack_input.get_loss_test()
|
||||
if loss_train is None or loss_test is None:
|
||||
raise ValueError('Not possible to run threshold attack without losses.')
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
|
||||
np.concatenate(
|
||||
(attack_input.get_loss_train(), attack_input.get_loss_test())))
|
||||
np.concatenate((loss_train, loss_test)))
|
||||
|
||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||
|
||||
|
@ -313,10 +316,12 @@ def _compute_missing_privacy_report_metadata(
|
|||
if metadata.accuracy_test is None:
|
||||
metadata.accuracy_test = _get_accuracy(attack_input.logits_test,
|
||||
attack_input.labels_test)
|
||||
if metadata.loss_train is None:
|
||||
metadata.loss_train = np.average(attack_input.get_loss_train())
|
||||
if metadata.loss_test is None:
|
||||
metadata.loss_test = np.average(attack_input.get_loss_test())
|
||||
loss_train = attack_input.get_loss_train()
|
||||
loss_test = attack_input.get_loss_test()
|
||||
if metadata.loss_train is None and loss_train is not None:
|
||||
metadata.loss_train = np.average(loss_train)
|
||||
if metadata.loss_test is None and loss_test is not None:
|
||||
metadata.loss_test = np.average(loss_test)
|
||||
return metadata
|
||||
|
||||
|
||||
|
|
|
@ -36,6 +36,14 @@ def get_test_input(n_train, n_test):
|
|||
labels_test=np.array([i % 5 for i in range(n_test)]))
|
||||
|
||||
|
||||
def get_test_input_logits_only(n_train, n_test):
|
||||
"""Get example input logits for attacks."""
|
||||
rng = np.random.RandomState(4)
|
||||
return AttackInputData(
|
||||
logits_train=rng.randn(n_train, 5) + 0.2,
|
||||
logits_test=rng.randn(n_test, 5) + 0.2)
|
||||
|
||||
|
||||
class RunAttacksTest(absltest.TestCase):
|
||||
|
||||
def test_run_attacks_size(self):
|
||||
|
@ -45,6 +53,13 @@ class RunAttacksTest(absltest.TestCase):
|
|||
|
||||
self.assertLen(result.single_attack_results, 2)
|
||||
|
||||
def test_trained_attacks_logits_only_size(self):
|
||||
result = mia.run_attacks(
|
||||
get_test_input_logits_only(100, 100), SlicingSpec(),
|
||||
(AttackType.LOGISTIC_REGRESSION,))
|
||||
|
||||
self.assertLen(result.single_attack_results, 1)
|
||||
|
||||
def test_run_attack_trained_sets_attack_type(self):
|
||||
result = mia._run_attack(
|
||||
get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
||||
|
|
Loading…
Reference in a new issue