Trained attackers no longer fail when labels are missing.

PiperOrigin-RevId: 368598111
This commit is contained in:
David Marn 2021-04-15 02:18:36 -07:00 committed by A. Unique TensorFlower
parent edd9c44269
commit ca347b8995
3 changed files with 43 additions and 11 deletions

View file

@ -249,8 +249,14 @@ class AttackInputData:
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1) return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
def get_loss_train(self): def get_loss_train(self):
"""Calculates (if needed) cross-entropy losses for the training set.""" """Calculates (if needed) cross-entropy losses for the training set.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if self.loss_train is None: if self.loss_train is None:
if self.labels_train is None:
return None
if self.logits_train is not None: if self.logits_train is not None:
self.loss_train = utils.log_loss_from_logits(self.labels_train, self.loss_train = utils.log_loss_from_logits(self.labels_train,
self.logits_train) self.logits_train)
@ -259,9 +265,15 @@ class AttackInputData:
return self.loss_train return self.loss_train
def get_loss_test(self): def get_loss_test(self):
"""Calculates (if needed) cross-entropy losses for the test set.""" """Calculates (if needed) cross-entropy losses for the test set.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if self.loss_test is None: if self.loss_test is None:
if self.logits_train is not None: if self.labels_test is None:
return None
if self.logits_test is not None:
self.loss_test = utils.log_loss_from_logits(self.labels_test, self.loss_test = utils.log_loss_from_logits(self.labels_test,
self.logits_test) self.logits_test)
else: else:

View file

@ -29,8 +29,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import MembershipProbabilityResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import MembershipProbabilityResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleMembershipProbabilityResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleMembershipProbabilityResult
@ -93,11 +92,15 @@ def _run_trained_attack(attack_input: AttackInputData,
def _run_threshold_attack(attack_input: AttackInputData): def _run_threshold_attack(attack_input: AttackInputData):
"""Runs a threshold attack on loss."""
ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size() ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size()
loss_train = attack_input.get_loss_train()
loss_test = attack_input.get_loss_test()
if loss_train is None or loss_test is None:
raise ValueError('Not possible to run threshold attack without losses.')
fpr, tpr, thresholds = metrics.roc_curve( fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.zeros(ntrain), np.ones(ntest))), np.concatenate((np.zeros(ntrain), np.ones(ntest))),
np.concatenate( np.concatenate((loss_train, loss_test)))
(attack_input.get_loss_train(), attack_input.get_loss_test())))
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
@ -313,10 +316,12 @@ def _compute_missing_privacy_report_metadata(
if metadata.accuracy_test is None: if metadata.accuracy_test is None:
metadata.accuracy_test = _get_accuracy(attack_input.logits_test, metadata.accuracy_test = _get_accuracy(attack_input.logits_test,
attack_input.labels_test) attack_input.labels_test)
if metadata.loss_train is None: loss_train = attack_input.get_loss_train()
metadata.loss_train = np.average(attack_input.get_loss_train()) loss_test = attack_input.get_loss_test()
if metadata.loss_test is None: if metadata.loss_train is None and loss_train is not None:
metadata.loss_test = np.average(attack_input.get_loss_test()) metadata.loss_train = np.average(loss_train)
if metadata.loss_test is None and loss_test is not None:
metadata.loss_test = np.average(loss_test)
return metadata return metadata

View file

@ -36,6 +36,14 @@ def get_test_input(n_train, n_test):
labels_test=np.array([i % 5 for i in range(n_test)])) labels_test=np.array([i % 5 for i in range(n_test)]))
def get_test_input_logits_only(n_train, n_test):
"""Get example input logits for attacks."""
rng = np.random.RandomState(4)
return AttackInputData(
logits_train=rng.randn(n_train, 5) + 0.2,
logits_test=rng.randn(n_test, 5) + 0.2)
class RunAttacksTest(absltest.TestCase): class RunAttacksTest(absltest.TestCase):
def test_run_attacks_size(self): def test_run_attacks_size(self):
@ -45,6 +53,13 @@ class RunAttacksTest(absltest.TestCase):
self.assertLen(result.single_attack_results, 2) self.assertLen(result.single_attack_results, 2)
def test_trained_attacks_logits_only_size(self):
result = mia.run_attacks(
get_test_input_logits_only(100, 100), SlicingSpec(),
(AttackType.LOGISTIC_REGRESSION,))
self.assertLen(result.single_attack_results, 1)
def test_run_attack_trained_sets_attack_type(self): def test_run_attack_trained_sets_attack_type(self):
result = mia._run_attack( result = mia._run_attack(
get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION) get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)