forked from 626_privacy/tensorflow_privacy
Add probabilities to AttackInputData.
PiperOrigin-RevId: 330723370
This commit is contained in:
parent
6312a853d8
commit
f44b63eb78
7 changed files with 93 additions and 16 deletions
|
@ -161,6 +161,11 @@ class AttackInputData:
|
|||
logits_train: np.ndarray = None
|
||||
logits_test: np.ndarray = None
|
||||
|
||||
# Predicted probabilities for each class. They can be derived from logits,
|
||||
# so they can be set only if logits are not explicitly provided.
|
||||
probs_train: np.ndarray = None
|
||||
probs_test: np.ndarray = None
|
||||
|
||||
# Contains ground-truth classes. Classes are assumed to be integers starting
|
||||
# from 0.
|
||||
labels_train: np.ndarray = None
|
||||
|
@ -185,6 +190,16 @@ class AttackInputData:
|
|||
'Please set labels_train and labels_test')
|
||||
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
|
||||
|
||||
@property
|
||||
def logits_or_probs_train(self):
|
||||
"""Returns train logits or probs whatever is not None."""
|
||||
return self.logits_train if self.probs_train is None else self.probs_train
|
||||
|
||||
@property
|
||||
def logits_or_probs_test(self):
|
||||
"""Returns test logits or probs whatever is not None."""
|
||||
return self.logits_test if self.probs_test is None else self.probs_test
|
||||
|
||||
@staticmethod
|
||||
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
|
||||
"""Computes the prediction entropy (by Song and Mittal)."""
|
||||
|
@ -199,7 +214,7 @@ class AttackInputData:
|
|||
# See the Equation (7) in https://arxiv.org/pdf/2003.10595.pdf
|
||||
return np.sum(np.multiply(probs, _log_value(probs)), axis=1)
|
||||
else:
|
||||
# When given the groud truth label, we compute the
|
||||
# When given the ground truth label, we compute the
|
||||
# modified prediction entropy.
|
||||
# See the Equation (8) in https://arxiv.org/pdf/2003.10595.pdf
|
||||
log_probs = _log_value(probs)
|
||||
|
@ -218,15 +233,21 @@ class AttackInputData:
|
|||
def get_loss_train(self):
|
||||
"""Calculates (if needed) cross-entropy losses for the training set."""
|
||||
if self.loss_train is None:
|
||||
if self.logits_train is not None:
|
||||
self.loss_train = utils.log_loss_from_logits(self.labels_train,
|
||||
self.logits_train)
|
||||
else:
|
||||
self.loss_train = utils.log_loss(self.labels_train, self.probs_train)
|
||||
return self.loss_train
|
||||
|
||||
def get_loss_test(self):
|
||||
"""Calculates (if needed) cross-entropy losses for the test set."""
|
||||
if self.loss_test is None:
|
||||
if self.logits_train is not None:
|
||||
self.loss_test = utils.log_loss_from_logits(self.labels_test,
|
||||
self.logits_test)
|
||||
else:
|
||||
self.loss_test = utils.log_loss(self.labels_test, self.probs_test)
|
||||
return self.loss_test
|
||||
|
||||
def get_entropy_train(self):
|
||||
|
@ -267,6 +288,13 @@ class AttackInputData:
|
|||
raise ValueError(
|
||||
'logits_train and logits_test should both be either set or unset')
|
||||
|
||||
if (self.probs_train is None) != (self.probs_test is None):
|
||||
raise ValueError(
|
||||
'probs_train and probs_test should both be either set or unset')
|
||||
|
||||
if (self.logits_train is not None) and (self.probs_train is not None):
|
||||
raise ValueError('Logits and probs can not be both set')
|
||||
|
||||
if (self.labels_train is None) != (self.labels_test is None):
|
||||
raise ValueError(
|
||||
'labels_train and labels_test should both be either set or unset')
|
||||
|
@ -286,6 +314,8 @@ class AttackInputData:
|
|||
|
||||
_is_np_array(self.logits_train, 'logits_train')
|
||||
_is_np_array(self.logits_test, 'logits_test')
|
||||
_is_np_array(self.probs_train, 'probs_train')
|
||||
_is_np_array(self.probs_test, 'probs_test')
|
||||
_is_np_array(self.labels_train, 'labels_train')
|
||||
_is_np_array(self.labels_test, 'labels_test')
|
||||
_is_np_array(self.loss_train, 'loss_train')
|
||||
|
@ -295,6 +325,8 @@ class AttackInputData:
|
|||
|
||||
_is_last_dim_equal(self.logits_train, 'logits_train', self.logits_test,
|
||||
'logits_test')
|
||||
_is_last_dim_equal(self.probs_train, 'probs_train', self.probs_test,
|
||||
'probs_test')
|
||||
_is_array_one_dimensional(self.loss_train, 'loss_train')
|
||||
_is_array_one_dimensional(self.loss_test, 'loss_test')
|
||||
_is_array_one_dimensional(self.entropy_train, 'entropy_train')
|
||||
|
@ -311,6 +343,8 @@ class AttackInputData:
|
|||
_append_array_shape(self.entropy_test, 'entropy_test', result)
|
||||
_append_array_shape(self.logits_train, 'logits_train', result)
|
||||
_append_array_shape(self.logits_test, 'logits_test', result)
|
||||
_append_array_shape(self.probs_train, 'probs_train', result)
|
||||
_append_array_shape(self.probs_test, 'probs_test', result)
|
||||
_append_array_shape(self.labels_train, 'labels_train', result)
|
||||
_append_array_shape(self.labels_test, 'labels_test', result)
|
||||
result.append(')')
|
||||
|
|
|
@ -46,7 +46,7 @@ class SingleSliceSpecTest(parameterized.TestCase):
|
|||
|
||||
class AttackInputDataTest(absltest.TestCase):
|
||||
|
||||
def test_get_loss(self):
|
||||
def test_get_loss_from_logits(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]),
|
||||
logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]),
|
||||
|
@ -58,6 +58,18 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7)
|
||||
|
||||
def test_get_loss_from_probs(self):
|
||||
attack_input = AttackInputData(
|
||||
probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]),
|
||||
probs_test=np.array([[0, 0.0001, 0.9999], [0.07, 0.18, 0.75]]),
|
||||
labels_train=np.array([1, 0]),
|
||||
labels_test=np.array([0, 2]))
|
||||
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_train(), [2.30258509, 0.2231436], atol=1e-7)
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_test(), [18.42068074, 0.28768207], atol=1e-7)
|
||||
|
||||
def test_get_loss_explicitly_provided(self):
|
||||
attack_input = AttackInputData(
|
||||
loss_train=np.array([1.0, 3.0, 6.0]),
|
||||
|
@ -99,6 +111,8 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
def test_validator(self):
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(logits_train=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(probs_train=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(labels_train=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
|
@ -107,6 +121,8 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
AttackInputData(entropy_train=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(logits_test=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(probs_test=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(labels_test=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
|
@ -114,6 +130,14 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
self.assertRaises(ValueError,
|
||||
AttackInputData(entropy_test=np.array([])).validate)
|
||||
self.assertRaises(ValueError, AttackInputData().validate)
|
||||
# Tests that having both logits and probs are not allowed.
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
AttackInputData(
|
||||
logits_train=np.array([]),
|
||||
logits_test=np.array([]),
|
||||
probs_train=np.array([]),
|
||||
probs_test=np.array([])).validate)
|
||||
|
||||
|
||||
class RocCurveTest(absltest.TestCase):
|
||||
|
|
|
@ -38,11 +38,13 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
|
|||
|
||||
# Slice train data.
|
||||
result.logits_train = _slice_if_not_none(data.logits_train, idx_train)
|
||||
result.probs_train = _slice_if_not_none(data.probs_train, idx_train)
|
||||
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
|
||||
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
|
||||
|
||||
# Slice test data.
|
||||
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
|
||||
result.probs_test = _slice_if_not_none(data.probs_test, idx_test)
|
||||
result.labels_test = _slice_if_not_none(data.labels_test, idx_test)
|
||||
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
|
||||
|
||||
|
|
|
@ -106,13 +106,24 @@ class GetSliceTest(absltest.TestCase):
|
|||
# Create test data for 3 class classification task.
|
||||
logits_train = np.array([[0, 1, 0], [2, 0, 3], [4, 5, 0], [6, 7, 0]])
|
||||
logits_test = np.array([[10, 0, 11], [12, 13, 0], [14, 15, 0], [0, 16, 17]])
|
||||
probs_train = np.array([[0, 1, 0], [0.1, 0, 0.7], [0.4, 0.6, 0],
|
||||
[0.3, 0.7, 0]])
|
||||
probs_test = np.array([[0.4, 0, 0.6], [0.1, 0.9, 0], [0.15, 0.85, 0],
|
||||
[0, 0, 1]])
|
||||
labels_train = np.array([1, 0, 1, 2])
|
||||
labels_test = np.array([1, 2, 0, 2])
|
||||
loss_train = np.array([2, 0.25, 4, 3])
|
||||
loss_test = np.array([0.5, 3.5, 7, 4.5])
|
||||
|
||||
self.input_data = AttackInputData(logits_train, logits_test, labels_train,
|
||||
labels_test, loss_train, loss_test)
|
||||
self.input_data = AttackInputData(
|
||||
logits_train=logits_train,
|
||||
logits_test=logits_test,
|
||||
probs_train=probs_train,
|
||||
probs_test=probs_test,
|
||||
labels_train=labels_train,
|
||||
labels_test=labels_test,
|
||||
loss_train=loss_train,
|
||||
loss_test=loss_test)
|
||||
|
||||
def test_slice_entire_dataset(self):
|
||||
entire_dataset_slice = SingleSliceSpec()
|
||||
|
@ -131,6 +142,11 @@ class GetSliceTest(absltest.TestCase):
|
|||
self.assertLen(output.logits_test, 1)
|
||||
self.assertTrue((output.logits_train[1] == [4, 5, 0]).all())
|
||||
|
||||
# Check probs.
|
||||
self.assertLen(output.probs_train, 2)
|
||||
self.assertLen(output.probs_test, 1)
|
||||
self.assertTrue((output.probs_train[1] == [0.4, 0.6, 0]).all())
|
||||
|
||||
# Check labels.
|
||||
self.assertLen(output.labels_train, 2)
|
||||
self.assertLen(output.labels_test, 1)
|
||||
|
|
|
@ -101,7 +101,7 @@ model.fit(
|
|||
to_categorical(training_labels, num_clusters),
|
||||
validation_data=(test_features, to_categorical(test_labels, num_clusters)),
|
||||
batch_size=64,
|
||||
epochs=10,
|
||||
epochs=2,
|
||||
shuffle=True)
|
||||
|
||||
training_pred = model.predict(training_features)
|
||||
|
@ -126,8 +126,8 @@ attack_results = mia.run_attacks(
|
|||
AttackInputData(
|
||||
labels_train=training_labels,
|
||||
labels_test=test_labels,
|
||||
logits_train=training_pred,
|
||||
logits_test=test_pred,
|
||||
probs_train=training_pred,
|
||||
probs_test=test_pred,
|
||||
loss_train=crossentropy(training_labels, training_pred),
|
||||
loss_test=crossentropy(test_labels, test_pred)),
|
||||
SlicingSpec(entire_dataset=True, by_class=True),
|
||||
|
|
|
@ -28,9 +28,10 @@ def get_test_input(n_train, n_test):
|
|||
"""Get example inputs for attacks."""
|
||||
rng = np.random.RandomState(4)
|
||||
return AttackInputData(
|
||||
rng.randn(n_train, 5) + 0.2,
|
||||
rng.randn(n_test, 5) + 0.2, np.array([i % 5 for i in range(n_train)]),
|
||||
np.array([i % 5 for i in range(n_test)]))
|
||||
logits_train=rng.randn(n_train, 5) + 0.2,
|
||||
logits_test=rng.randn(n_test, 5) + 0.2,
|
||||
labels_train=np.array([i % 5 for i in range(n_train)]),
|
||||
labels_test=np.array([i % 5 for i in range(n_test)]))
|
||||
|
||||
|
||||
class RunAttacksTest(absltest.TestCase):
|
||||
|
|
|
@ -55,9 +55,9 @@ def create_attacker_data(attack_input_data: AttackInputData,
|
|||
Returns:
|
||||
AttackerData.
|
||||
"""
|
||||
attack_input_train = _column_stack(attack_input_data.logits_train,
|
||||
attack_input_train = _column_stack(attack_input_data.logits_or_probs_train,
|
||||
attack_input_data.get_loss_train())
|
||||
attack_input_test = _column_stack(attack_input_data.logits_test,
|
||||
attack_input_test = _column_stack(attack_input_data.logits_or_probs_test,
|
||||
attack_input_data.get_loss_test())
|
||||
|
||||
features_all = np.concatenate((attack_input_train, attack_input_test))
|
||||
|
|
Loading…
Reference in a new issue