forked from 626_privacy/tensorflow_privacy
Add multi-label support for Tensorflow Privacy membership attacks.
PiperOrigin-RevId: 443176652
This commit is contained in:
parent
e14618fe7c
commit
ee35642b90
9 changed files with 676 additions and 36 deletions
|
@ -17,6 +17,7 @@ import collections
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import glob
|
import glob
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, Callable, Iterable, MutableSequence, Optional, Union
|
from typing import Any, Callable, Iterable, MutableSequence, Optional, Union
|
||||||
|
@ -115,6 +116,7 @@ class AttackType(enum.Enum):
|
||||||
K_NEAREST_NEIGHBORS = 'knn'
|
K_NEAREST_NEIGHBORS = 'knn'
|
||||||
THRESHOLD_ATTACK = 'threshold'
|
THRESHOLD_ATTACK = 'threshold'
|
||||||
THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy'
|
THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy'
|
||||||
|
TF_LOGISTIC_REGRESSION = 'tf_lr'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trained_attack(self):
|
def is_trained_attack(self):
|
||||||
|
@ -154,6 +156,19 @@ def _is_array_one_dimensional(arr, arr_name):
|
||||||
raise ValueError('%s should be a one dimensional numpy array.' % arr_name)
|
raise ValueError('%s should be a one dimensional numpy array.' % arr_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_array_two_dimensional(arr, arr_name):
|
||||||
|
"""Checks whether the array is two dimensional."""
|
||||||
|
if arr is not None and len(arr.shape) != 2:
|
||||||
|
raise ValueError('%s should be a two dimensional numpy array.' % arr_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_array_one_or_two_dimensional(arr, arr_name):
|
||||||
|
"""Checks whether the array is one or two dimensional."""
|
||||||
|
if arr is not None and len(arr.shape) not in [1, 2]:
|
||||||
|
raise ValueError(
|
||||||
|
('%s should be a one or two dimensional numpy array.' % arr_name))
|
||||||
|
|
||||||
|
|
||||||
def _is_np_array(arr, arr_name):
|
def _is_np_array(arr, arr_name):
|
||||||
"""Checks whether array is a numpy array."""
|
"""Checks whether array is a numpy array."""
|
||||||
if arr is not None and not isinstance(arr, np.ndarray):
|
if arr is not None and not isinstance(arr, np.ndarray):
|
||||||
|
@ -213,6 +228,14 @@ class AttackInputData:
|
||||||
# preferred when both are available.
|
# preferred when both are available.
|
||||||
loss_function_using_logits: Optional[bool] = None
|
loss_function_using_logits: Optional[bool] = None
|
||||||
|
|
||||||
|
# If the problem is a multilabel classification problem. If this is set then
|
||||||
|
# the loss function and attack data construction are adjusted accordingly. In
|
||||||
|
# this case the provided labels must be multi-hot encoded. That is, the labels
|
||||||
|
# are an array of shape (num_examples, num_classes) with 0s where the
|
||||||
|
# corresponding class is absent from the example, and 1s where the
|
||||||
|
# corresponding class is present.
|
||||||
|
multilabel_data: Optional[bool] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
if self.labels_train is None or self.labels_test is None:
|
if self.labels_train is None or self.labels_test is None:
|
||||||
|
@ -266,12 +289,12 @@ class AttackInputData:
|
||||||
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
|
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_loss(
|
def _get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||||
loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
|
||||||
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
||||||
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
|
loss_function: Union[Callable[[np.ndarray, np.ndarray],
|
||||||
LossFunction],
|
np.ndarray], LossFunction],
|
||||||
loss_function_using_logits: Optional[bool]) -> Optional[np.ndarray]:
|
loss_function_using_logits: Optional[bool],
|
||||||
|
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
|
||||||
"""Calculates (if needed) losses.
|
"""Calculates (if needed) losses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -284,6 +307,7 @@ class AttackInputData:
|
||||||
is supposed to take in (label, logits / probs) as input.
|
is supposed to take in (label, logits / probs) as input.
|
||||||
loss_function_using_logits: if `loss_function` expects `logits` or
|
loss_function_using_logits: if `loss_function` expects `logits` or
|
||||||
`probs`.
|
`probs`.
|
||||||
|
multilabel_data: if the data is from a multilabel classification problem.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loss (or None if neither the loss nor the labels are present).
|
Loss (or None if neither the loss nor the labels are present).
|
||||||
|
@ -299,6 +323,10 @@ class AttackInputData:
|
||||||
|
|
||||||
predictions = logits if loss_function_using_logits else probs
|
predictions = logits if loss_function_using_logits else probs
|
||||||
if loss_function == LossFunction.CROSS_ENTROPY:
|
if loss_function == LossFunction.CROSS_ENTROPY:
|
||||||
|
if multilabel_data:
|
||||||
|
loss = utils.multilabel_bce_loss(labels, predictions,
|
||||||
|
loss_function_using_logits)
|
||||||
|
else:
|
||||||
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
|
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
|
||||||
elif loss_function == LossFunction.SQUARED:
|
elif loss_function == LossFunction.SQUARED:
|
||||||
loss = utils.squared_loss(labels, predictions)
|
loss = utils.squared_loss(labels, predictions)
|
||||||
|
@ -306,6 +334,12 @@ class AttackInputData:
|
||||||
loss = loss_function(labels, predictions)
|
loss = loss_function(labels, predictions)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Checks performed after instantiation of the AttackInputData dataclass."""
|
||||||
|
# Check if the data is multilabel.
|
||||||
|
_ = self.is_multilabel_data()
|
||||||
|
# The validate() check is called as needed, so is not called here.
|
||||||
|
|
||||||
def get_loss_train(self):
|
def get_loss_train(self):
|
||||||
"""Calculates (if needed) cross-entropy losses for the training set.
|
"""Calculates (if needed) cross-entropy losses for the training set.
|
||||||
|
|
||||||
|
@ -316,7 +350,7 @@ class AttackInputData:
|
||||||
self.loss_function_using_logits = (self.logits_train is not None)
|
self.loss_function_using_logits = (self.logits_train is not None)
|
||||||
return self._get_loss(self.loss_train, self.labels_train, self.logits_train,
|
return self._get_loss(self.loss_train, self.labels_train, self.logits_train,
|
||||||
self.probs_train, self.loss_function,
|
self.probs_train, self.loss_function,
|
||||||
self.loss_function_using_logits)
|
self.loss_function_using_logits, self.multilabel_data)
|
||||||
|
|
||||||
def get_loss_test(self):
|
def get_loss_test(self):
|
||||||
"""Calculates (if needed) cross-entropy losses for the test set.
|
"""Calculates (if needed) cross-entropy losses for the test set.
|
||||||
|
@ -328,35 +362,143 @@ class AttackInputData:
|
||||||
self.loss_function_using_logits = bool(self.logits_test)
|
self.loss_function_using_logits = bool(self.logits_test)
|
||||||
return self._get_loss(self.loss_test, self.labels_test, self.logits_test,
|
return self._get_loss(self.loss_test, self.labels_test, self.logits_test,
|
||||||
self.probs_test, self.loss_function,
|
self.probs_test, self.loss_function,
|
||||||
self.loss_function_using_logits)
|
self.loss_function_using_logits, self.multilabel_data)
|
||||||
|
|
||||||
def get_entropy_train(self):
|
def get_entropy_train(self):
|
||||||
"""Calculates prediction entropy for the training set."""
|
"""Calculates prediction entropy for the training set."""
|
||||||
|
if self.is_multilabel_data():
|
||||||
|
# Not implemented for multilabel data.
|
||||||
|
raise NotImplementedError('Computation of the entropy is not '
|
||||||
|
'applicable for multi-label data.')
|
||||||
if self.entropy_train is not None:
|
if self.entropy_train is not None:
|
||||||
return self.entropy_train
|
return self.entropy_train
|
||||||
return self._get_entropy(self.logits_train, self.labels_train)
|
return self._get_entropy(self.logits_train, self.labels_train)
|
||||||
|
|
||||||
def get_entropy_test(self):
|
def get_entropy_test(self):
|
||||||
"""Calculates prediction entropy for the test set."""
|
"""Calculates prediction entropy for the test set."""
|
||||||
|
if self.is_multilabel_data():
|
||||||
|
# Not implemented for multilabel data.
|
||||||
|
raise NotImplementedError('Computation of the entropy is not '
|
||||||
|
'applicable for multi-label data.')
|
||||||
if self.entropy_test is not None:
|
if self.entropy_test is not None:
|
||||||
return self.entropy_test
|
return self.entropy_test
|
||||||
return self._get_entropy(self.logits_test, self.labels_test)
|
return self._get_entropy(self.logits_test, self.labels_test)
|
||||||
|
|
||||||
def get_train_size(self):
|
def get_train_shape(self):
|
||||||
"""Returns size of the training set."""
|
"""Returns the shape of the training set."""
|
||||||
if self.loss_train is not None:
|
if self.loss_train is not None:
|
||||||
return self.loss_train.size
|
return self.loss_train.shape
|
||||||
if self.entropy_train is not None:
|
if self.entropy_train is not None:
|
||||||
return self.entropy_train.size
|
return self.entropy_train.shape
|
||||||
return self.logits_or_probs_train.shape[0]
|
return self.logits_or_probs_train.shape
|
||||||
|
|
||||||
|
def get_test_shape(self):
|
||||||
|
"""Returns the shape of the test set."""
|
||||||
|
if self.loss_test is not None:
|
||||||
|
return self.loss_test.shape
|
||||||
|
if self.entropy_test is not None:
|
||||||
|
return self.entropy_test.shape
|
||||||
|
return self.logits_or_probs_test.shape
|
||||||
|
|
||||||
|
def get_train_size(self):
|
||||||
|
"""Returns the number of examples of the training set."""
|
||||||
|
return self.get_train_shape()[0]
|
||||||
|
|
||||||
def get_test_size(self):
|
def get_test_size(self):
|
||||||
"""Returns size of the test set."""
|
"""Returns the number of examples of the test set."""
|
||||||
if self.loss_test is not None:
|
return self.get_test_shape()[0]
|
||||||
return self.loss_test.size
|
|
||||||
if self.entropy_test is not None:
|
def is_multihot_labels(self, arr, arr_name) -> bool:
|
||||||
return self.entropy_test.size
|
"""Check if the 2D array is multihot, with values in [0, 1].
|
||||||
return self.logits_or_probs_test.shape[0]
|
|
||||||
|
Array is multihot if the sum along the classes axis (axis=1) produces a
|
||||||
|
vector of at least one value > 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arr: Array to test.
|
||||||
|
arr_name: Name of the array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the array is a 2D multihot array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if the array is not 2D.
|
||||||
|
"""
|
||||||
|
if arr is None:
|
||||||
|
return False
|
||||||
|
elif len(arr.shape) > 2:
|
||||||
|
raise ValueError(f'Array {arr_name} is not 2D, cannot determine whether '
|
||||||
|
'it is multihot.')
|
||||||
|
elif len(arr.shape) == 1:
|
||||||
|
return False
|
||||||
|
summed_arr = np.sum(arr, axis=1)
|
||||||
|
return not ((summed_arr == 0) | (summed_arr == 1)).all()
|
||||||
|
|
||||||
|
def is_multilabel_data(self) -> bool:
|
||||||
|
"""Check if the provided data is for a multilabel classification problem.
|
||||||
|
|
||||||
|
Data is multilabel if all of the following are true:
|
||||||
|
1. Train and test sizes are 2-dimensional: (num_samples, num_classes)
|
||||||
|
2. Label size is 2-dimentionsl: (num_samples, num_classes)
|
||||||
|
3. The labels are multi-hot. The labels are multihot if sum{label_tensor}
|
||||||
|
along axis=1 (the classes axis) yields a vector of at least one
|
||||||
|
value > 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the provided data is multilabel.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if the dimensionality of the train and test data are not equal.
|
||||||
|
"""
|
||||||
|
# If the data has already been checked for multihot encoded labels, then
|
||||||
|
# return the result of the evaluation.
|
||||||
|
if self.multilabel_data is not None:
|
||||||
|
return self.multilabel_data
|
||||||
|
|
||||||
|
# If one of probs or logits are not provided, or labels are not provided,
|
||||||
|
# this is not a multilabel problem
|
||||||
|
if (self.logits_or_probs_train is None or
|
||||||
|
self.logits_or_probs_test is None or self.labels_train is None or
|
||||||
|
self.labels_test is None):
|
||||||
|
self.multilabel_data = False
|
||||||
|
return self.multilabel_data
|
||||||
|
|
||||||
|
train_shape = self.get_train_shape()
|
||||||
|
test_shape = self.get_test_shape()
|
||||||
|
label_train_shape = self.labels_train.shape
|
||||||
|
label_test_shape = self.labels_test.shape
|
||||||
|
|
||||||
|
if len(train_shape) != len(test_shape):
|
||||||
|
raise ValueError('The number of dimensions of the train data '
|
||||||
|
f'({train_shape}) is not the same as that of the test '
|
||||||
|
f'data ({test_shape}).')
|
||||||
|
if len(train_shape) not in [1, 2] or len(test_shape) not in [1, 2]:
|
||||||
|
raise ValueError(('Train and test data shapes must be 1-D '
|
||||||
|
'(number of samples) or 2-D (number of samples, '
|
||||||
|
'number of classes).'))
|
||||||
|
if len(label_train_shape) != len(label_test_shape):
|
||||||
|
raise ValueError('The number of dimensions of the train labels '
|
||||||
|
f'({label_train_shape}) is not the same as that of the '
|
||||||
|
f'test labels ({label_test_shape}).')
|
||||||
|
if (len(label_train_shape) not in [1, 2] or
|
||||||
|
len(label_test_shape) not in [1, 2]):
|
||||||
|
raise ValueError('Train and test labels shapes must be 1-D '
|
||||||
|
'(number of samples) or 2-D (number of samples, '
|
||||||
|
'number of classes).')
|
||||||
|
data_is_2d = len(train_shape) == len(test_shape) == 2
|
||||||
|
if data_is_2d:
|
||||||
|
equal_feature_count = train_shape[1] == test_shape[1]
|
||||||
|
else:
|
||||||
|
equal_feature_count = False
|
||||||
|
labels_are_2d = len(label_train_shape) == len(label_test_shape) == 2
|
||||||
|
labels_train_are_multihot = self.is_multihot_labels(self.labels_train,
|
||||||
|
'labels_train')
|
||||||
|
labels_test_are_multihot = self.is_multihot_labels(self.labels_test,
|
||||||
|
'labels_test')
|
||||||
|
self.multilabel_data = (
|
||||||
|
data_is_2d and labels_are_2d and equal_feature_count and
|
||||||
|
labels_train_are_multihot and labels_test_are_multihot)
|
||||||
|
return self.multilabel_data
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validates the inputs."""
|
"""Validates the inputs."""
|
||||||
|
@ -413,6 +555,24 @@ class AttackInputData:
|
||||||
'logits_test')
|
'logits_test')
|
||||||
_is_last_dim_equal(self.probs_train, 'probs_train', self.probs_test,
|
_is_last_dim_equal(self.probs_train, 'probs_train', self.probs_test,
|
||||||
'probs_test')
|
'probs_test')
|
||||||
|
|
||||||
|
if self.is_multilabel_data():
|
||||||
|
# Validation for multi-label data.
|
||||||
|
# Verify that both logits and probabilities have the same number of
|
||||||
|
# classes.
|
||||||
|
_is_last_dim_equal(self.logits_train, 'logits_train', self.probs_train,
|
||||||
|
'probs_train')
|
||||||
|
_is_last_dim_equal(self.logits_test, 'logits_test', self.probs_test,
|
||||||
|
'probs_test')
|
||||||
|
# Check that losses, labels and entropies are 2D
|
||||||
|
# (num_samples, num_classes).
|
||||||
|
_is_array_two_dimensional(self.loss_train, 'loss_train')
|
||||||
|
_is_array_two_dimensional(self.loss_test, 'loss_test')
|
||||||
|
_is_array_two_dimensional(self.entropy_train, 'entropy_train')
|
||||||
|
_is_array_two_dimensional(self.entropy_test, 'entropy_test')
|
||||||
|
_is_array_two_dimensional(self.labels_train, 'labels_train')
|
||||||
|
_is_array_two_dimensional(self.labels_test, 'labels_test')
|
||||||
|
else:
|
||||||
_is_array_one_dimensional(self.loss_train, 'loss_train')
|
_is_array_one_dimensional(self.loss_train, 'loss_train')
|
||||||
_is_array_one_dimensional(self.loss_test, 'loss_test')
|
_is_array_one_dimensional(self.loss_test, 'loss_test')
|
||||||
_is_array_one_dimensional(self.entropy_train, 'entropy_train')
|
_is_array_one_dimensional(self.entropy_train, 'entropy_train')
|
||||||
|
@ -784,8 +944,8 @@ class AttackResults:
|
||||||
aucs = [result.get_auc() for result in self.single_attack_results]
|
aucs = [result.get_auc() for result in self.single_attack_results]
|
||||||
|
|
||||||
if min(aucs) < 0.4:
|
if min(aucs) < 0.4:
|
||||||
print('Suspiciously low AUC detected: %.2f. ' +
|
logging.info(('Suspiciously low AUC detected: %.2f. '
|
||||||
'There might be a bug in the classifier' % min(aucs))
|
'There might be a bug in the classifier'), min(aucs))
|
||||||
|
|
||||||
return self.single_attack_results[np.argmax(aucs)]
|
return self.single_attack_results[np.argmax(aucs)]
|
||||||
|
|
||||||
|
|
|
@ -277,6 +277,87 @@ class AttackInputDataTest(parameterized.TestCase):
|
||||||
probs_train=np.array([]),
|
probs_train=np.array([]),
|
||||||
probs_test=np.array([])).validate)
|
probs_test=np.array([])).validate)
|
||||||
|
|
||||||
|
def test_multilabel_validator(self):
|
||||||
|
# Tests for multilabel data.
|
||||||
|
with self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
msg='Validation passes incorrectly when `logits_test` is not 1D/2D.'):
|
||||||
|
AttackInputData(
|
||||||
|
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||||
|
logits_test=np.array([[[0.01, 1.5], [0.5, -3]],
|
||||||
|
[[0.01, 1.5], [0.5, -3]]]),
|
||||||
|
labels_train=np.array([[0, 0], [0, 1], [1, 0]]),
|
||||||
|
labels_test=np.array([[1, 1], [1, 0]]),
|
||||||
|
).validate()
|
||||||
|
self.assertTrue(
|
||||||
|
AttackInputData(
|
||||||
|
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||||
|
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
|
||||||
|
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
|
||||||
|
labels_test=np.array([[1, 1], [1, 0]]),
|
||||||
|
).is_multilabel_data(),
|
||||||
|
msg='Multilabel data check fails even though conditions are met.')
|
||||||
|
|
||||||
|
def test_multihot_labels_check_on_null_array_returns_false(self):
|
||||||
|
self.assertFalse(
|
||||||
|
AttackInputData(
|
||||||
|
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||||
|
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
|
||||||
|
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
|
||||||
|
labels_test=np.array([[1, 1], [1, 0]]),
|
||||||
|
).is_multihot_labels(None, 'null_array'),
|
||||||
|
msg='Multilabel test on a null array should return False.')
|
||||||
|
self.assertFalse(
|
||||||
|
AttackInputData(
|
||||||
|
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||||
|
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
|
||||||
|
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
|
||||||
|
labels_test=np.array([[1, 1], [1, 0]]),
|
||||||
|
).is_multihot_labels(np.array([1.0, 2.0, 3.0]), '1d_array'),
|
||||||
|
msg='Multilabel test on a 1-D array should return False.')
|
||||||
|
|
||||||
|
def test_multilabel_get_bce_loss_from_probs(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||||
|
probs_test=np.array([[0.8, 0.7, 0.9]]),
|
||||||
|
labels_train=np.array([[0, 1, 1], [1, 1, 0]]),
|
||||||
|
labels_test=np.array([[1, 1, 0]]))
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_train(), [[0.22314343, 1.20397247, 0.3566748],
|
||||||
|
[0.22314343, 0.51082546, 2.30258409]],
|
||||||
|
atol=1e-6)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_test(), [[0.22314354, 0.35667493, 2.30258499]],
|
||||||
|
atol=1e-6)
|
||||||
|
|
||||||
|
def test_multilabel_get_bce_loss_from_logits(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||||
|
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
|
||||||
|
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
|
||||||
|
labels_test=np.array([[1, 1], [1, 0]]))
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_train(),
|
||||||
|
[[0.31326167, 0.126928], [0.69815966, 0.20141327],
|
||||||
|
[0.47407697, 3.04858714]],
|
||||||
|
atol=1e-6)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_test(),
|
||||||
|
[[0.68815966, 0.20141327], [0.47407697, 0.04858734]],
|
||||||
|
atol=1e-6)
|
||||||
|
|
||||||
|
def test_multilabel_get_loss_explicitly_provided(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
loss_train=np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]),
|
||||||
|
loss_test=np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
|
||||||
|
|
||||||
|
np.testing.assert_equal(attack_input.get_loss_train().tolist(),
|
||||||
|
np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]))
|
||||||
|
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
||||||
|
np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
|
||||||
|
|
||||||
|
|
||||||
class RocCurveTest(absltest.TestCase):
|
class RocCurveTest(absltest.TestCase):
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -49,10 +50,21 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
|
||||||
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
|
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
|
||||||
result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test)
|
result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test)
|
||||||
|
|
||||||
|
# A slice has the same multilabel status as the original data. This is because
|
||||||
|
# of the way multilabel status is computed. A dataset is multilabel if at
|
||||||
|
# least 1 sample has a label that is multihot encoded with more than one
|
||||||
|
# positive class. A slice of this dataset could have only samples with labels
|
||||||
|
# that have only a single positive class, even if the original dataset were
|
||||||
|
# multilable. Therefore we ensure that the slice inherits the multilabel state
|
||||||
|
# of the original dataset.
|
||||||
|
result.multilabel_data = data.is_multilabel_data()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
|
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
|
||||||
|
if data.is_multilabel_data():
|
||||||
|
raise ValueError("Slicing by class not supported for multilabel data.")
|
||||||
idx_train = data.labels_train == class_value
|
idx_train = data.labels_train == class_value
|
||||||
idx_test = data.labels_test == class_value
|
idx_test = data.labels_test == class_value
|
||||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||||
|
@ -65,6 +77,12 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
|
||||||
# Find from_percentile and to_percentile percentiles in losses.
|
# Find from_percentile and to_percentile percentiles in losses.
|
||||||
loss_train = data.get_loss_train()
|
loss_train = data.get_loss_train()
|
||||||
loss_test = data.get_loss_test()
|
loss_test = data.get_loss_test()
|
||||||
|
if data.is_multilabel_data():
|
||||||
|
logging.info("For multilabel data, when slices by percentiles are "
|
||||||
|
"requested, losses are summed over the class axis before "
|
||||||
|
"slicing.")
|
||||||
|
loss_train = np.sum(loss_train, axis=1)
|
||||||
|
loss_test = np.sum(loss_test, axis=1)
|
||||||
losses = np.concatenate((loss_train, loss_test))
|
losses = np.concatenate((loss_train, loss_test))
|
||||||
from_loss = np.percentile(losses, from_percentile)
|
from_loss = np.percentile(losses, from_percentile)
|
||||||
to_loss = np.percentile(losses, to_percentile)
|
to_loss = np.percentile(losses, to_percentile)
|
||||||
|
@ -82,6 +100,20 @@ def _indices_by_classification(logits_or_probs, labels, correctly_classified):
|
||||||
|
|
||||||
def _slice_by_classification_correctness(data: AttackInputData,
|
def _slice_by_classification_correctness(data: AttackInputData,
|
||||||
correctly_classified: bool):
|
correctly_classified: bool):
|
||||||
|
"""Slices attack inputs by whether they were classified correctly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Data to be used as input to the attack models.
|
||||||
|
correctly_classified: Whether to use the indices corresponding to the
|
||||||
|
correctly classified samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AttackInputData object containing the sliced data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if data.is_multilabel_data():
|
||||||
|
raise ValueError("Slicing by classification correctness not supported for "
|
||||||
|
"multilabel data.")
|
||||||
idx_train = _indices_by_classification(data.logits_or_probs_train,
|
idx_train = _indices_by_classification(data.logits_or_probs_train,
|
||||||
data.labels_train,
|
data.labels_train,
|
||||||
correctly_classified)
|
correctly_classified)
|
||||||
|
|
|
@ -12,7 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
from absl.testing.absltest import mock
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
@ -71,7 +73,7 @@ class SingleSliceSpecsTest(absltest.TestCase):
|
||||||
self.assertTrue(_are_all_fields_equal(output[0], expected0))
|
self.assertTrue(_are_all_fields_equal(output[0], expected0))
|
||||||
self.assertTrue(_are_all_fields_equal(output[5], expected5))
|
self.assertTrue(_are_all_fields_equal(output[5], expected5))
|
||||||
|
|
||||||
def test_slice_by_correcness(self):
|
def test_slice_by_correctness(self):
|
||||||
input_data = SlicingSpec(
|
input_data = SlicingSpec(
|
||||||
entire_dataset=False, by_classification_correctness=True)
|
entire_dataset=False, by_classification_correctness=True)
|
||||||
expected = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True)
|
expected = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True)
|
||||||
|
@ -199,5 +201,83 @@ class GetSliceTest(absltest.TestCase):
|
||||||
self.assertTrue((output.labels_test == [1, 2, 0]).all())
|
self.assertTrue((output.labels_test == [1, 2, 0]).all())
|
||||||
|
|
||||||
|
|
||||||
|
class GetSliceTestForMultilabelData(absltest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self, methodname):
|
||||||
|
"""Initialize the test class."""
|
||||||
|
super().__init__(methodname)
|
||||||
|
|
||||||
|
# Create test data for 3 class multilabel classification task.
|
||||||
|
logits_train = np.array([[0, 1, 0], [2, 0, 3], [4, 5, 0], [6, 7, 0]])
|
||||||
|
logits_test = np.array([[10, 0, 11], [12, 13, 0], [14, 15, 0], [0, 16, 17]])
|
||||||
|
probs_train = np.array([[0, 1, 0], [0.1, 0, 0.7], [0.4, 0.6, 0],
|
||||||
|
[0.3, 0.7, 0]])
|
||||||
|
probs_test = np.array([[0.4, 0, 0.6], [0.1, 0.9, 0], [0.15, 0.85, 0],
|
||||||
|
[0, 0, 1]])
|
||||||
|
labels_train = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0], [0, 1, 0]])
|
||||||
|
labels_test = np.array([[1, 0, 1], [0, 1, 0], [0, 1, 0], [0, 0, 1]])
|
||||||
|
loss_train = np.array([[0, 0, 2], [3, 0, 0.3], [2, 0.5, 0], [1.5, 2, 3]])
|
||||||
|
loss_test = np.array([[1.5, 0, 1.0], [0.5, 0.8, 0], [0.5, 3, 0], [0, 0, 0]])
|
||||||
|
entropy_train = np.array([[0.2, 0.2, 5], [10, 0, 2], [7, 3, 0], [6, 8, 9]])
|
||||||
|
entropy_test = np.array([[3, 0, 2], [3, 6, 0], [2, 6, 0], [0, 0, 0]])
|
||||||
|
|
||||||
|
self.input_data = AttackInputData(
|
||||||
|
logits_train=logits_train,
|
||||||
|
logits_test=logits_test,
|
||||||
|
probs_train=probs_train,
|
||||||
|
probs_test=probs_test,
|
||||||
|
labels_train=labels_train,
|
||||||
|
labels_test=labels_test,
|
||||||
|
loss_train=loss_train,
|
||||||
|
loss_test=loss_test,
|
||||||
|
entropy_train=entropy_train,
|
||||||
|
entropy_test=entropy_test)
|
||||||
|
|
||||||
|
def test_slice_entire_dataset(self):
|
||||||
|
entire_dataset_slice = SingleSliceSpec()
|
||||||
|
output = get_slice(self.input_data, entire_dataset_slice)
|
||||||
|
expected = self.input_data
|
||||||
|
expected.slice_spec = entire_dataset_slice
|
||||||
|
self.assertTrue(_are_all_fields_equal(output, self.input_data))
|
||||||
|
|
||||||
|
def test_slice_by_class_fails(self):
|
||||||
|
class_index = 1
|
||||||
|
class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index)
|
||||||
|
self.assertRaises(ValueError, get_slice, self.input_data, class_slice)
|
||||||
|
|
||||||
|
@mock.patch('logging.Logger.info', wraps=logging.Logger)
|
||||||
|
def test_slice_by_percentile_logs_multilabel_data(self, mock_logger):
|
||||||
|
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
||||||
|
_ = get_slice(self.input_data, percentile_slice)
|
||||||
|
mock_logger.assert_called_with(
|
||||||
|
('For multilabel data, when slices by percentiles are '
|
||||||
|
'requested, losses are summed over the class axis before '
|
||||||
|
'slicing.'))
|
||||||
|
|
||||||
|
def test_slice_by_percentile(self):
|
||||||
|
# 50th percentile is the lower 50% of losses summed over the classes.
|
||||||
|
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
||||||
|
output = get_slice(self.input_data, percentile_slice)
|
||||||
|
|
||||||
|
# Check logits.
|
||||||
|
with self.subTest(msg='Check logits'):
|
||||||
|
self.assertLen(output.logits_train, 2)
|
||||||
|
self.assertLen(output.logits_test, 3)
|
||||||
|
self.assertTrue((output.logits_test[0] == [10, 0, 11]).all())
|
||||||
|
|
||||||
|
# Check labels.
|
||||||
|
with self.subTest(msg='Check labels'):
|
||||||
|
self.assertLen(output.labels_train, 2)
|
||||||
|
self.assertLen(output.labels_test, 3)
|
||||||
|
self.assertTrue((output.labels_train == [[0, 1, 1], [1, 1, 0]]).all())
|
||||||
|
self.assertTrue((output.labels_test == [[1, 0, 1], [0, 1, 0], [0, 0,
|
||||||
|
1]]).all())
|
||||||
|
|
||||||
|
def test_slice_by_correctness_fails(self):
|
||||||
|
percentile_slice = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED,
|
||||||
|
False)
|
||||||
|
self.assertRaises(ValueError, get_slice, self.input_data, percentile_slice)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -17,9 +17,11 @@ This file belongs to the new API for membership inference attacks. This file
|
||||||
will be renamed to membership_inference_attack.py after the old API is removed.
|
will be renamed to membership_inference_attack.py after the old API is removed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Iterable
|
import logging
|
||||||
|
from typing import Iterable, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy import special
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
from sklearn import model_selection
|
from sklearn import model_selection
|
||||||
|
|
||||||
|
@ -39,6 +41,9 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.datase
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.dataset_slicing import get_slice
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.dataset_slicing import get_slice
|
||||||
|
|
||||||
|
|
||||||
|
ArrayLike = Union[np.ndarray, List]
|
||||||
|
|
||||||
|
|
||||||
def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
|
def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
|
||||||
if hasattr(data, 'slice_spec'):
|
if hasattr(data, 'slice_spec'):
|
||||||
return data.slice_spec
|
return data.slice_spec
|
||||||
|
@ -81,7 +86,8 @@ def _run_trained_attack(attack_input: AttackInputData,
|
||||||
|
|
||||||
attacker = models.create_attacker(attack_type)
|
attacker = models.create_attacker(attack_type)
|
||||||
attacker.train_model(features[train_indices], labels[train_indices])
|
attacker.train_model(features[train_indices], labels[train_indices])
|
||||||
scores[test_indices] = attacker.predict(features[test_indices])
|
predictions = attacker.predict(features[test_indices])
|
||||||
|
scores[test_indices] = predictions
|
||||||
|
|
||||||
# Predict the left out with the last attacker
|
# Predict the left out with the last attacker
|
||||||
if left_out_indices.size:
|
if left_out_indices.size:
|
||||||
|
@ -110,6 +116,11 @@ def _run_threshold_attack(attack_input: AttackInputData):
|
||||||
loss_test = attack_input.get_loss_test()
|
loss_test = attack_input.get_loss_test()
|
||||||
if loss_train is None or loss_test is None:
|
if loss_train is None or loss_test is None:
|
||||||
raise ValueError('Not possible to run threshold attack without losses.')
|
raise ValueError('Not possible to run threshold attack without losses.')
|
||||||
|
if attack_input.is_multilabel_data():
|
||||||
|
logging.info(('For multilabel data, when a threshold attack is requested, '
|
||||||
|
'losses are summed over the class axis before slicing.'))
|
||||||
|
loss_train = np.sum(loss_train, axis=1)
|
||||||
|
loss_test = np.sum(loss_test, axis=1)
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
|
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
|
||||||
np.concatenate((loss_train, loss_test)))
|
np.concatenate((loss_train, loss_test)))
|
||||||
|
@ -126,6 +137,10 @@ def _run_threshold_attack(attack_input: AttackInputData):
|
||||||
|
|
||||||
|
|
||||||
def _run_threshold_entropy_attack(attack_input: AttackInputData):
|
def _run_threshold_entropy_attack(attack_input: AttackInputData):
|
||||||
|
"""Runs threshold entropy attack on single label data."""
|
||||||
|
if attack_input.is_multilabel_data():
|
||||||
|
raise NotImplementedError(('Entropy-based attacks are not implemented for '
|
||||||
|
'multilabel data.'))
|
||||||
ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size()
|
ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size()
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
|
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
|
||||||
|
@ -212,6 +227,7 @@ def run_attacks(attack_input: AttackInputData,
|
||||||
for single_slice_spec in input_slice_specs:
|
for single_slice_spec in input_slice_specs:
|
||||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||||
for attack_type in attack_types:
|
for attack_type in attack_types:
|
||||||
|
logging.info('Running attack: %s', attack_type.name)
|
||||||
attack_result = _run_attack(attack_input_slice, attack_type,
|
attack_result = _run_attack(attack_input_slice, attack_type,
|
||||||
balance_attacker_training, min_num_samples)
|
balance_attacker_training, min_num_samples)
|
||||||
if attack_result is not None:
|
if attack_result is not None:
|
||||||
|
@ -318,14 +334,47 @@ def run_membership_probability_analysis(
|
||||||
def _compute_missing_privacy_report_metadata(
|
def _compute_missing_privacy_report_metadata(
|
||||||
metadata: PrivacyReportMetadata,
|
metadata: PrivacyReportMetadata,
|
||||||
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
||||||
"""Populates metadata fields if they are missing."""
|
"""Populates metadata fields if they are missing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: Metadata that is used to create a privacy report based on the
|
||||||
|
attack results.
|
||||||
|
attack_input: The input data used to run a membership attack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new or updated metadata object containing information to create the
|
||||||
|
privacy report.
|
||||||
|
"""
|
||||||
|
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = PrivacyReportMetadata()
|
metadata = PrivacyReportMetadata()
|
||||||
|
if attack_input.is_multilabel_data():
|
||||||
|
accuracy_fn = _get_multilabel_accuracy
|
||||||
|
sigmoid_func = special.expit
|
||||||
|
# Multi label accuracy is calculated with the prediction probabilties and
|
||||||
|
# the labels. A threshold with a default of 0.5 is used to get predicted
|
||||||
|
# labels from the probabilities.
|
||||||
|
if (attack_input.probs_train is None and
|
||||||
|
attack_input.logits_train is not None):
|
||||||
|
logits_or_probs_train = sigmoid_func(attack_input.logits_train)
|
||||||
|
else:
|
||||||
|
logits_or_probs_train = attack_input.probs_train
|
||||||
|
if (attack_input.probs_test is None and
|
||||||
|
attack_input.logits_test is not None):
|
||||||
|
logits_or_probs_test = sigmoid_func(attack_input.logits_test)
|
||||||
|
else:
|
||||||
|
logits_or_probs_test = attack_input.probs_test
|
||||||
|
else:
|
||||||
|
accuracy_fn = _get_accuracy
|
||||||
|
# Single label accuracy is calculated with the argmax of the logits which
|
||||||
|
# is the same as the argmax of the probbilities.
|
||||||
|
logits_or_probs_train = attack_input.logits_or_probs_train
|
||||||
|
logits_or_probs_test = attack_input.logits_or_probs_test
|
||||||
if metadata.accuracy_train is None:
|
if metadata.accuracy_train is None:
|
||||||
metadata.accuracy_train = _get_accuracy(attack_input.logits_train,
|
metadata.accuracy_train = accuracy_fn(logits_or_probs_train,
|
||||||
attack_input.labels_train)
|
attack_input.labels_train)
|
||||||
if metadata.accuracy_test is None:
|
if metadata.accuracy_test is None:
|
||||||
metadata.accuracy_test = _get_accuracy(attack_input.logits_test,
|
metadata.accuracy_test = accuracy_fn(logits_or_probs_test,
|
||||||
attack_input.labels_test)
|
attack_input.labels_test)
|
||||||
loss_train = attack_input.get_loss_train()
|
loss_train = attack_input.get_loss_train()
|
||||||
loss_test = attack_input.get_loss_test()
|
loss_test = attack_input.get_loss_test()
|
||||||
|
@ -341,3 +390,28 @@ def _get_accuracy(logits, labels):
|
||||||
if logits is None or labels is None:
|
if logits is None or labels is None:
|
||||||
return None
|
return None
|
||||||
return metrics.accuracy_score(labels, np.argmax(logits, axis=1))
|
return metrics.accuracy_score(labels, np.argmax(logits, axis=1))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_numpy_binary_accuracy(preds: ArrayLike, labels: ArrayLike):
|
||||||
|
"""Computes the multilabel accuracy at threshold=0.5 using Numpy."""
|
||||||
|
return np.mean(np.equal(labels, np.round(preds)))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_multilabel_accuracy(preds: ArrayLike, labels: ArrayLike):
|
||||||
|
"""Computes the accuracy over multilabel data if it is missing.
|
||||||
|
|
||||||
|
Compute multilabel binary accuracy. AUC is a better measure of model quality
|
||||||
|
for multilabel classification than accuracy, in particular when the classes
|
||||||
|
are imbalanced. For consistency with the single label classification case,
|
||||||
|
we compute and return the binary accuracy over the labels and predictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds: Prediction probabilities.
|
||||||
|
labels: Ground truth multihot labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The binary accuracy averaged across all labels.
|
||||||
|
"""
|
||||||
|
if preds is None or labels is None:
|
||||||
|
return None
|
||||||
|
return _get_numpy_binary_accuracy(preds, labels)
|
||||||
|
|
|
@ -34,6 +34,42 @@ def get_test_input(n_train, n_test):
|
||||||
labels_test=np.array([i % 5 for i in range(n_test)]))
|
labels_test=np.array([i % 5 for i in range(n_test)]))
|
||||||
|
|
||||||
|
|
||||||
|
def get_multihot_labels_for_test(num_samples: int,
|
||||||
|
num_classes: int) -> np.ndarray:
|
||||||
|
"""Generate a array of multihot labels.
|
||||||
|
|
||||||
|
Given an integer 'num_samples', generate a deterministic array of
|
||||||
|
'num_classes'multihot labels. Each multihot label is the list of bits (0/1) of
|
||||||
|
the corresponding row number in the array, upto 'num_classes'. If the value
|
||||||
|
of num_classes < num_samples, then the bit list repeats.
|
||||||
|
e.g. if num_samples=10 and num_classes=3, row=3 corresponds to the label
|
||||||
|
vector [0, 1, 1].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_samples: Number of samples for which to generate test labels.
|
||||||
|
num_classes: Number of classes for which to generate test multihot labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Numpy integer array with rows=number of samples, and columns=length of the
|
||||||
|
bit-representation of the number of samples.
|
||||||
|
"""
|
||||||
|
m = 2**num_classes # Number of unique labels given the number of classes.
|
||||||
|
bit_format = f'0{num_classes}b' # Bit representation format with leading 0s.
|
||||||
|
return np.asarray(
|
||||||
|
[list(format(i % m, bit_format)) for i in range(num_samples)]).astype(int)
|
||||||
|
|
||||||
|
|
||||||
|
def get_multilabel_test_input(n_train, n_test):
|
||||||
|
"""Get example multilabel inputs for attacks."""
|
||||||
|
rng = np.random.RandomState(4)
|
||||||
|
num_classes = max(n_train // 20, 5) # use at least 5 classes.
|
||||||
|
return AttackInputData(
|
||||||
|
logits_train=rng.randn(n_train, num_classes) + 0.2,
|
||||||
|
logits_test=rng.randn(n_test, num_classes) + 0.2,
|
||||||
|
labels_train=get_multihot_labels_for_test(n_train, num_classes),
|
||||||
|
labels_test=get_multihot_labels_for_test(n_test, num_classes))
|
||||||
|
|
||||||
|
|
||||||
def get_test_input_logits_only(n_train, n_test):
|
def get_test_input_logits_only(n_train, n_test):
|
||||||
"""Get example input logits for attacks."""
|
"""Get example input logits for attacks."""
|
||||||
rng = np.random.RandomState(4)
|
rng = np.random.RandomState(4)
|
||||||
|
@ -172,5 +208,60 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
DataSize(ntrain=20, ntest=16))
|
DataSize(ntrain=20, ntest=16))
|
||||||
|
|
||||||
|
|
||||||
|
class RunAttacksTestOnMultilabelData(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_run_attacks_size(self):
|
||||||
|
result = mia.run_attacks(
|
||||||
|
get_multilabel_test_input(100, 100), SlicingSpec(),
|
||||||
|
(AttackType.LOGISTIC_REGRESSION,))
|
||||||
|
|
||||||
|
self.assertLen(result.single_attack_results, 1)
|
||||||
|
|
||||||
|
def test_run_attack_trained_sets_attack_type(self):
|
||||||
|
result = mia._run_attack(
|
||||||
|
get_multilabel_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
||||||
|
|
||||||
|
self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
||||||
|
|
||||||
|
def test_run_attack_threshold_sets_attack_type(self):
|
||||||
|
result = mia._run_attack(
|
||||||
|
get_multilabel_test_input(100, 100), AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
|
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
|
def test_run_attack_threshold_entropy_fails(self):
|
||||||
|
self.assertRaises(NotImplementedError, mia._run_threshold_entropy_attack,
|
||||||
|
get_multilabel_test_input(100, 100))
|
||||||
|
|
||||||
|
def test_run_attack_by_percentiles_slice(self):
|
||||||
|
result = mia.run_attacks(
|
||||||
|
get_multilabel_test_input(100, 100),
|
||||||
|
SlicingSpec(entire_dataset=True, by_class=False, by_percentiles=True),
|
||||||
|
(AttackType.THRESHOLD_ATTACK,))
|
||||||
|
|
||||||
|
# 1 attack on entire dataset, 1 attack each of 10 percentile ranges, total
|
||||||
|
# of 11.
|
||||||
|
self.assertLen(result.single_attack_results, 11)
|
||||||
|
expected_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (20, 30))
|
||||||
|
# First slice (Slice #0) is entire dataset. Hence Slice #3 is the 3rd
|
||||||
|
# percentile range 20-30.
|
||||||
|
self.assertEqual(result.single_attack_results[3].slice_spec, expected_slice)
|
||||||
|
|
||||||
|
def test_numpy_multilabel_accuracy(self):
|
||||||
|
predictions = [[0.5, 0.2, 0.3], [0.1, 0.6, 0.3], [0.5, 0.2, 0.3]]
|
||||||
|
labels = [[1, 0, 0], [0, 1, 1], [1, 0, 1]]
|
||||||
|
# At a threshold=0.5, 5 of the total 9 lables are correct.
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
mia._get_numpy_binary_accuracy(predictions, labels), 5 / 9, places=6)
|
||||||
|
|
||||||
|
def test_multilabel_accuracy(self):
|
||||||
|
predictions = [[0.5, 0.2, 0.3], [0.1, 0.6, 0.3], [0.5, 0.2, 0.3]]
|
||||||
|
labels = [[1, 0, 0], [0, 1, 1], [1, 0, 1]]
|
||||||
|
# At a threshold=0.5, 5 of the total 9 lables are correct.
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
mia._get_multilabel_accuracy(predictions, labels), 5 / 9, places=6)
|
||||||
|
self.assertIsNone(mia._get_accuracy(None, labels))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -47,6 +47,25 @@ class TrainedAttackerTest(absltest.TestCase):
|
||||||
self.assertLen(attacker_data.fold_indices, 5)
|
self.assertLen(attacker_data.fold_indices, 5)
|
||||||
self.assertEmpty(attacker_data.left_out_indices)
|
self.assertEmpty(attacker_data.left_out_indices)
|
||||||
|
|
||||||
|
def test_multilabel_create_attacker_data_loss_and_logits(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
||||||
|
logits_test=np.array([[10, 11], [14, 15]]),
|
||||||
|
labels_train=np.array([[0, 1], [1, 1], [1, 0]]),
|
||||||
|
labels_test=np.array([[1, 0], [1, 1]]),
|
||||||
|
loss_train=np.array([[1, 3], [6, 7], [8, 9]]),
|
||||||
|
loss_test=np.array([[4, 2], [4, 6]]))
|
||||||
|
attacker_data = models.create_attacker_data(attack_input, balance=False)
|
||||||
|
self.assertLen(attacker_data.features_all, 5)
|
||||||
|
self.assertLen(attacker_data.fold_indices, 5)
|
||||||
|
self.assertEmpty(attacker_data.left_out_indices)
|
||||||
|
self.assertEqual(
|
||||||
|
attacker_data.features_all.shape[1],
|
||||||
|
attack_input.logits_train.shape[1] + attack_input.loss_train.shape[1])
|
||||||
|
self.assertTrue(
|
||||||
|
attack_input.is_multilabel_data(),
|
||||||
|
msg='Expected multilabel check to pass.')
|
||||||
|
|
||||||
def test_unbalanced_create_attacker_data_loss_and_logits(self):
|
def test_unbalanced_create_attacker_data_loss_and_logits(self):
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utility functions for membership inference attacks."""
|
"""Utility functions for membership inference attacks."""
|
||||||
|
|
||||||
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import special
|
from scipy import special
|
||||||
|
|
||||||
|
@ -76,3 +77,48 @@ def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
|
||||||
the squared loss of each sample.
|
the squared loss of each sample.
|
||||||
"""
|
"""
|
||||||
return (y_true - y_pred)**2
|
return (y_true - y_pred)**2
|
||||||
|
|
||||||
|
|
||||||
|
def multilabel_bce_loss(labels: np.ndarray,
|
||||||
|
pred: np.ndarray,
|
||||||
|
from_logits=False,
|
||||||
|
small_value=1e-8) -> np.ndarray:
|
||||||
|
"""Computes the per-multi-label-example cross entropy loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
labels: numpy array of shape (num_samples, num_classes). labels[i] is the
|
||||||
|
true multi-hot encoded label (vector) of the i-th sample and each element
|
||||||
|
of the vector is one of {0, 1}.
|
||||||
|
pred: numpy array of shape (num_samples, num_classes). pred[i] is the
|
||||||
|
logits or probability vector of the i-th sample.
|
||||||
|
from_logits: whether `pred` is logits or probability vector.
|
||||||
|
small_value: a scalar. np.log can become -inf if the probability is too
|
||||||
|
close to 0, so the probability is clipped below by small_value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the cross-entropy loss of each sample for each class.
|
||||||
|
"""
|
||||||
|
# Check arrays.
|
||||||
|
if labels.shape[0] != pred.shape[0]:
|
||||||
|
raise ValueError('labels and pred should have the same number of examples,',
|
||||||
|
f'but got {labels.shape[0]} and {pred.shape[0]}.')
|
||||||
|
if not ((labels == 0) | (labels == 1)).all():
|
||||||
|
raise ValueError(
|
||||||
|
'labels should be in {0, 1}. For multi-label classification the labels '
|
||||||
|
'should be multihot encoded.')
|
||||||
|
# Check if labels vectors are multi-label.
|
||||||
|
summed_labels = np.sum(labels, axis=1)
|
||||||
|
if ((summed_labels == 0) | (summed_labels == 1)).all():
|
||||||
|
logging.info(
|
||||||
|
('Labels are one-hot encoded single label. Every sample has at most one'
|
||||||
|
' positive label.'))
|
||||||
|
if not from_logits and ((pred < 0.0) | (pred > 1.0)).any():
|
||||||
|
raise ValueError(('Prediction probabilities are not in [0, 1] and '
|
||||||
|
'`from_logits` is set to False.'))
|
||||||
|
|
||||||
|
# Multi-class multi-label binary cross entropy loss
|
||||||
|
if from_logits:
|
||||||
|
pred = special.expit(pred)
|
||||||
|
bce = labels * np.log(pred + small_value)
|
||||||
|
bce += (1 - labels) * np.log(1 - pred + small_value)
|
||||||
|
return -bce
|
||||||
|
|
|
@ -124,5 +124,62 @@ class TestSquaredLoss(parameterized.TestCase):
|
||||||
np.testing.assert_allclose(loss, expected_loss, atol=1e-7)
|
np.testing.assert_allclose(loss, expected_loss, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultilabelBCELoss(parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('probs_example1', np.array(
|
||||||
|
[[0, 1, 1], [1, 1, 0]]), np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||||
|
np.array([[0.22314343, 1.20397247, 0.3566748],
|
||||||
|
[0.22314343, 0.51082546, 2.30258409]]), False),
|
||||||
|
('probs_example2', np.array([[0, 1, 0], [1, 1, 0]]),
|
||||||
|
np.array([[0.01, 0.02, 0.04], [0.8, 0.7, 0.9]]),
|
||||||
|
np.array([[0.01005033, 3.91202251, 0.04082198],
|
||||||
|
[0.22314354, 0.35667493, 2.30258499]]), False),
|
||||||
|
('logits_example1', np.array([[0, 1, 1], [1, 1, 0]]),
|
||||||
|
np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]),
|
||||||
|
np.array([[0.26328245, 0.85435522, 0.11551951],
|
||||||
|
[0.69314716, 0.47407697, 1.70141322]]), True),
|
||||||
|
('logits_example2', np.array([[0, 1, 0], [1, 1, 0]]),
|
||||||
|
np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]),
|
||||||
|
np.array([[0.26328245, 0.85435522, 2.21551943],
|
||||||
|
[0.69314716, 0.47407697, 1.70141322]]), True),
|
||||||
|
)
|
||||||
|
def test_multilabel_bce_loss(self, label, pred, expected_losses, from_logits):
|
||||||
|
loss = utils.multilabel_bce_loss(label, pred, from_logits=from_logits)
|
||||||
|
np.testing.assert_allclose(loss, expected_losses, atol=1e-6)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('from_logits_true_and_incorrect_values_example1',
|
||||||
|
np.array([[0, 1, 1], [1, 1, 0]
|
||||||
|
]), np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||||
|
np.array([[0.22314343, 1.20397247, 0.3566748],
|
||||||
|
[0.22314343, 0.51082546, 2.30258409]]), True),
|
||||||
|
('from_logits_true_and_incorrect_values_example2',
|
||||||
|
np.array([[0, 1, 0], [1, 1, 0]
|
||||||
|
]), np.array([[0.01, 0.02, 0.04], [0.8, 0.7, 0.9]]),
|
||||||
|
np.array([[0.01005033, 3.91202251, 0.04082198],
|
||||||
|
[0.22314354, 0.35667493, 2.30258499]]), True),
|
||||||
|
)
|
||||||
|
def test_multilabel_bce_loss_incorrect_value(self, label, pred,
|
||||||
|
expected_losses, from_logits):
|
||||||
|
loss = utils.multilabel_bce_loss(label, pred, from_logits=from_logits)
|
||||||
|
self.assertFalse(np.allclose(loss, expected_losses))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('from_logits_false_and_pred_not_in_0to1_example1',
|
||||||
|
np.array([[0, 1, 1], [1, 1, 0]
|
||||||
|
]), np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]), False,
|
||||||
|
(r'Prediction probabilities are not in \[0, 1\] and '
|
||||||
|
'`from_logits` is set to False.')),
|
||||||
|
('labels_not_0_or_1', np.array([[0, 1, 0], [1, 2, 0]]),
|
||||||
|
np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]), False,
|
||||||
|
('labels should be in {0, 1}. For multi-label classification the labels '
|
||||||
|
'should be multihot encoded.')),
|
||||||
|
)
|
||||||
|
def test_multilabel_bce_loss_raises(self, label, pred, from_logits, regex):
|
||||||
|
self.assertRaisesRegex(ValueError, regex, utils.multilabel_bce_loss, label,
|
||||||
|
pred, from_logits)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue