Add probabilities to AttackInputData.

PiperOrigin-RevId: 330723370
This commit is contained in:
Vadym Doroshenko 2020-09-09 08:05:28 -07:00 committed by A. Unique TensorFlower
parent 6312a853d8
commit f44b63eb78
7 changed files with 93 additions and 16 deletions

View file

@ -161,6 +161,11 @@ class AttackInputData:
logits_train: np.ndarray = None
logits_test: np.ndarray = None
# Predicted probabilities for each class. They can be derived from logits,
# so they can be set only if logits are not explicitly provided.
probs_train: np.ndarray = None
probs_test: np.ndarray = None
# Contains ground-truth classes. Classes are assumed to be integers starting
# from 0.
labels_train: np.ndarray = None
@ -185,6 +190,16 @@ class AttackInputData:
'Please set labels_train and labels_test')
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
@property
def logits_or_probs_train(self):
"""Returns train logits or probs whatever is not None."""
return self.logits_train if self.probs_train is None else self.probs_train
@property
def logits_or_probs_test(self):
"""Returns test logits or probs whatever is not None."""
return self.logits_test if self.probs_test is None else self.probs_test
@staticmethod
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
"""Computes the prediction entropy (by Song and Mittal)."""
@ -199,7 +214,7 @@ class AttackInputData:
# See the Equation (7) in https://arxiv.org/pdf/2003.10595.pdf
return np.sum(np.multiply(probs, _log_value(probs)), axis=1)
else:
# When given the groud truth label, we compute the
# When given the ground truth label, we compute the
# modified prediction entropy.
# See the Equation (8) in https://arxiv.org/pdf/2003.10595.pdf
log_probs = _log_value(probs)
@ -218,15 +233,21 @@ class AttackInputData:
def get_loss_train(self):
"""Calculates (if needed) cross-entropy losses for the training set."""
if self.loss_train is None:
self.loss_train = utils.log_loss_from_logits(self.labels_train,
self.logits_train)
if self.logits_train is not None:
self.loss_train = utils.log_loss_from_logits(self.labels_train,
self.logits_train)
else:
self.loss_train = utils.log_loss(self.labels_train, self.probs_train)
return self.loss_train
def get_loss_test(self):
"""Calculates (if needed) cross-entropy losses for the test set."""
if self.loss_test is None:
self.loss_test = utils.log_loss_from_logits(self.labels_test,
self.logits_test)
if self.logits_train is not None:
self.loss_test = utils.log_loss_from_logits(self.labels_test,
self.logits_test)
else:
self.loss_test = utils.log_loss(self.labels_test, self.probs_test)
return self.loss_test
def get_entropy_train(self):
@ -267,6 +288,13 @@ class AttackInputData:
raise ValueError(
'logits_train and logits_test should both be either set or unset')
if (self.probs_train is None) != (self.probs_test is None):
raise ValueError(
'probs_train and probs_test should both be either set or unset')
if (self.logits_train is not None) and (self.probs_train is not None):
raise ValueError('Logits and probs can not be both set')
if (self.labels_train is None) != (self.labels_test is None):
raise ValueError(
'labels_train and labels_test should both be either set or unset')
@ -286,6 +314,8 @@ class AttackInputData:
_is_np_array(self.logits_train, 'logits_train')
_is_np_array(self.logits_test, 'logits_test')
_is_np_array(self.probs_train, 'probs_train')
_is_np_array(self.probs_test, 'probs_test')
_is_np_array(self.labels_train, 'labels_train')
_is_np_array(self.labels_test, 'labels_test')
_is_np_array(self.loss_train, 'loss_train')
@ -295,6 +325,8 @@ class AttackInputData:
_is_last_dim_equal(self.logits_train, 'logits_train', self.logits_test,
'logits_test')
_is_last_dim_equal(self.probs_train, 'probs_train', self.probs_test,
'probs_test')
_is_array_one_dimensional(self.loss_train, 'loss_train')
_is_array_one_dimensional(self.loss_test, 'loss_test')
_is_array_one_dimensional(self.entropy_train, 'entropy_train')
@ -311,6 +343,8 @@ class AttackInputData:
_append_array_shape(self.entropy_test, 'entropy_test', result)
_append_array_shape(self.logits_train, 'logits_train', result)
_append_array_shape(self.logits_test, 'logits_test', result)
_append_array_shape(self.probs_train, 'probs_train', result)
_append_array_shape(self.probs_test, 'probs_test', result)
_append_array_shape(self.labels_train, 'labels_train', result)
_append_array_shape(self.labels_test, 'labels_test', result)
result.append(')')

View file

