Add multi-label support for Tensorflow Privacy membership attacks.

PiperOrigin-RevId: 443176652
This commit is contained in:
A. Unique TensorFlower 2022-04-20 13:22:55 -07:00
parent e14618fe7c
commit ee35642b90
9 changed files with 676 additions and 36 deletions

View file

@ -17,6 +17,7 @@ import collections
import dataclasses
import enum
import glob
import logging
import os
import pickle
from typing import Any, Callable, Iterable, MutableSequence, Optional, Union
@ -115,6 +116,7 @@ class AttackType(enum.Enum):
K_NEAREST_NEIGHBORS = 'knn'
THRESHOLD_ATTACK = 'threshold'
THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy'
TF_LOGISTIC_REGRESSION = 'tf_lr'
@property
def is_trained_attack(self):
@ -154,6 +156,19 @@ def _is_array_one_dimensional(arr, arr_name):
raise ValueError('%s should be a one dimensional numpy array.' % arr_name)
def _is_array_two_dimensional(arr, arr_name):
"""Checks whether the array is two dimensional."""
if arr is not None and len(arr.shape) != 2:
raise ValueError('%s should be a two dimensional numpy array.' % arr_name)
def _is_array_one_or_two_dimensional(arr, arr_name):
"""Checks whether the array is one or two dimensional."""
if arr is not None and len(arr.shape) not in [1, 2]:
raise ValueError(
('%s should be a one or two dimensional numpy array.' % arr_name))
def _is_np_array(arr, arr_name):
"""Checks whether array is a numpy array."""
if arr is not None and not isinstance(arr, np.ndarray):
@ -213,6 +228,14 @@ class AttackInputData:
# preferred when both are available.
loss_function_using_logits: Optional[bool] = None
# If the problem is a multilabel classification problem. If this is set then
# the loss function and attack data construction are adjusted accordingly. In
# this case the provided labels must be multi-hot encoded. That is, the labels
# are an array of shape (num_examples, num_classes) with 0s where the
# corresponding class is absent from the example, and 1s where the
# corresponding class is present.
multilabel_data: Optional[bool] = None
@property
def num_classes(self):
if self.labels_train is None or self.labels_test is None:
@ -266,12 +289,12 @@ class AttackInputData:
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
@staticmethod
def _get_loss(
loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
LossFunction],
loss_function_using_logits: Optional[bool]) -> Optional[np.ndarray]:
def _get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray],
np.ndarray], LossFunction],
loss_function_using_logits: Optional[bool],
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
"""Calculates (if needed) losses.
Args:
@ -284,6 +307,7 @@ class AttackInputData:
is supposed to take in (label, logits / probs) as input.
loss_function_using_logits: if `loss_function` expects `logits` or
`probs`.
multilabel_data: if the data is from a multilabel classification problem.
Returns:
Loss (or None if neither the loss nor the labels are present).
@ -299,13 +323,23 @@ class AttackInputData:
predictions = logits if loss_function_using_logits else probs
if loss_function == LossFunction.CROSS_ENTROPY:
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
if multilabel_data:
loss = utils.multilabel_bce_loss(labels, predictions,
loss_function_using_logits)
else:
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
elif loss_function == LossFunction.SQUARED:
loss = utils.squared_loss(labels, predictions)
else:
loss = loss_function(labels, predictions)
return loss
def __post_init__(self):
"""Checks performed after instantiation of the AttackInputData dataclass."""
# Check if the data is multilabel.
_ = self.is_multilabel_data()
# The validate() check is called as needed, so is not called here.
def get_loss_train(self):
"""Calculates (if needed) cross-entropy losses for the training set.
@ -316,7 +350,7 @@ class AttackInputData:
self.loss_function_using_logits = (self.logits_train is not None)
return self._get_loss(self.loss_train, self.labels_train, self.logits_train,
self.probs_train, self.loss_function,
self.loss_function_using_logits)
self.loss_function_using_logits, self.multilabel_data)
def get_loss_test(self):
"""Calculates (if needed) cross-entropy losses for the test set.
@ -328,35 +362,143 @@ class AttackInputData:
self.loss_function_using_logits = bool(self.logits_test)
return self._get_loss(self.loss_test, self.labels_test, self.logits_test,
self.probs_test, self.loss_function,
self.loss_function_using_logits)
self.loss_function_using_logits, self.multilabel_data)
def get_entropy_train(self):
"""Calculates prediction entropy for the training set."""
if self.is_multilabel_data():
# Not implemented for multilabel data.
raise NotImplementedError('Computation of the entropy is not '
'applicable for multi-label data.')
if self.entropy_train is not None:
return self.entropy_train
return self._get_entropy(self.logits_train, self.labels_train)
def get_entropy_test(self):
"""Calculates prediction entropy for the test set."""
if self.is_multilabel_data():
# Not implemented for multilabel data.
raise NotImplementedError('Computation of the entropy is not '
'applicable for multi-label data.')
if self.entropy_test is not None:
return self.entropy_test
return self._get_entropy(self.logits_test, self.labels_test)
def get_train_size(self):
"""Returns size of the training set."""
def get_train_shape(self):
"""Returns the shape of the training set."""
if self.loss_train is not None:
return self.loss_train.size
return self.loss_train.shape
if self.entropy_train is not None:
return self.entropy_train.size
return self.logits_or_probs_train.shape[0]
return self.entropy_train.shape
return self.logits_or_probs_train.shape
def get_test_shape(self):
"""Returns the shape of the test set."""
if self.loss_test is not None:
return self.loss_test.shape
if self.entropy_test is not None:
return self.entropy_test.shape
return self.logits_or_probs_test.shape
def get_train_size(self):
"""Returns the number of examples of the training set."""
return self.get_train_shape()[0]
def get_test_size(self):
"""Returns size of the test set."""
if self.loss_test is not None:
return self.loss_test.size
if self.entropy_test is not None:
return self.entropy_test.size
return self.logits_or_probs_test.shape[0]
"""Returns the number of examples of the test set."""
return self.get_test_shape()[0]
def is_multihot_labels(self, arr, arr_name) -> bool:
"""Check if the 2D array is multihot, with values in [0, 1].
Array is multihot if the sum along the classes axis (axis=1) produces a
vector of at least one value > 1.
Args:
arr: Array to test.
arr_name: Name of the array.
Returns:
True if the array is a 2D multihot array.
Raises:
ValueError if the array is not 2D.
"""
if arr is None:
return False
elif len(arr.shape) > 2:
raise ValueError(f'Array {arr_name} is not 2D, cannot determine whether '
'it is multihot.')
elif len(arr.shape) == 1:
return False
summed_arr = np.sum(arr, axis=1)
return not ((summed_arr == 0) | (summed_arr == 1)).all()
def is_multilabel_data(self) -> bool:
"""Check if the provided data is for a multilabel classification problem.
Data is multilabel if all of the following are true:
1. Train and test sizes are 2-dimensional: (num_samples, num_classes)
2. Label size is 2-dimentionsl: (num_samples, num_classes)
3. The labels are multi-hot. The labels are multihot if sum{label_tensor}
along axis=1 (the classes axis) yields a vector of at least one
value > 1.
Returns:
Whether the provided data is multilabel.
Raises:
ValueError if the dimensionality of the train and test data are not equal.
"""
# If the data has already been checked for multihot encoded labels, then
# return the result of the evaluation.
if self.multilabel_data is not None:
return self.multilabel_data
# If one of probs or logits are not provided, or labels are not provided,
# this is not a multilabel problem
if (self.logits_or_probs_train is None or
self.logits_or_probs_test is None or self.labels_train is None or
self.labels_test is None):
self.multilabel_data = False
return self.multilabel_data
train_shape = self.get_train_shape()
test_shape = self.get_test_shape()
label_train_shape = self.labels_train.shape
label_test_shape = self.labels_test.shape
if len(train_shape) != len(test_shape):
raise ValueError('The number of dimensions of the train data '
f'({train_shape}) is not the same as that of the test '
f'data ({test_shape}).')
if len(train_shape) not in [1, 2] or len(test_shape) not in [1, 2]:
raise ValueError(('Train and test data shapes must be 1-D '
'(number of samples) or 2-D (number of samples, '
'number of classes).'))
if len(label_train_shape) != len(label_test_shape):
raise ValueError('The number of dimensions of the train labels '
f'({label_train_shape}) is not the same as that of the '
f'test labels ({label_test_shape}).')
if (len(label_train_shape) not in [1, 2] or
len(label_test_shape) not in [1, 2]):
raise ValueError('Train and test labels shapes must be 1-D '
'(number of samples) or 2-D (number of samples, '
'number of classes).')
data_is_2d = len(train_shape) == len(test_shape) == 2
if data_is_2d:
equal_feature_count = train_shape[1] == test_shape[1]
else:
equal_feature_count = False
labels_are_2d = len(label_train_shape) == len(label_test_shape) == 2
labels_train_are_multihot = self.is_multihot_labels(self.labels_train,
'labels_train')
labels_test_are_multihot = self.is_multihot_labels(self.labels_test,
'labels_test')
self.multilabel_data = (
data_is_2d and labels_are_2d and equal_feature_count and
labels_train_are_multihot and labels_test_are_multihot)
return self.multilabel_data
def validate(self):
"""Validates the inputs."""
@ -413,12 +555,30 @@ class AttackInputData:
'logits_test')
_is_last_dim_equal(self.probs_train, 'probs_train', self.probs_test,
'probs_test')
_is_array_one_dimensional(self.loss_train, 'loss_train')
_is_array_one_dimensional(self.loss_test, 'loss_test')
_is_array_one_dimensional(self.entropy_train, 'entropy_train')
_is_array_one_dimensional(self.entropy_test, 'entropy_test')
_is_array_one_dimensional(self.labels_train, 'labels_train')
_is_array_one_dimensional(self.labels_test, 'labels_test')
if self.is_multilabel_data():
# Validation for multi-label data.
# Verify that both logits and probabilities have the same number of
# classes.
_is_last_dim_equal(self.logits_train, 'logits_train', self.probs_train,
'probs_train')
_is_last_dim_equal(self.logits_test, 'logits_test', self.probs_test,
'probs_test')
# Check that losses, labels and entropies are 2D
# (num_samples, num_classes).
_is_array_two_dimensional(self.loss_train, 'loss_train')
_is_array_two_dimensional(self.loss_test, 'loss_test')
_is_array_two_dimensional(self.entropy_train, 'entropy_train')
_is_array_two_dimensional(self.entropy_test, 'entropy_test')
_is_array_two_dimensional(self.labels_train, 'labels_train')
_is_array_two_dimensional(self.labels_test, 'labels_test')
else:
_is_array_one_dimensional(self.loss_train, 'loss_train')
_is_array_one_dimensional(self.loss_test, 'loss_test')
_is_array_one_dimensional(self.entropy_train, 'entropy_train')
_is_array_one_dimensional(self.entropy_test, 'entropy_test')
_is_array_one_dimensional(self.labels_train, 'labels_train')
_is_array_one_dimensional(self.labels_test, 'labels_test')
def __str__(self):
"""Return the shapes of variables that are not None."""
@ -784,8 +944,8 @@ class AttackResults:
aucs = [result.get_auc() for result in self.single_attack_results]
if min(aucs) < 0.4:
print('Suspiciously low AUC detected: %.2f. ' +
'There might be a bug in the classifier' % min(aucs))
logging.info(('Suspiciously low AUC detected: %.2f. '
'There might be a bug in the classifier'), min(aucs))
return self.single_attack_results[np.argmax(aucs)]

View file

@ -277,6 +277,87 @@ class AttackInputDataTest(parameterized.TestCase):
probs_train=np.array([]),
probs_test=np.array([])).validate)
def test_multilabel_validator(self):
# Tests for multilabel data.
with self.assertRaises(
ValueError,
msg='Validation passes incorrectly when `logits_test` is not 1D/2D.'):
AttackInputData(
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
logits_test=np.array([[[0.01, 1.5], [0.5, -3]],
[[0.01, 1.5], [0.5, -3]]]),
labels_train=np.array([[0, 0], [0, 1], [1, 0]]),
labels_test=np.array([[1, 1], [1, 0]]),
).validate()
self.assertTrue(
AttackInputData(
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
labels_test=np.array([[1, 1], [1, 0]]),
).is_multilabel_data(),
msg='Multilabel data check fails even though conditions are met.')
def test_multihot_labels_check_on_null_array_returns_false(self):
self.assertFalse(
AttackInputData(
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
labels_test=np.array([[1, 1], [1, 0]]),
).is_multihot_labels(None, 'null_array'),
msg='Multilabel test on a null array should return False.')
self.assertFalse(
AttackInputData(
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
labels_test=np.array([[1, 1], [1, 0]]),
).is_multihot_labels(np.array([1.0, 2.0, 3.0]), '1d_array'),
msg='Multilabel test on a 1-D array should return False.')
def test_multilabel_get_bce_loss_from_probs(self):
attack_input = AttackInputData(
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
probs_test=np.array([[0.8, 0.7, 0.9]]),
labels_train=np.array([[0, 1, 1], [1, 1, 0]]),
labels_test=np.array([[1, 1, 0]]))
np.testing.assert_allclose(
attack_input.get_loss_train(), [[0.22314343, 1.20397247, 0.3566748],
[0.22314343, 0.51082546, 2.30258409]],
atol=1e-6)
np.testing.assert_allclose(
attack_input.get_loss_test(), [[0.22314354, 0.35667493, 2.30258499]],
atol=1e-6)
def test_multilabel_get_bce_loss_from_logits(self):
attack_input = AttackInputData(
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
labels_test=np.array([[1, 1], [1, 0]]))
np.testing.assert_allclose(
attack_input.get_loss_train(),
[[0.31326167, 0.126928], [0.69815966, 0.20141327],
[0.47407697, 3.04858714]],
atol=1e-6)
np.testing.assert_allclose(
attack_input.get_loss_test(),
[[0.68815966, 0.20141327], [0.47407697, 0.04858734]],
atol=1e-6)
def test_multilabel_get_loss_explicitly_provided(self):
attack_input = AttackInputData(
loss_train=np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]),
loss_test=np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
np.testing.assert_equal(attack_input.get_loss_train().tolist(),
np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]))
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
class RocCurveTest(absltest.TestCase):

View file

@ -15,6 +15,7 @@
import collections
import copy
import logging
from typing import List, Optional
import numpy as np
@ -49,10 +50,21 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test)
# A slice has the same multilabel status as the original data. This is because
# of the way multilabel status is computed. A dataset is multilabel if at
# least 1 sample has a label that is multihot encoded with more than one
# positive class. A slice of this dataset could have only samples with labels
# that have only a single positive class, even if the original dataset were
# multilable. Therefore we ensure that the slice inherits the multilabel state
# of the original dataset.
result.multilabel_data = data.is_multilabel_data()
return result
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
if data.is_multilabel_data():
raise ValueError("Slicing by class not supported for multilabel data.")
idx_train = data.labels_train == class_value
idx_test = data.labels_test == class_value
return _slice_data_by_indices(data, idx_train, idx_test)
@ -65,6 +77,12 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
# Find from_percentile and to_percentile percentiles in losses.
loss_train = data.get_loss_train()
loss_test = data.get_loss_test()
if data.is_multilabel_data():
logging.info("For multilabel data, when slices by percentiles are "
"requested, losses are summed over the class axis before "
"slicing.")
loss_train = np.sum(loss_train, axis=1)
loss_test = np.sum(loss_test, axis=1)
losses = np.concatenate((loss_train, loss_test))
from_loss = np.percentile(losses, from_percentile)
to_loss = np.percentile(losses, to_percentile)
@ -82,6 +100,20 @@ def _indices_by_classification(logits_or_probs, labels, correctly_classified):
def _slice_by_classification_correctness(data: AttackInputData,
correctly_classified: bool):
"""Slices attack inputs by whether they were classified correctly.
Args:
data: Data to be used as input to the attack models.
correctly_classified: Whether to use the indices corresponding to the
correctly classified samples.
Returns:
AttackInputData object containing the sliced data.
"""
if data.is_multilabel_data():
raise ValueError("Slicing by classification correctness not supported for "
"multilabel data.")
idx_train = _indices_by_classification(data.logits_or_probs_train,
data.labels_train,
correctly_classified)

View file

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from absl.testing import absltest
from absl.testing.absltest import mock
import numpy as np
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
@ -71,7 +73,7 @@ class SingleSliceSpecsTest(absltest.TestCase):
self.assertTrue(_are_all_fields_equal(output[0], expected0))
self.assertTrue(_are_all_fields_equal(output[5], expected5))
def test_slice_by_correcness(self):
def test_slice_by_correctness(self):
input_data = SlicingSpec(
entire_dataset=False, by_classification_correctness=True)
expected = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True)
@ -199,5 +201,83 @@ class GetSliceTest(absltest.TestCase):
self.assertTrue((output.labels_test == [1, 2, 0]).all())
class GetSliceTestForMultilabelData(absltest.TestCase):
def __init__(self, methodname):
"""Initialize the test class."""
super().__init__(methodname)
# Create test data for 3 class multilabel classification task.
logits_train = np.array([[0, 1, 0], [2, 0, 3], [4, 5, 0], [6, 7, 0]])
logits_test = np.array([[10, 0, 11], [12, 13, 0], [14, 15, 0], [0, 16, 17]])
probs_train = np.array([[0, 1, 0], [0.1, 0, 0.7], [0.4, 0.6, 0],
[0.3, 0.7, 0]])
probs_test = np.array([[0.4, 0, 0.6], [0.1, 0.9, 0], [0.15, 0.85, 0],
[0, 0, 1]])
labels_train = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0], [0, 1, 0]])
labels_test = np.array([[1, 0, 1], [0, 1, 0], [0, 1, 0], [0, 0, 1]])
loss_train = np.array([[0, 0, 2], [3, 0, 0.3], [2, 0.5, 0], [1.5, 2, 3]])
loss_test = np.array([[1.5, 0, 1.0], [0.5, 0.8, 0], [0.5, 3, 0], [0, 0, 0]])
entropy_train = np.array([[0.2, 0.2, 5], [10, 0, 2], [7, 3, 0], [6, 8, 9]])
entropy_test = np.array([[3, 0, 2], [3, 6, 0], [2, 6, 0], [0, 0, 0]])
self.input_data = AttackInputData(
logits_train=logits_train,
logits_test=logits_test,
probs_train=probs_train,
probs_test=probs_test,
labels_train=labels_train,
labels_test=labels_test,
loss_train=loss_train,
loss_test=loss_test,
entropy_train=entropy_train,
entropy_test=entropy_test)
def test_slice_entire_dataset(self):
entire_dataset_slice = SingleSliceSpec()
output = get_slice(self.input_data, entire_dataset_slice)
expected = self.input_data
expected.slice_spec = entire_dataset_slice
self.assertTrue(_are_all_fields_equal(output, self.input_data))
def test_slice_by_class_fails(self):
class_index = 1
class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index)
self.assertRaises(ValueError, get_slice, self.input_data, class_slice)
@mock.patch('logging.Logger.info', wraps=logging.Logger)
def test_slice_by_percentile_logs_multilabel_data(self, mock_logger):
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
_ = get_slice(self.input_data, percentile_slice)
mock_logger.assert_called_with(
('For multilabel data, when slices by percentiles are '
'requested, losses are summed over the class axis before '
'slicing.'))
def test_slice_by_percentile(self):
# 50th percentile is the lower 50% of losses summed over the classes.
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
output = get_slice(self.input_data, percentile_slice)
# Check logits.
with self.subTest(msg='Check logits'):
self.assertLen(output.logits_train, 2)
self.assertLen(output.logits_test, 3)
self.assertTrue((output.logits_test[0] == [10, 0, 11]).all())
# Check labels.
with self.subTest(msg='Check labels'):
self.assertLen(output.labels_train, 2)
self.assertLen(output.labels_test, 3)
self.assertTrue((output.labels_train == [[0, 1, 1], [1, 1, 0]]).all())
self.assertTrue((output.labels_test == [[1, 0, 1], [0, 1, 0], [0, 0,
1]]).all())
def test_slice_by_correctness_fails(self):
percentile_slice = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED,
False)
self.assertRaises(ValueError, get_slice, self.input_data, percentile_slice)
if __name__ == '__main__':
absltest.main()

View file

@ -17,9 +17,11 @@ This file belongs to the new API for membership inference attacks. This file
will be renamed to membership_inference_attack.py after the old API is removed.
"""
from typing import Iterable
import logging
from typing import Iterable, List, Union
import numpy as np
from scipy import special
from sklearn import metrics
from sklearn import model_selection
@ -39,6 +41,9 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.datase
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.dataset_slicing import get_slice
ArrayLike = Union[np.ndarray, List]
def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
if hasattr(data, 'slice_spec'):
return data.slice_spec
@ -81,7 +86,8 @@ def _run_trained_attack(attack_input: AttackInputData,
attacker = models.create_attacker(attack_type)
attacker.train_model(features[train_indices], labels[train_indices])
scores[test_indices] = attacker.predict(features[test_indices])
predictions = attacker.predict(features[test_indices])
scores[test_indices] = predictions
# Predict the left out with the last attacker
if left_out_indices.size:
@ -110,6 +116,11 @@ def _run_threshold_attack(attack_input: AttackInputData):
loss_test = attack_input.get_loss_test()
if loss_train is None or loss_test is None:
raise ValueError('Not possible to run threshold attack without losses.')
if attack_input.is_multilabel_data():
logging.info(('For multilabel data, when a threshold attack is requested, '
'losses are summed over the class axis before slicing.'))
loss_train = np.sum(loss_train, axis=1)
loss_test = np.sum(loss_test, axis=1)
fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
np.concatenate((loss_train, loss_test)))
@ -126,6 +137,10 @@ def _run_threshold_attack(attack_input: AttackInputData):
def _run_threshold_entropy_attack(attack_input: AttackInputData):
"""Runs threshold entropy attack on single label data."""
if attack_input.is_multilabel_data():
raise NotImplementedError(('Entropy-based attacks are not implemented for '
'multilabel data.'))
ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size()
fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
@ -212,6 +227,7 @@ def run_attacks(attack_input: AttackInputData,
for single_slice_spec in input_slice_specs:
attack_input_slice = get_slice(attack_input, single_slice_spec)
for attack_type in attack_types:
logging.info('Running attack: %s', attack_type.name)
attack_result = _run_attack(attack_input_slice, attack_type,
balance_attacker_training, min_num_samples)
if attack_result is not None:
@ -318,15 +334,48 @@ def run_membership_probability_analysis(
def _compute_missing_privacy_report_metadata(
metadata: PrivacyReportMetadata,
attack_input: AttackInputData) -> PrivacyReportMetadata:
"""Populates metadata fields if they are missing."""
"""Populates metadata fields if they are missing.
Args:
metadata: Metadata that is used to create a privacy report based on the
attack results.
attack_input: The input data used to run a membership attack.
Returns:
A new or updated metadata object containing information to create the
privacy report.
"""
if metadata is None:
metadata = PrivacyReportMetadata()
if attack_input.is_multilabel_data():
accuracy_fn = _get_multilabel_accuracy
sigmoid_func = special.expit
# Multi label accuracy is calculated with the prediction probabilties and
# the labels. A threshold with a default of 0.5 is used to get predicted
# labels from the probabilities.
if (attack_input.probs_train is None and
attack_input.logits_train is not None):
logits_or_probs_train = sigmoid_func(attack_input.logits_train)
else:
logits_or_probs_train = attack_input.probs_train
if (attack_input.probs_test is None and
attack_input.logits_test is not None):
logits_or_probs_test = sigmoid_func(attack_input.logits_test)
else:
logits_or_probs_test = attack_input.probs_test
else:
accuracy_fn = _get_accuracy
# Single label accuracy is calculated with the argmax of the logits which
# is the same as the argmax of the probbilities.
logits_or_probs_train = attack_input.logits_or_probs_train
logits_or_probs_test = attack_input.logits_or_probs_test
if metadata.accuracy_train is None:
metadata.accuracy_train = _get_accuracy(attack_input.logits_train,
attack_input.labels_train)
metadata.accuracy_train = accuracy_fn(logits_or_probs_train,
attack_input.labels_train)
if metadata.accuracy_test is None:
metadata.accuracy_test = _get_accuracy(attack_input.logits_test,
attack_input.labels_test)
metadata.accuracy_test = accuracy_fn(logits_or_probs_test,
attack_input.labels_test)
loss_train = attack_input.get_loss_train()
loss_test = attack_input.get_loss_test()
if metadata.loss_train is None and loss_train is not None:
@ -341,3 +390,28 @@ def _get_accuracy(logits, labels):
if logits is None or labels is None:
return None
return metrics.accuracy_score(labels, np.argmax(logits, axis=1))
def _get_numpy_binary_accuracy(preds: ArrayLike, labels: ArrayLike):
"""Computes the multilabel accuracy at threshold=0.5 using Numpy."""
return np.mean(np.equal(labels, np.round(preds)))
def _get_multilabel_accuracy(preds: ArrayLike, labels: ArrayLike):
"""Computes the accuracy over multilabel data if it is missing.
Compute multilabel binary accuracy. AUC is a better measure of model quality
for multilabel classification than accuracy, in particular when the classes
are imbalanced. For consistency with the single label classification case,
we compute and return the binary accuracy over the labels and predictions.
Args:
preds: Prediction probabilities.
labels: Ground truth multihot labels.
Returns:
The binary accuracy averaged across all labels.
"""
if preds is None or labels is None:
return None
return _get_numpy_binary_accuracy(preds, labels)

View file

@ -34,6 +34,42 @@ def get_test_input(n_train, n_test):
labels_test=np.array([i % 5 for i in range(n_test)]))
def get_multihot_labels_for_test(num_samples: int,
num_classes: int) -> np.ndarray:
"""Generate a array of multihot labels.
Given an integer 'num_samples', generate a deterministic array of
'num_classes'multihot labels. Each multihot label is the list of bits (0/1) of
the corresponding row number in the array, upto 'num_classes'. If the value
of num_classes < num_samples, then the bit list repeats.
e.g. if num_samples=10 and num_classes=3, row=3 corresponds to the label
vector [0, 1, 1].
Args:
num_samples: Number of samples for which to generate test labels.
num_classes: Number of classes for which to generate test multihot labels.
Returns:
Numpy integer array with rows=number of samples, and columns=length of the
bit-representation of the number of samples.
"""
m = 2**num_classes # Number of unique labels given the number of classes.
bit_format = f'0{num_classes}b' # Bit representation format with leading 0s.
return np.asarray(
[list(format(i % m, bit_format)) for i in range(num_samples)]).astype(int)
def get_multilabel_test_input(n_train, n_test):
"""Get example multilabel inputs for attacks."""
rng = np.random.RandomState(4)
num_classes = max(n_train // 20, 5) # use at least 5 classes.
return AttackInputData(
logits_train=rng.randn(n_train, num_classes) + 0.2,
logits_test=rng.randn(n_test, num_classes) + 0.2,
labels_train=get_multihot_labels_for_test(n_train, num_classes),
labels_test=get_multihot_labels_for_test(n_test, num_classes))
def get_test_input_logits_only(n_train, n_test):
"""Get example input logits for attacks."""
rng = np.random.RandomState(4)
@ -172,5 +208,60 @@ class RunAttacksTest(absltest.TestCase):
DataSize(ntrain=20, ntest=16))
class RunAttacksTestOnMultilabelData(absltest.TestCase):
def test_run_attacks_size(self):
result = mia.run_attacks(
get_multilabel_test_input(100, 100), SlicingSpec(),
(AttackType.LOGISTIC_REGRESSION,))
self.assertLen(result.single_attack_results, 1)
def test_run_attack_trained_sets_attack_type(self):
result = mia._run_attack(
get_multilabel_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION)
def test_run_attack_threshold_sets_attack_type(self):
result = mia._run_attack(
get_multilabel_test_input(100, 100), AttackType.THRESHOLD_ATTACK)
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
def test_run_attack_threshold_entropy_fails(self):
self.assertRaises(NotImplementedError, mia._run_threshold_entropy_attack,
get_multilabel_test_input(100, 100))
def test_run_attack_by_percentiles_slice(self):
result = mia.run_attacks(
get_multilabel_test_input(100, 100),
SlicingSpec(entire_dataset=True, by_class=False, by_percentiles=True),
(AttackType.THRESHOLD_ATTACK,))
# 1 attack on entire dataset, 1 attack each of 10 percentile ranges, total
# of 11.
self.assertLen(result.single_attack_results, 11)
expected_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (20, 30))
# First slice (Slice #0) is entire dataset. Hence Slice #3 is the 3rd
# percentile range 20-30.
self.assertEqual(result.single_attack_results[3].slice_spec, expected_slice)
def test_numpy_multilabel_accuracy(self):
predictions = [[0.5, 0.2, 0.3], [0.1, 0.6, 0.3], [0.5, 0.2, 0.3]]
labels = [[1, 0, 0], [0, 1, 1], [1, 0, 1]]
# At a threshold=0.5, 5 of the total 9 lables are correct.
self.assertAlmostEqual(
mia._get_numpy_binary_accuracy(predictions, labels), 5 / 9, places=6)
def test_multilabel_accuracy(self):
predictions = [[0.5, 0.2, 0.3], [0.1, 0.6, 0.3], [0.5, 0.2, 0.3]]
labels = [[1, 0, 0], [0, 1, 1], [1, 0, 1]]
# At a threshold=0.5, 5 of the total 9 lables are correct.
self.assertAlmostEqual(
mia._get_multilabel_accuracy(predictions, labels), 5 / 9, places=6)
self.assertIsNone(mia._get_accuracy(None, labels))
if __name__ == '__main__':
absltest.main()

