From ee35642b90e65a38f9d393aa8146380f1c35780d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 20 Apr 2022 13:22:55 -0700 Subject: [PATCH] Add multi-label support for Tensorflow Privacy membership attacks. PiperOrigin-RevId: 443176652 --- .../data_structures.py | 216 +++++++++++++++--- .../data_structures_test.py | 81 +++++++ .../dataset_slicing.py | 32 +++ .../dataset_slicing_test.py | 82 ++++++- .../membership_inference_attack.py | 88 ++++++- .../membership_inference_attack_test.py | 91 ++++++++ .../models_test.py | 19 ++ .../membership_inference_attack/utils.py | 46 ++++ .../membership_inference_attack/utils_test.py | 57 +++++ 9 files changed, 676 insertions(+), 36 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index dda0d79..36812d5 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -17,6 +17,7 @@ import collections import dataclasses import enum import glob +import logging import os import pickle from typing import Any, Callable, Iterable, MutableSequence, Optional, Union @@ -115,6 +116,7 @@ class AttackType(enum.Enum): K_NEAREST_NEIGHBORS = 'knn' THRESHOLD_ATTACK = 'threshold' THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy' + TF_LOGISTIC_REGRESSION = 'tf_lr' @property def is_trained_attack(self): @@ -154,6 +156,19 @@ def _is_array_one_dimensional(arr, arr_name): raise ValueError('%s should be a one dimensional numpy array.' % arr_name) +def _is_array_two_dimensional(arr, arr_name): + """Checks whether the array is two dimensional.""" + if arr is not None and len(arr.shape) != 2: + raise ValueError('%s should be a two dimensional numpy array.' % arr_name) + + +def _is_array_one_or_two_dimensional(arr, arr_name): + """Checks whether the array is one or two dimensional.""" + if arr is not None and len(arr.shape) not in [1, 2]: + raise ValueError( + ('%s should be a one or two dimensional numpy array.' % arr_name)) + + def _is_np_array(arr, arr_name): """Checks whether array is a numpy array.""" if arr is not None and not isinstance(arr, np.ndarray): @@ -213,6 +228,14 @@ class AttackInputData: # preferred when both are available. loss_function_using_logits: Optional[bool] = None + # If the problem is a multilabel classification problem. If this is set then + # the loss function and attack data construction are adjusted accordingly. In + # this case the provided labels must be multi-hot encoded. That is, the labels + # are an array of shape (num_examples, num_classes) with 0s where the + # corresponding class is absent from the example, and 1s where the + # corresponding class is present. + multilabel_data: Optional[bool] = None + @property def num_classes(self): if self.labels_train is None or self.labels_test is None: @@ -266,12 +289,12 @@ class AttackInputData: return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1) @staticmethod - def _get_loss( - loss: Optional[np.ndarray], labels: Optional[np.ndarray], - logits: Optional[np.ndarray], probs: Optional[np.ndarray], - loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], - LossFunction], - loss_function_using_logits: Optional[bool]) -> Optional[np.ndarray]: + def _get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray], + logits: Optional[np.ndarray], probs: Optional[np.ndarray], + loss_function: Union[Callable[[np.ndarray, np.ndarray], + np.ndarray], LossFunction], + loss_function_using_logits: Optional[bool], + multilabel_data: Optional[bool]) -> Optional[np.ndarray]: """Calculates (if needed) losses. Args: @@ -284,6 +307,7 @@ class AttackInputData: is supposed to take in (label, logits / probs) as input. loss_function_using_logits: if `loss_function` expects `logits` or `probs`. + multilabel_data: if the data is from a multilabel classification problem. Returns: Loss (or None if neither the loss nor the labels are present). @@ -299,13 +323,23 @@ class AttackInputData: predictions = logits if loss_function_using_logits else probs if loss_function == LossFunction.CROSS_ENTROPY: - loss = utils.log_loss(labels, predictions, loss_function_using_logits) + if multilabel_data: + loss = utils.multilabel_bce_loss(labels, predictions, + loss_function_using_logits) + else: + loss = utils.log_loss(labels, predictions, loss_function_using_logits) elif loss_function == LossFunction.SQUARED: loss = utils.squared_loss(labels, predictions) else: loss = loss_function(labels, predictions) return loss + def __post_init__(self): + """Checks performed after instantiation of the AttackInputData dataclass.""" + # Check if the data is multilabel. + _ = self.is_multilabel_data() + # The validate() check is called as needed, so is not called here. + def get_loss_train(self): """Calculates (if needed) cross-entropy losses for the training set. @@ -316,7 +350,7 @@ class AttackInputData: self.loss_function_using_logits = (self.logits_train is not None) return self._get_loss(self.loss_train, self.labels_train, self.logits_train, self.probs_train, self.loss_function, - self.loss_function_using_logits) + self.loss_function_using_logits, self.multilabel_data) def get_loss_test(self): """Calculates (if needed) cross-entropy losses for the test set. @@ -328,35 +362,143 @@ class AttackInputData: self.loss_function_using_logits = bool(self.logits_test) return self._get_loss(self.loss_test, self.labels_test, self.logits_test, self.probs_test, self.loss_function, - self.loss_function_using_logits) + self.loss_function_using_logits, self.multilabel_data) def get_entropy_train(self): """Calculates prediction entropy for the training set.""" + if self.is_multilabel_data(): + # Not implemented for multilabel data. + raise NotImplementedError('Computation of the entropy is not ' + 'applicable for multi-label data.') if self.entropy_train is not None: return self.entropy_train return self._get_entropy(self.logits_train, self.labels_train) def get_entropy_test(self): """Calculates prediction entropy for the test set.""" + if self.is_multilabel_data(): + # Not implemented for multilabel data. + raise NotImplementedError('Computation of the entropy is not ' + 'applicable for multi-label data.') if self.entropy_test is not None: return self.entropy_test return self._get_entropy(self.logits_test, self.labels_test) - def get_train_size(self): - """Returns size of the training set.""" + def get_train_shape(self): + """Returns the shape of the training set.""" if self.loss_train is not None: - return self.loss_train.size + return self.loss_train.shape if self.entropy_train is not None: - return self.entropy_train.size - return self.logits_or_probs_train.shape[0] + return self.entropy_train.shape + return self.logits_or_probs_train.shape + + def get_test_shape(self): + """Returns the shape of the test set.""" + if self.loss_test is not None: + return self.loss_test.shape + if self.entropy_test is not None: + return self.entropy_test.shape + return self.logits_or_probs_test.shape + + def get_train_size(self): + """Returns the number of examples of the training set.""" + return self.get_train_shape()[0] def get_test_size(self): - """Returns size of the test set.""" - if self.loss_test is not None: - return self.loss_test.size - if self.entropy_test is not None: - return self.entropy_test.size - return self.logits_or_probs_test.shape[0] + """Returns the number of examples of the test set.""" + return self.get_test_shape()[0] + + def is_multihot_labels(self, arr, arr_name) -> bool: + """Check if the 2D array is multihot, with values in [0, 1]. + + Array is multihot if the sum along the classes axis (axis=1) produces a + vector of at least one value > 1. + + Args: + arr: Array to test. + arr_name: Name of the array. + + Returns: + True if the array is a 2D multihot array. + + Raises: + ValueError if the array is not 2D. + """ + if arr is None: + return False + elif len(arr.shape) > 2: + raise ValueError(f'Array {arr_name} is not 2D, cannot determine whether ' + 'it is multihot.') + elif len(arr.shape) == 1: + return False + summed_arr = np.sum(arr, axis=1) + return not ((summed_arr == 0) | (summed_arr == 1)).all() + + def is_multilabel_data(self) -> bool: + """Check if the provided data is for a multilabel classification problem. + + Data is multilabel if all of the following are true: + 1. Train and test sizes are 2-dimensional: (num_samples, num_classes) + 2. Label size is 2-dimentionsl: (num_samples, num_classes) + 3. The labels are multi-hot. The labels are multihot if sum{label_tensor} + along axis=1 (the classes axis) yields a vector of at least one + value > 1. + + Returns: + Whether the provided data is multilabel. + + Raises: + ValueError if the dimensionality of the train and test data are not equal. + """ + # If the data has already been checked for multihot encoded labels, then + # return the result of the evaluation. + if self.multilabel_data is not None: + return self.multilabel_data + + # If one of probs or logits are not provided, or labels are not provided, + # this is not a multilabel problem + if (self.logits_or_probs_train is None or + self.logits_or_probs_test is None or self.labels_train is None or + self.labels_test is None): + self.multilabel_data = False + return self.multilabel_data + + train_shape = self.get_train_shape() + test_shape = self.get_test_shape() + label_train_shape = self.labels_train.shape + label_test_shape = self.labels_test.shape + + if len(train_shape) != len(test_shape): + raise ValueError('The number of dimensions of the train data ' + f'({train_shape}) is not the same as that of the test ' + f'data ({test_shape}).') + if len(train_shape) not in [1, 2] or len(test_shape) not in [1, 2]: + raise ValueError(('Train and test data shapes must be 1-D ' + '(number of samples) or 2-D (number of samples, ' + 'number of classes).')) + if len(label_train_shape) != len(label_test_shape): + raise ValueError('The number of dimensions of the train labels ' + f'({label_train_shape}) is not the same as that of the ' + f'test labels ({label_test_shape}).') + if (len(label_train_shape) not in [1, 2] or + len(label_test_shape) not in [1, 2]): + raise ValueError('Train and test labels shapes must be 1-D ' + '(number of samples) or 2-D (number of samples, ' + 'number of classes).') + data_is_2d = len(train_shape) == len(test_shape) == 2 + if data_is_2d: + equal_feature_count = train_shape[1] == test_shape[1] + else: + equal_feature_count = False + labels_are_2d = len(label_train_shape) == len(label_test_shape) == 2 + labels_train_are_multihot = self.is_multihot_labels(self.labels_train, + 'labels_train') + labels_test_are_multihot = self.is_multihot_labels(self.labels_test, + 'labels_test') + self.multilabel_data = ( + data_is_2d and labels_are_2d and equal_feature_count and + labels_train_are_multihot and labels_test_are_multihot) + return self.multilabel_data def validate(self): """Validates the inputs.""" @@ -413,12 +555,30 @@ class AttackInputData: 'logits_test') _is_last_dim_equal(self.probs_train, 'probs_train', self.probs_test, 'probs_test') - _is_array_one_dimensional(self.loss_train, 'loss_train') - _is_array_one_dimensional(self.loss_test, 'loss_test') - _is_array_one_dimensional(self.entropy_train, 'entropy_train') - _is_array_one_dimensional(self.entropy_test, 'entropy_test') - _is_array_one_dimensional(self.labels_train, 'labels_train') - _is_array_one_dimensional(self.labels_test, 'labels_test') + + if self.is_multilabel_data(): + # Validation for multi-label data. + # Verify that both logits and probabilities have the same number of + # classes. + _is_last_dim_equal(self.logits_train, 'logits_train', self.probs_train, + 'probs_train') + _is_last_dim_equal(self.logits_test, 'logits_test', self.probs_test, + 'probs_test') + # Check that losses, labels and entropies are 2D + # (num_samples, num_classes). + _is_array_two_dimensional(self.loss_train, 'loss_train') + _is_array_two_dimensional(self.loss_test, 'loss_test') + _is_array_two_dimensional(self.entropy_train, 'entropy_train') + _is_array_two_dimensional(self.entropy_test, 'entropy_test') + _is_array_two_dimensional(self.labels_train, 'labels_train') + _is_array_two_dimensional(self.labels_test, 'labels_test') + else: + _is_array_one_dimensional(self.loss_train, 'loss_train') + _is_array_one_dimensional(self.loss_test, 'loss_test') + _is_array_one_dimensional(self.entropy_train, 'entropy_train') + _is_array_one_dimensional(self.entropy_test, 'entropy_test') + _is_array_one_dimensional(self.labels_train, 'labels_train') + _is_array_one_dimensional(self.labels_test, 'labels_test') def __str__(self): """Return the shapes of variables that are not None.""" @@ -784,8 +944,8 @@ class AttackResults: aucs = [result.get_auc() for result in self.single_attack_results] if min(aucs) < 0.4: - print('Suspiciously low AUC detected: %.2f. ' + - 'There might be a bug in the classifier' % min(aucs)) + logging.info(('Suspiciously low AUC detected: %.2f. ' + 'There might be a bug in the classifier'), min(aucs)) return self.single_attack_results[np.argmax(aucs)] diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py index 7dc06e0..80d7a74 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py @@ -277,6 +277,87 @@ class AttackInputDataTest(parameterized.TestCase): probs_train=np.array([]), probs_test=np.array([])).validate) + def test_multilabel_validator(self): + # Tests for multilabel data. + with self.assertRaises( + ValueError, + msg='Validation passes incorrectly when `logits_test` is not 1D/2D.'): + AttackInputData( + logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]), + logits_test=np.array([[[0.01, 1.5], [0.5, -3]], + [[0.01, 1.5], [0.5, -3]]]), + labels_train=np.array([[0, 0], [0, 1], [1, 0]]), + labels_test=np.array([[1, 1], [1, 0]]), + ).validate() + self.assertTrue( + AttackInputData( + logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]), + logits_test=np.array([[0.01, 1.5], [0.5, -3]]), + labels_train=np.array([[0, 0], [0, 1], [1, 1]]), + labels_test=np.array([[1, 1], [1, 0]]), + ).is_multilabel_data(), + msg='Multilabel data check fails even though conditions are met.') + + def test_multihot_labels_check_on_null_array_returns_false(self): + self.assertFalse( + AttackInputData( + logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]), + logits_test=np.array([[0.01, 1.5], [0.5, -3]]), + labels_train=np.array([[0, 0], [0, 1], [1, 1]]), + labels_test=np.array([[1, 1], [1, 0]]), + ).is_multihot_labels(None, 'null_array'), + msg='Multilabel test on a null array should return False.') + self.assertFalse( + AttackInputData( + logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]), + logits_test=np.array([[0.01, 1.5], [0.5, -3]]), + labels_train=np.array([[0, 0], [0, 1], [1, 1]]), + labels_test=np.array([[1, 1], [1, 0]]), + ).is_multihot_labels(np.array([1.0, 2.0, 3.0]), '1d_array'), + msg='Multilabel test on a 1-D array should return False.') + + def test_multilabel_get_bce_loss_from_probs(self): + attack_input = AttackInputData( + probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]), + probs_test=np.array([[0.8, 0.7, 0.9]]), + labels_train=np.array([[0, 1, 1], [1, 1, 0]]), + labels_test=np.array([[1, 1, 0]])) + + np.testing.assert_allclose( + attack_input.get_loss_train(), [[0.22314343, 1.20397247, 0.3566748], + [0.22314343, 0.51082546, 2.30258409]], + atol=1e-6) + np.testing.assert_allclose( + attack_input.get_loss_test(), [[0.22314354, 0.35667493, 2.30258499]], + atol=1e-6) + + def test_multilabel_get_bce_loss_from_logits(self): + attack_input = AttackInputData( + logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]), + logits_test=np.array([[0.01, 1.5], [0.5, -3]]), + labels_train=np.array([[0, 0], [0, 1], [1, 1]]), + labels_test=np.array([[1, 1], [1, 0]])) + + np.testing.assert_allclose( + attack_input.get_loss_train(), + [[0.31326167, 0.126928], [0.69815966, 0.20141327], + [0.47407697, 3.04858714]], + atol=1e-6) + np.testing.assert_allclose( + attack_input.get_loss_test(), + [[0.68815966, 0.20141327], [0.47407697, 0.04858734]], + atol=1e-6) + + def test_multilabel_get_loss_explicitly_provided(self): + attack_input = AttackInputData( + loss_train=np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]), + loss_test=np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]])) + + np.testing.assert_equal(attack_input.get_loss_train().tolist(), + np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]])) + np.testing.assert_equal(attack_input.get_loss_test().tolist(), + np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]])) + class RocCurveTest(absltest.TestCase): diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py index 4129f7b..70747cf 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py @@ -15,6 +15,7 @@ import collections import copy +import logging from typing import List, Optional import numpy as np @@ -49,10 +50,21 @@ def _slice_data_by_indices(data: AttackInputData, idx_train, result.loss_test = _slice_if_not_none(data.loss_test, idx_test) result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test) + # A slice has the same multilabel status as the original data. This is because + # of the way multilabel status is computed. A dataset is multilabel if at + # least 1 sample has a label that is multihot encoded with more than one + # positive class. A slice of this dataset could have only samples with labels + # that have only a single positive class, even if the original dataset were + # multilable. Therefore we ensure that the slice inherits the multilabel state + # of the original dataset. + result.multilabel_data = data.is_multilabel_data() + return result def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData: + if data.is_multilabel_data(): + raise ValueError("Slicing by class not supported for multilabel data.") idx_train = data.labels_train == class_value idx_test = data.labels_test == class_value return _slice_data_by_indices(data, idx_train, idx_test) @@ -65,6 +77,12 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float, # Find from_percentile and to_percentile percentiles in losses. loss_train = data.get_loss_train() loss_test = data.get_loss_test() + if data.is_multilabel_data(): + logging.info("For multilabel data, when slices by percentiles are " + "requested, losses are summed over the class axis before " + "slicing.") + loss_train = np.sum(loss_train, axis=1) + loss_test = np.sum(loss_test, axis=1) losses = np.concatenate((loss_train, loss_test)) from_loss = np.percentile(losses, from_percentile) to_loss = np.percentile(losses, to_percentile) @@ -82,6 +100,20 @@ def _indices_by_classification(logits_or_probs, labels, correctly_classified): def _slice_by_classification_correctness(data: AttackInputData, correctly_classified: bool): + """Slices attack inputs by whether they were classified correctly. + + Args: + data: Data to be used as input to the attack models. + correctly_classified: Whether to use the indices corresponding to the + correctly classified samples. + + Returns: + AttackInputData object containing the sliced data. + """ + + if data.is_multilabel_data(): + raise ValueError("Slicing by classification correctness not supported for " + "multilabel data.") idx_train = _indices_by_classification(data.logits_or_probs_train, data.labels_train, correctly_classified) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py index 2ce8438..3f12908 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from absl.testing import absltest +from absl.testing.absltest import mock import numpy as np from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData @@ -71,7 +73,7 @@ class SingleSliceSpecsTest(absltest.TestCase): self.assertTrue(_are_all_fields_equal(output[0], expected0)) self.assertTrue(_are_all_fields_equal(output[5], expected5)) - def test_slice_by_correcness(self): + def test_slice_by_correctness(self): input_data = SlicingSpec( entire_dataset=False, by_classification_correctness=True) expected = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True) @@ -199,5 +201,83 @@ class GetSliceTest(absltest.TestCase): self.assertTrue((output.labels_test == [1, 2, 0]).all()) +class GetSliceTestForMultilabelData(absltest.TestCase): + + def __init__(self, methodname): + """Initialize the test class.""" + super().__init__(methodname) + + # Create test data for 3 class multilabel classification task. + logits_train = np.array([[0, 1, 0], [2, 0, 3], [4, 5, 0], [6, 7, 0]]) + logits_test = np.array([[10, 0, 11], [12, 13, 0], [14, 15, 0], [0, 16, 17]]) + probs_train = np.array([[0, 1, 0], [0.1, 0, 0.7], [0.4, 0.6, 0], + [0.3, 0.7, 0]]) + probs_test = np.array([[0.4, 0, 0.6], [0.1, 0.9, 0], [0.15, 0.85, 0], + [0, 0, 1]]) + labels_train = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0], [0, 1, 0]]) + labels_test = np.array([[1, 0, 1], [0, 1, 0], [0, 1, 0], [0, 0, 1]]) + loss_train = np.array([[0, 0, 2], [3, 0, 0.3], [2, 0.5, 0], [1.5, 2, 3]]) + loss_test = np.array([[1.5, 0, 1.0], [0.5, 0.8, 0], [0.5, 3, 0], [0, 0, 0]]) + entropy_train = np.array([[0.2, 0.2, 5], [10, 0, 2], [7, 3, 0], [6, 8, 9]]) + entropy_test = np.array([[3, 0, 2], [3, 6, 0], [2, 6, 0], [0, 0, 0]]) + + self.input_data = AttackInputData( + logits_train=logits_train, + logits_test=logits_test, + probs_train=probs_train, + probs_test=probs_test, + labels_train=labels_train, + labels_test=labels_test, + loss_train=loss_train, + loss_test=loss_test, + entropy_train=entropy_train, + entropy_test=entropy_test) + + def test_slice_entire_dataset(self): + entire_dataset_slice = SingleSliceSpec() + output = get_slice(self.input_data, entire_dataset_slice) + expected = self.input_data + expected.slice_spec = entire_dataset_slice + self.assertTrue(_are_all_fields_equal(output, self.input_data)) + + def test_slice_by_class_fails(self): + class_index = 1 + class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index) + self.assertRaises(ValueError, get_slice, self.input_data, class_slice) + + @mock.patch('logging.Logger.info', wraps=logging.Logger) + def test_slice_by_percentile_logs_multilabel_data(self, mock_logger): + percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50)) + _ = get_slice(self.input_data, percentile_slice) + mock_logger.assert_called_with( + ('For multilabel data, when slices by percentiles are ' + 'requested, losses are summed over the class axis before ' + 'slicing.')) + + def test_slice_by_percentile(self): + # 50th percentile is the lower 50% of losses summed over the classes. + percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50)) + output = get_slice(self.input_data, percentile_slice) + + # Check logits. + with self.subTest(msg='Check logits'): + self.assertLen(output.logits_train, 2) + self.assertLen(output.logits_test, 3) + self.assertTrue((output.logits_test[0] == [10, 0, 11]).all()) + + # Check labels. + with self.subTest(msg='Check labels'): + self.assertLen(output.labels_train, 2) + self.assertLen(output.labels_test, 3) + self.assertTrue((output.labels_train == [[0, 1, 1], [1, 1, 0]]).all()) + self.assertTrue((output.labels_test == [[1, 0, 1], [0, 1, 0], [0, 0, + 1]]).all()) + + def test_slice_by_correctness_fails(self): + percentile_slice = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, + False) + self.assertRaises(ValueError, get_slice, self.input_data, percentile_slice) + + if __name__ == '__main__': absltest.main() diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py index 1bc7199..5fe1149 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py @@ -17,9 +17,11 @@ This file belongs to the new API for membership inference attacks. This file will be renamed to membership_inference_attack.py after the old API is removed. """ -from typing import Iterable +import logging +from typing import Iterable, List, Union import numpy as np +from scipy import special from sklearn import metrics from sklearn import model_selection @@ -39,6 +41,9 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.datase from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.dataset_slicing import get_slice +ArrayLike = Union[np.ndarray, List] + + def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec: if hasattr(data, 'slice_spec'): return data.slice_spec @@ -81,7 +86,8 @@ def _run_trained_attack(attack_input: AttackInputData, attacker = models.create_attacker(attack_type) attacker.train_model(features[train_indices], labels[train_indices]) - scores[test_indices] = attacker.predict(features[test_indices]) + predictions = attacker.predict(features[test_indices]) + scores[test_indices] = predictions # Predict the left out with the last attacker if left_out_indices.size: @@ -110,6 +116,11 @@ def _run_threshold_attack(attack_input: AttackInputData): loss_test = attack_input.get_loss_test() if loss_train is None or loss_test is None: raise ValueError('Not possible to run threshold attack without losses.') + if attack_input.is_multilabel_data(): + logging.info(('For multilabel data, when a threshold attack is requested, ' + 'losses are summed over the class axis before slicing.')) + loss_train = np.sum(loss_train, axis=1) + loss_test = np.sum(loss_test, axis=1) fpr, tpr, thresholds = metrics.roc_curve( np.concatenate((np.zeros(ntrain), np.ones(ntest))), np.concatenate((loss_train, loss_test))) @@ -126,6 +137,10 @@ def _run_threshold_attack(attack_input: AttackInputData): def _run_threshold_entropy_attack(attack_input: AttackInputData): + """Runs threshold entropy attack on single label data.""" + if attack_input.is_multilabel_data(): + raise NotImplementedError(('Entropy-based attacks are not implemented for ' + 'multilabel data.')) ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size() fpr, tpr, thresholds = metrics.roc_curve( np.concatenate((np.zeros(ntrain), np.ones(ntest))), @@ -212,6 +227,7 @@ def run_attacks(attack_input: AttackInputData, for single_slice_spec in input_slice_specs: attack_input_slice = get_slice(attack_input, single_slice_spec) for attack_type in attack_types: + logging.info('Running attack: %s', attack_type.name) attack_result = _run_attack(attack_input_slice, attack_type, balance_attacker_training, min_num_samples) if attack_result is not None: @@ -318,15 +334,48 @@ def run_membership_probability_analysis( def _compute_missing_privacy_report_metadata( metadata: PrivacyReportMetadata, attack_input: AttackInputData) -> PrivacyReportMetadata: - """Populates metadata fields if they are missing.""" + """Populates metadata fields if they are missing. + + Args: + metadata: Metadata that is used to create a privacy report based on the + attack results. + attack_input: The input data used to run a membership attack. + + Returns: + A new or updated metadata object containing information to create the + privacy report. + """ + if metadata is None: metadata = PrivacyReportMetadata() + if attack_input.is_multilabel_data(): + accuracy_fn = _get_multilabel_accuracy + sigmoid_func = special.expit + # Multi label accuracy is calculated with the prediction probabilties and + # the labels. A threshold with a default of 0.5 is used to get predicted + # labels from the probabilities. + if (attack_input.probs_train is None and + attack_input.logits_train is not None): + logits_or_probs_train = sigmoid_func(attack_input.logits_train) + else: + logits_or_probs_train = attack_input.probs_train + if (attack_input.probs_test is None and + attack_input.logits_test is not None): + logits_or_probs_test = sigmoid_func(attack_input.logits_test) + else: + logits_or_probs_test = attack_input.probs_test + else: + accuracy_fn = _get_accuracy + # Single label accuracy is calculated with the argmax of the logits which + # is the same as the argmax of the probbilities. + logits_or_probs_train = attack_input.logits_or_probs_train + logits_or_probs_test = attack_input.logits_or_probs_test if metadata.accuracy_train is None: - metadata.accuracy_train = _get_accuracy(attack_input.logits_train, - attack_input.labels_train) + metadata.accuracy_train = accuracy_fn(logits_or_probs_train, + attack_input.labels_train) if metadata.accuracy_test is None: - metadata.accuracy_test = _get_accuracy(attack_input.logits_test, - attack_input.labels_test) + metadata.accuracy_test = accuracy_fn(logits_or_probs_test, + attack_input.labels_test) loss_train = attack_input.get_loss_train() loss_test = attack_input.get_loss_test() if metadata.loss_train is None and loss_train is not None: @@ -341,3 +390,28 @@ def _get_accuracy(logits, labels): if logits is None or labels is None: return None return metrics.accuracy_score(labels, np.argmax(logits, axis=1)) + + +def _get_numpy_binary_accuracy(preds: ArrayLike, labels: ArrayLike): + """Computes the multilabel accuracy at threshold=0.5 using Numpy.""" + return np.mean(np.equal(labels, np.round(preds))) + + +def _get_multilabel_accuracy(preds: ArrayLike, labels: ArrayLike): + """Computes the accuracy over multilabel data if it is missing. + + Compute multilabel binary accuracy. AUC is a better measure of model quality + for multilabel classification than accuracy, in particular when the classes + are imbalanced. For consistency with the single label classification case, + we compute and return the binary accuracy over the labels and predictions. + + Args: + preds: Prediction probabilities. + labels: Ground truth multihot labels. + + Returns: + The binary accuracy averaged across all labels. + """ + if preds is None or labels is None: + return None + return _get_numpy_binary_accuracy(preds, labels) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py index a524b7d..282478b 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py @@ -34,6 +34,42 @@ def get_test_input(n_train, n_test): labels_test=np.array([i % 5 for i in range(n_test)])) +def get_multihot_labels_for_test(num_samples: int, + num_classes: int) -> np.ndarray: + """Generate a array of multihot labels. + + Given an integer 'num_samples', generate a deterministic array of + 'num_classes'multihot labels. Each multihot label is the list of bits (0/1) of + the corresponding row number in the array, upto 'num_classes'. If the value + of num_classes < num_samples, then the bit list repeats. + e.g. if num_samples=10 and num_classes=3, row=3 corresponds to the label + vector [0, 1, 1]. + + Args: + num_samples: Number of samples for which to generate test labels. + num_classes: Number of classes for which to generate test multihot labels. + + Returns: + Numpy integer array with rows=number of samples, and columns=length of the + bit-representation of the number of samples. + """ + m = 2**num_classes # Number of unique labels given the number of classes. + bit_format = f'0{num_classes}b' # Bit representation format with leading 0s. + return np.asarray( + [list(format(i % m, bit_format)) for i in range(num_samples)]).astype(int) + + +def get_multilabel_test_input(n_train, n_test): + """Get example multilabel inputs for attacks.""" + rng = np.random.RandomState(4) + num_classes = max(n_train // 20, 5) # use at least 5 classes. + return AttackInputData( + logits_train=rng.randn(n_train, num_classes) + 0.2, + logits_test=rng.randn(n_test, num_classes) + 0.2, + labels_train=get_multihot_labels_for_test(n_train, num_classes), + labels_test=get_multihot_labels_for_test(n_test, num_classes)) + + def get_test_input_logits_only(n_train, n_test): """Get example input logits for attacks.""" rng = np.random.RandomState(4) @@ -172,5 +208,60 @@ class RunAttacksTest(absltest.TestCase): DataSize(ntrain=20, ntest=16)) +class RunAttacksTestOnMultilabelData(absltest.TestCase): + + def test_run_attacks_size(self): + result = mia.run_attacks( + get_multilabel_test_input(100, 100), SlicingSpec(), + (AttackType.LOGISTIC_REGRESSION,)) + + self.assertLen(result.single_attack_results, 1) + + def test_run_attack_trained_sets_attack_type(self): + result = mia._run_attack( + get_multilabel_test_input(100, 100), AttackType.LOGISTIC_REGRESSION) + + self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION) + + def test_run_attack_threshold_sets_attack_type(self): + result = mia._run_attack( + get_multilabel_test_input(100, 100), AttackType.THRESHOLD_ATTACK) + + self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK) + + def test_run_attack_threshold_entropy_fails(self): + self.assertRaises(NotImplementedError, mia._run_threshold_entropy_attack, + get_multilabel_test_input(100, 100)) + + def test_run_attack_by_percentiles_slice(self): + result = mia.run_attacks( + get_multilabel_test_input(100, 100), + SlicingSpec(entire_dataset=True, by_class=False, by_percentiles=True), + (AttackType.THRESHOLD_ATTACK,)) + + # 1 attack on entire dataset, 1 attack each of 10 percentile ranges, total + # of 11. + self.assertLen(result.single_attack_results, 11) + expected_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (20, 30)) + # First slice (Slice #0) is entire dataset. Hence Slice #3 is the 3rd + # percentile range 20-30. + self.assertEqual(result.single_attack_results[3].slice_spec, expected_slice) + + def test_numpy_multilabel_accuracy(self): + predictions = [[0.5, 0.2, 0.3], [0.1, 0.6, 0.3], [0.5, 0.2, 0.3]] + labels = [[1, 0, 0], [0, 1, 1], [1, 0, 1]] + # At a threshold=0.5, 5 of the total 9 lables are correct. + self.assertAlmostEqual( + mia._get_numpy_binary_accuracy(predictions, labels), 5 / 9, places=6) + + def test_multilabel_accuracy(self): + predictions = [[0.5, 0.2, 0.3], [0.1, 0.6, 0.3], [0.5, 0.2, 0.3]] + labels = [[1, 0, 0], [0, 1, 1], [1, 0, 1]] + # At a threshold=0.5, 5 of the total 9 lables are correct. + self.assertAlmostEqual( + mia._get_multilabel_accuracy(predictions, labels), 5 / 9, places=6) + self.assertIsNone(mia._get_accuracy(None, labels)) + + if __name__ == '__main__': absltest.main() diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py index 70355bf..4ace899 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py @@ -47,6 +47,25 @@ class TrainedAttackerTest(absltest.TestCase): self.assertLen(attacker_data.fold_indices, 5) self.assertEmpty(attacker_data.left_out_indices) + def test_multilabel_create_attacker_data_loss_and_logits(self): + attack_input = AttackInputData( + logits_train=np.array([[1, 2], [5, 6], [8, 9]]), + logits_test=np.array([[10, 11], [14, 15]]), + labels_train=np.array([[0, 1], [1, 1], [1, 0]]), + labels_test=np.array([[1, 0], [1, 1]]), + loss_train=np.array([[1, 3], [6, 7], [8, 9]]), + loss_test=np.array([[4, 2], [4, 6]])) + attacker_data = models.create_attacker_data(attack_input, balance=False) + self.assertLen(attacker_data.features_all, 5) + self.assertLen(attacker_data.fold_indices, 5) + self.assertEmpty(attacker_data.left_out_indices) + self.assertEqual( + attacker_data.features_all.shape[1], + attack_input.logits_train.shape[1] + attack_input.loss_train.shape[1]) + self.assertTrue( + attack_input.is_multilabel_data(), + msg='Expected multilabel check to pass.') + def test_unbalanced_create_attacker_data_loss_and_logits(self): attack_input = AttackInputData( logits_train=np.array([[1, 2], [5, 6], [8, 9]]), diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py index 0ee3ddd..9d02d83 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py @@ -13,6 +13,7 @@ # limitations under the License. """Utility functions for membership inference attacks.""" +import logging import numpy as np from scipy import special @@ -76,3 +77,48 @@ def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: the squared loss of each sample. """ return (y_true - y_pred)**2 + + +def multilabel_bce_loss(labels: np.ndarray, + pred: np.ndarray, + from_logits=False, + small_value=1e-8) -> np.ndarray: + """Computes the per-multi-label-example cross entropy loss. + + Args: + labels: numpy array of shape (num_samples, num_classes). labels[i] is the + true multi-hot encoded label (vector) of the i-th sample and each element + of the vector is one of {0, 1}. + pred: numpy array of shape (num_samples, num_classes). pred[i] is the + logits or probability vector of the i-th sample. + from_logits: whether `pred` is logits or probability vector. + small_value: a scalar. np.log can become -inf if the probability is too + close to 0, so the probability is clipped below by small_value. + + Returns: + the cross-entropy loss of each sample for each class. + """ + # Check arrays. + if labels.shape[0] != pred.shape[0]: + raise ValueError('labels and pred should have the same number of examples,', + f'but got {labels.shape[0]} and {pred.shape[0]}.') + if not ((labels == 0) | (labels == 1)).all(): + raise ValueError( + 'labels should be in {0, 1}. For multi-label classification the labels ' + 'should be multihot encoded.') + # Check if labels vectors are multi-label. + summed_labels = np.sum(labels, axis=1) + if ((summed_labels == 0) | (summed_labels == 1)).all(): + logging.info( + ('Labels are one-hot encoded single label. Every sample has at most one' + ' positive label.')) + if not from_logits and ((pred < 0.0) | (pred > 1.0)).any(): + raise ValueError(('Prediction probabilities are not in [0, 1] and ' + '`from_logits` is set to False.')) + + # Multi-class multi-label binary cross entropy loss + if from_logits: + pred = special.expit(pred) + bce = labels * np.log(pred + small_value) + bce += (1 - labels) * np.log(1 - pred + small_value) + return -bce diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py index 15f639b..725aeb8 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py @@ -124,5 +124,62 @@ class TestSquaredLoss(parameterized.TestCase): np.testing.assert_allclose(loss, expected_loss, atol=1e-7) +class TestMultilabelBCELoss(parameterized.TestCase): + + @parameterized.named_parameters( + ('probs_example1', np.array( + [[0, 1, 1], [1, 1, 0]]), np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]), + np.array([[0.22314343, 1.20397247, 0.3566748], + [0.22314343, 0.51082546, 2.30258409]]), False), + ('probs_example2', np.array([[0, 1, 0], [1, 1, 0]]), + np.array([[0.01, 0.02, 0.04], [0.8, 0.7, 0.9]]), + np.array([[0.01005033, 3.91202251, 0.04082198], + [0.22314354, 0.35667493, 2.30258499]]), False), + ('logits_example1', np.array([[0, 1, 1], [1, 1, 0]]), + np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]), + np.array([[0.26328245, 0.85435522, 0.11551951], + [0.69314716, 0.47407697, 1.70141322]]), True), + ('logits_example2', np.array([[0, 1, 0], [1, 1, 0]]), + np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]), + np.array([[0.26328245, 0.85435522, 2.21551943], + [0.69314716, 0.47407697, 1.70141322]]), True), + ) + def test_multilabel_bce_loss(self, label, pred, expected_losses, from_logits): + loss = utils.multilabel_bce_loss(label, pred, from_logits=from_logits) + np.testing.assert_allclose(loss, expected_losses, atol=1e-6) + + @parameterized.named_parameters( + ('from_logits_true_and_incorrect_values_example1', + np.array([[0, 1, 1], [1, 1, 0] + ]), np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]), + np.array([[0.22314343, 1.20397247, 0.3566748], + [0.22314343, 0.51082546, 2.30258409]]), True), + ('from_logits_true_and_incorrect_values_example2', + np.array([[0, 1, 0], [1, 1, 0] + ]), np.array([[0.01, 0.02, 0.04], [0.8, 0.7, 0.9]]), + np.array([[0.01005033, 3.91202251, 0.04082198], + [0.22314354, 0.35667493, 2.30258499]]), True), + ) + def test_multilabel_bce_loss_incorrect_value(self, label, pred, + expected_losses, from_logits): + loss = utils.multilabel_bce_loss(label, pred, from_logits=from_logits) + self.assertFalse(np.allclose(loss, expected_losses)) + + @parameterized.named_parameters( + ('from_logits_false_and_pred_not_in_0to1_example1', + np.array([[0, 1, 1], [1, 1, 0] + ]), np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]), False, + (r'Prediction probabilities are not in \[0, 1\] and ' + '`from_logits` is set to False.')), + ('labels_not_0_or_1', np.array([[0, 1, 0], [1, 2, 0]]), + np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]), False, + ('labels should be in {0, 1}. For multi-label classification the labels ' + 'should be multihot encoded.')), + ) + def test_multilabel_bce_loss_raises(self, label, pred, from_logits, regex): + self.assertRaisesRegex(ValueError, regex, utils.multilabel_bce_loss, label, + pred, from_logits) + + if __name__ == '__main__': absltest.main()