@ -46,7 +46,7 @@ class SingleSliceSpecTest(parameterized.TestCase):
class AttackInputDataTest(absltest.TestCase):
def test_get_loss(self):
def test_get_loss_from_logits(self):
attack_input = AttackInputData(
logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]),
logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]),
@ -58,6 +58,18 @@ class AttackInputDataTest(absltest.TestCase):
np.testing.assert_allclose(
attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7)
def test_get_loss_from_probs(self):
attack_input = AttackInputData(
probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]),
probs_test=np.array([[0, 0.0001, 0.9999], [0.07, 0.18, 0.75]]),
labels_train=np.array([1, 0]),
labels_test=np.array([0, 2]))
np.testing.assert_allclose(
attack_input.get_loss_train(), [2.30258509, 0.2231436], atol=1e-7)
np.testing.assert_allclose(
attack_input.get_loss_test(), [18.42068074, 0.28768207], atol=1e-7)
def test_get_loss_explicitly_provided(self):
attack_input = AttackInputData(
loss_train=np.array([1.0, 3.0, 6.0]),
@ -99,6 +111,8 @@ class AttackInputDataTest(absltest.TestCase):
def test_validator(self):
self.assertRaises(ValueError,
AttackInputData(logits_train=np.array([])).validate)
self.assertRaises(ValueError,
AttackInputData(probs_train=np.array([])).validate)
self.assertRaises(ValueError,
AttackInputData(labels_train=np.array([])).validate)
self.assertRaises(ValueError,
@ -107,6 +121,8 @@ class AttackInputDataTest(absltest.TestCase):
AttackInputData(entropy_train=np.array([])).validate)
self.assertRaises(ValueError,
AttackInputData(logits_test=np.array([])).validate)
self.assertRaises(ValueError,
AttackInputData(probs_test=np.array([])).validate)
self.assertRaises(ValueError,
AttackInputData(labels_test=np.array([])).validate)
self.assertRaises(ValueError,
@ -114,6 +130,14 @@ class AttackInputDataTest(absltest.TestCase):
self.assertRaises(ValueError,
AttackInputData(entropy_test=np.array([])).validate)
self.assertRaises(ValueError, AttackInputData().validate)
# Tests that having both logits and probs are not allowed.
self.assertRaises(
ValueError,
AttackInputData(
logits_train=np.array([]),
logits_test=np.array([]),
probs_train=np.array([]),
probs_test=np.array([])).validate)
class RocCurveTest(absltest.TestCase):

View file

@ -38,11 +38,13 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
# Slice train data.
result.logits_train = _slice_if_not_none(data.logits_train, idx_train)
result.probs_train = _slice_if_not_none(data.probs_train, idx_train)
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
# Slice test data.
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
result.probs_test = _slice_if_not_none(data.probs_test, idx_test)
result.labels_test = _slice_if_not_none(data.labels_test, idx_test)
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)

View file

@ -106,13 +106,24 @@ class GetSliceTest(absltest.TestCase):
# Create test data for 3 class classification task.
logits_train = np.array([[0, 1, 0], [2, 0, 3], [4, 5, 0], [6, 7, 0]])
logits_test = np.array([[10, 0, 11], [12, 13, 0], [14, 15, 0], [0, 16, 17]])
probs_train = np.array([[0, 1, 0], [0.1, 0, 0.7], [0.4, 0.6, 0],
[0.3, 0.7, 0]])
probs_test = np.array([[0.4, 0, 0.6], [0.1, 0.9, 0], [0.15, 0.85, 0],
[0, 0, 1]])
labels_train = np.array([1, 0, 1, 2])
labels_test = np.array([1, 2, 0, 2])
loss_train = np.array([2, 0.25, 4, 3])
loss_test = np.array([0.5, 3.5, 7, 4.5])
self.input_data = AttackInputData(logits_train, logits_test, labels_train,
labels_test, loss_train, loss_test)
self.input_data = AttackInputData(
logits_train=logits_train,
logits_test=logits_test,
probs_train=probs_train,
probs_test=probs_test,
labels_train=labels_train,
labels_test=labels_test,
loss_train=loss_train,
loss_test=loss_test)
def test_slice_entire_dataset(self):
entire_dataset_slice = SingleSliceSpec()
@ -131,6 +142,11 @@ class GetSliceTest(absltest.TestCase):
self.assertLen(output.logits_test, 1)
self.assertTrue((output.logits_train[1] == [4, 5, 0]).all())
# Check probs.
self.assertLen(output.probs_train, 2)
self.assertLen(output.probs_test, 1)
self.assertTrue((output.probs_train[1] == [0.4, 0.6, 0]).all())
# Check labels.
self.assertLen(output.labels_train, 2)
self.assertLen(output.labels_test, 1)

View file

@ -101,7 +101,7 @@ model.fit(
to_categorical(training_labels, num_clusters),
validation_data=(test_features, to_categorical(test_labels, num_clusters)),
batch_size=64,
epochs=10,
epochs=2,
shuffle=True)
training_pred = model.predict(training_features)
@ -126,8 +126,8 @@ attack_results = mia.run_attacks(
AttackInputData(
labels_train=training_labels,
labels_test=test_labels,
logits_train=training_pred,
logits_test=test_pred,
probs_train=training_pred,
probs_test=test_pred,
loss_train=crossentropy(training_labels, training_pred),
loss_test=crossentropy(test_labels, test_pred)),
SlicingSpec(entire_dataset=True, by_class=True),

View file

@ -28,9 +28,10 @@ def get_test_input(n_train, n_test):
"""Get example inputs for attacks."""
rng = np.random.RandomState(4)
return AttackInputData(
rng.randn(n_train, 5) + 0.2,
rng.randn(n_test, 5) + 0.2, np.array([i % 5 for i in range(n_train)]),
np.array([i % 5 for i in range(n_test)]))
logits_train=rng.randn(n_train, 5) + 0.2,
logits_test=rng.randn(n_test, 5) + 0.2,
labels_train=np.array([i % 5 for i in range(n_train)]),
labels_test=np.array([i % 5 for i in range(n_test)]))
class RunAttacksTest(absltest.TestCase):

View file

@ -55,9 +55,9 @@ def create_attacker_data(attack_input_data: AttackInputData,
Returns:
AttackerData.
"""
attack_input_train = _column_stack(attack_input_data.logits_train,
attack_input_train = _column_stack(attack_input_data.logits_or_probs_train,
attack_input_data.get_loss_train())
attack_input_test = _column_stack(attack_input_data.logits_test,
attack_input_test = _column_stack(attack_input_data.logits_or_probs_test,
attack_input_data.get_loss_test())
features_all = np.concatenate((attack_input_train, attack_input_test))