View file

@ -47,6 +47,25 @@ class TrainedAttackerTest(absltest.TestCase):
self.assertLen(attacker_data.fold_indices, 5)
self.assertEmpty(attacker_data.left_out_indices)
def test_multilabel_create_attacker_data_loss_and_logits(self):
attack_input = AttackInputData(
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
logits_test=np.array([[10, 11], [14, 15]]),
labels_train=np.array([[0, 1], [1, 1], [1, 0]]),
labels_test=np.array([[1, 0], [1, 1]]),
loss_train=np.array([[1, 3], [6, 7], [8, 9]]),
loss_test=np.array([[4, 2], [4, 6]]))
attacker_data = models.create_attacker_data(attack_input, balance=False)
self.assertLen(attacker_data.features_all, 5)
self.assertLen(attacker_data.fold_indices, 5)
self.assertEmpty(attacker_data.left_out_indices)
self.assertEqual(
attacker_data.features_all.shape[1],
attack_input.logits_train.shape[1] + attack_input.loss_train.shape[1])
self.assertTrue(
attack_input.is_multilabel_data(),
msg='Expected multilabel check to pass.')
def test_unbalanced_create_attacker_data_loss_and_logits(self):
attack_input = AttackInputData(
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),

View file

@ -13,6 +13,7 @@
# limitations under the License.
"""Utility functions for membership inference attacks."""
import logging
import numpy as np
from scipy import special
@ -76,3 +77,48 @@ def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
the squared loss of each sample.
"""
return (y_true - y_pred)**2
def multilabel_bce_loss(labels: np.ndarray,
pred: np.ndarray,
from_logits=False,
small_value=1e-8) -> np.ndarray:
"""Computes the per-multi-label-example cross entropy loss.
Args:
labels: numpy array of shape (num_samples, num_classes). labels[i] is the
true multi-hot encoded label (vector) of the i-th sample and each element
of the vector is one of {0, 1}.
pred: numpy array of shape (num_samples, num_classes). pred[i] is the
logits or probability vector of the i-th sample.
from_logits: whether `pred` is logits or probability vector.
small_value: a scalar. np.log can become -inf if the probability is too
close to 0, so the probability is clipped below by small_value.
Returns:
the cross-entropy loss of each sample for each class.
"""
# Check arrays.
if labels.shape[0] != pred.shape[0]:
raise ValueError('labels and pred should have the same number of examples,',
f'but got {labels.shape[0]} and {pred.shape[0]}.')
if not ((labels == 0) | (labels == 1)).all():
raise ValueError(
'labels should be in {0, 1}. For multi-label classification the labels '
'should be multihot encoded.')
# Check if labels vectors are multi-label.
summed_labels = np.sum(labels, axis=1)
if ((summed_labels == 0) | (summed_labels == 1)).all():
logging.info(
('Labels are one-hot encoded single label. Every sample has at most one'
' positive label.'))
if not from_logits and ((pred < 0.0) | (pred > 1.0)).any():
raise ValueError(('Prediction probabilities are not in [0, 1] and '
'`from_logits` is set to False.'))
# Multi-class multi-label binary cross entropy loss
if from_logits:
pred = special.expit(pred)
bce = labels * np.log(pred + small_value)
bce += (1 - labels) * np.log(1 - pred + small_value)
return -bce

View file

@ -124,5 +124,62 @@ class TestSquaredLoss(parameterized.TestCase):
np.testing.assert_allclose(loss, expected_loss, atol=1e-7)
class TestMultilabelBCELoss(parameterized.TestCase):
@parameterized.named_parameters(
('probs_example1', np.array(
[[0, 1, 1], [1, 1, 0]]), np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
np.array([[0.22314343, 1.20397247, 0.3566748],
[0.22314343, 0.51082546, 2.30258409]]), False),
('probs_example2', np.array([[0, 1, 0], [1, 1, 0]]),
np.array([[0.01, 0.02, 0.04], [0.8, 0.7, 0.9]]),
np.array([[0.01005033, 3.91202251, 0.04082198],
[0.22314354, 0.35667493, 2.30258499]]), False),
('logits_example1', np.array([[0, 1, 1], [1, 1, 0]]),
np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]),
np.array([[0.26328245, 0.85435522, 0.11551951],
[0.69314716, 0.47407697, 1.70141322]]), True),
('logits_example2', np.array([[0, 1, 0], [1, 1, 0]]),
np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]),
np.array([[0.26328245, 0.85435522, 2.21551943],
[0.69314716, 0.47407697, 1.70141322]]), True),
)
def test_multilabel_bce_loss(self, label, pred, expected_losses, from_logits):
loss = utils.multilabel_bce_loss(label, pred, from_logits=from_logits)
np.testing.assert_allclose(loss, expected_losses, atol=1e-6)
@parameterized.named_parameters(
('from_logits_true_and_incorrect_values_example1',
np.array([[0, 1, 1], [1, 1, 0]
]), np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
np.array([[0.22314343, 1.20397247, 0.3566748],
[0.22314343, 0.51082546, 2.30258409]]), True),
('from_logits_true_and_incorrect_values_example2',
np.array([[0, 1, 0], [1, 1, 0]
]), np.array([[0.01, 0.02, 0.04], [0.8, 0.7, 0.9]]),
np.array([[0.01005033, 3.91202251, 0.04082198],
[0.22314354, 0.35667493, 2.30258499]]), True),
)
def test_multilabel_bce_loss_incorrect_value(self, label, pred,
expected_losses, from_logits):
loss = utils.multilabel_bce_loss(label, pred, from_logits=from_logits)
self.assertFalse(np.allclose(loss, expected_losses))
@parameterized.named_parameters(
('from_logits_false_and_pred_not_in_0to1_example1',
np.array([[0, 1, 1], [1, 1, 0]
]), np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]), False,
(r'Prediction probabilities are not in \[0, 1\] and '
'`from_logits` is set to False.')),
('labels_not_0_or_1', np.array([[0, 1, 0], [1, 2, 0]]),
np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]), False,
('labels should be in {0, 1}. For multi-label classification the labels '
'should be multihot encoded.')),
)
def test_multilabel_bce_loss_raises(self, label, pred, from_logits, regex):
self.assertRaisesRegex(ValueError, regex, utils.multilabel_bce_loss, label,
pred, from_logits)
if __name__ == '__main__':
absltest.main()