forked from 626_privacy/tensorflow_privacy
Add multi-label support for Tensorflow Privacy membership attacks.
PiperOrigin-RevId: 443176652
This commit is contained in:
parent
e14618fe7c
commit
ee35642b90
9 changed files with 676 additions and 36 deletions
|
@ -17,6 +17,7 @@ import collections
|
|||
import dataclasses
|
||||
import enum
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Callable, Iterable, MutableSequence, Optional, Union
|
||||
|
@ -115,6 +116,7 @@ class AttackType(enum.Enum):
|
|||
K_NEAREST_NEIGHBORS = 'knn'
|
||||
THRESHOLD_ATTACK = 'threshold'
|
||||
THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy'
|
||||
TF_LOGISTIC_REGRESSION = 'tf_lr'
|
||||
|
||||
@property
|
||||
def is_trained_attack(self):
|
||||
|
@ -154,6 +156,19 @@ def _is_array_one_dimensional(arr, arr_name):
|
|||
raise ValueError('%s should be a one dimensional numpy array.' % arr_name)
|
||||
|
||||
|
||||
def _is_array_two_dimensional(arr, arr_name):
|
||||
"""Checks whether the array is two dimensional."""
|
||||
if arr is not None and len(arr.shape) != 2:
|
||||
raise ValueError('%s should be a two dimensional numpy array.' % arr_name)
|
||||
|
||||
|
||||
def _is_array_one_or_two_dimensional(arr, arr_name):
|
||||
"""Checks whether the array is one or two dimensional."""
|
||||
if arr is not None and len(arr.shape) not in [1, 2]:
|
||||
raise ValueError(
|
||||
('%s should be a one or two dimensional numpy array.' % arr_name))
|
||||
|
||||
|
||||
def _is_np_array(arr, arr_name):
|
||||
"""Checks whether array is a numpy array."""
|
||||
if arr is not None and not isinstance(arr, np.ndarray):
|
||||
|
@ -213,6 +228,14 @@ class AttackInputData:
|
|||
# preferred when both are available.
|
||||
loss_function_using_logits: Optional[bool] = None
|
||||
|
||||
# If the problem is a multilabel classification problem. If this is set then
|
||||
# the loss function and attack data construction are adjusted accordingly. In
|
||||
# this case the provided labels must be multi-hot encoded. That is, the labels
|
||||
# are an array of shape (num_examples, num_classes) with 0s where the
|
||||
# corresponding class is absent from the example, and 1s where the
|
||||
# corresponding class is present.
|
||||
multilabel_data: Optional[bool] = None
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
if self.labels_train is None or self.labels_test is None:
|
||||
|
@ -266,12 +289,12 @@ class AttackInputData:
|
|||
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
|
||||
|
||||
@staticmethod
|
||||
def _get_loss(
|
||||
loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
||||
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
|
||||
LossFunction],
|
||||
loss_function_using_logits: Optional[bool]) -> Optional[np.ndarray]:
|
||||
def _get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
||||
loss_function: Union[Callable[[np.ndarray, np.ndarray],
|
||||
np.ndarray], LossFunction],
|
||||
loss_function_using_logits: Optional[bool],
|
||||
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
|
||||
"""Calculates (if needed) losses.
|
||||
|
||||
Args:
|
||||
|
@ -284,6 +307,7 @@ class AttackInputData:
|
|||
is supposed to take in (label, logits / probs) as input.
|
||||
loss_function_using_logits: if `loss_function` expects `logits` or
|
||||
`probs`.
|
||||
multilabel_data: if the data is from a multilabel classification problem.
|
||||
|
||||
Returns:
|
||||
Loss (or None if neither the loss nor the labels are present).
|
||||
|
@ -299,13 +323,23 @@ class AttackInputData:
|
|||
|
||||
predictions = logits if loss_function_using_logits else probs
|
||||
if loss_function == LossFunction.CROSS_ENTROPY:
|
||||
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
|
||||
if multilabel_data:
|
||||
loss = utils.multilabel_bce_loss(labels, predictions,
|
||||
loss_function_using_logits)
|
||||
else:
|
||||
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
|
||||
elif loss_function == LossFunction.SQUARED:
|
||||
loss = utils.squared_loss(labels, predictions)
|
||||
else:
|
||||
loss = loss_function(labels, predictions)
|
||||
return loss
|
||||
|
||||
def __post_init__(self):
|
||||
"""Checks performed after instantiation of the AttackInputData dataclass."""
|
||||
# Check if the data is multilabel.
|
||||
_ = self.is_multilabel_data()
|
||||
# The validate() check is called as needed, so is not called here.
|
||||
|
||||
def get_loss_train(self):
|
||||
"""Calculates (if needed) cross-entropy losses for the training set.
|
||||
|
||||
|
@ -316,7 +350,7 @@ class AttackInputData:
|
|||
self.loss_function_using_logits = (self.logits_train is not None)
|
||||
return self._get_loss(self.loss_train, self.labels_train, self.logits_train,
|
||||
self.probs_train, self.loss_function,
|
||||
self.loss_function_using_logits)
|
||||
self.loss_function_using_logits, self.multilabel_data)
|
||||
|
||||
def get_loss_test(self):
|
||||
"""Calculates (if needed) cross-entropy losses for the test set.
|
||||
|
@ -328,35 +362,143 @@ class AttackInputData:
|
|||
self.loss_function_using_logits = bool(self.logits_test)
|
||||
return self._get_loss(self.loss_test, self.labels_test, self.logits_test,
|
||||
self.probs_test, self.loss_function,
|
||||
self.loss_function_using_logits)
|
||||
self.loss_function_using_logits, self.multilabel_data)
|
||||
|
||||
def get_entropy_train(self):
|
||||
"""Calculates prediction entropy for the training set."""
|
||||
if self.is_multilabel_data():
|
||||
# Not implemented for multilabel data.
|
||||
raise NotImplementedError('Computation of the entropy is not '
|
||||
'applicable for multi-label data.')
|
||||
if self.entropy_train is not None:
|
||||
return self.entropy_train
|
||||
return self._get_entropy(self.logits_train, self.labels_train)
|
||||
|
||||
def get_entropy_test(self):
|
||||
"""Calculates prediction entropy for the test set."""
|
||||
if self.is_multilabel_data():
|
||||
# Not implemented for multilabel data.
|
||||
raise NotImplementedError('Computation of the entropy is not '
|
||||
'applicable for multi-label data.')
|
||||
if self.entropy_test is not None:
|
||||
return self.entropy_test
|
||||
return self._get_entropy(self.logits_test, self.labels_test)
|
||||
|
||||
def get_train_size(self):
|
||||
"""Returns size of the training set."""
|
||||
def get_train_shape(self):
|
||||
"""Returns the shape of the training set."""
|
||||
if self.loss_train is not None:
|
||||
return self.loss_train.size
|
||||
return self.loss_train.shape
|
||||
if self.entropy_train is not None:
|
||||
return self.entropy_train.size
|
||||
return self.logits_or_probs_train.shape[0]
|
||||
return self.entropy_train.shape
|
||||
return self.logits_or_probs_train.shape
|
||||
|
||||
def get_test_shape(self):
|
||||
"""Returns the shape of the test set."""
|
||||
if self.loss_test is not None:
|
||||
return self.loss_test.shape
|
||||
if self.entropy_test is not None:
|
||||
return self.entropy_test.shape
|
||||
return self.logits_or_probs_test.shape
|
||||
|
||||
def get_train_size(self):
|
||||
"""Returns the number of examples of the training set."""
|
||||
return self.get_train_shape()[0]
|
||||
|
||||
def get_test_size(self):
|
||||
"""Returns size of the test set."""
|
||||
if self.loss_test is not None:
|
||||
return self.loss_test.size
|
||||
if self.entropy_test is not None:
|
||||
return self.entropy_test.size
|
||||
return self.logits_or_probs_test.shape[0]
|
||||
"""Returns the number of examples of the test set."""
|
||||
return self.get_test_shape()[0]
|
||||
|
||||
def is_multihot_labels(self, arr, arr_name) -> bool:
|
||||
"""Check if the 2D array is multihot, with values in [0, 1].
|
||||
|
||||
Array is multihot if the sum along the classes axis (axis=1) produces a
|
||||
vector of at least one value > 1.
|
||||
|
||||
Args:
|
||||
arr: Array to test.
|
||||
arr_name: Name of the array.
|
||||
|
||||
Returns:
|
||||
True if the array is a 2D multihot array.
|
||||
|
||||
Raises:
|
||||
ValueError if the array is not 2D.
|
||||
"""
|
||||
if arr is None:
|
||||
return False
|
||||
elif len(arr.shape) > 2:
|
||||
raise ValueError(f'Array {arr_name} is not 2D, cannot determine whether '
|
||||
'it is multihot.')
|
||||
elif len(arr.shape) == 1:
|
||||
return False
|
||||
summed_arr = np.sum(arr, axis=1)
|
||||
return not ((summed_arr == 0) | (summed_arr == 1)).all()
|
||||
|
||||
def is_multilabel_data(self) -> bool:
|
||||
"""Check if the provided data is for a multilabel classification problem.
|
||||
|
||||
Data is multilabel if all of the following are true:
|
||||
1. Train and test sizes are 2-dimensional: (num_samples, num_classes)
|
||||
2. Label size is 2-dimentionsl: (num_samples, num_classes)
|
||||
3. The labels are multi-hot. The labels are multihot if sum{label_tensor}
|
||||
along axis=1 (the classes axis) yields a vector of at least one
|
||||
value > 1.
|
||||
|
||||
Returns:
|
||||
Whether the provided data is multilabel.
|
||||
|
||||
Raises:
|
||||
ValueError if the dimensionality of the train and test data are not equal.
|
||||
"""
|
||||
# If the data has already been checked for multihot encoded labels, then
|
||||
# return the result of the evaluation.
|
||||
if self.multilabel_data is not None:
|
||||
return self.multilabel_data
|
||||
|
||||
# If one of probs or logits are not provided, or labels are not provided,
|
||||
# this is not a multilabel problem
|
||||
if (self.logits_or_probs_train is None or
|
||||
self.logits_or_probs_test is None or self.labels_train is None or
|
||||
self.labels_test is None):
|
||||
self.multilabel_data = False
|
||||
return self.multilabel_data
|
||||
|
||||
train_shape = self.get_train_shape()
|
||||
test_shape = self.get_test_shape()
|
||||
label_train_shape = self.labels_train.shape
|
||||
label_test_shape = self.labels_test.shape
|
||||
|
||||
if len(train_shape) != len(test_shape):
|
||||
raise ValueError('The number of dimensions of the train data '
|
||||
f'({train_shape}) is not the same as that of the test '
|
||||
f'data ({test_shape}).')
|
||||
if len(train_shape) not in [1, 2] or len(test_shape) not in [1, 2]:
|
||||
raise ValueError(('Train and test data shapes must be 1-D '
|
||||
'(number of samples) or 2-D (number of samples, '
|
||||
'number of classes).'))
|
||||
if len(label_train_shape) != len(label_test_shape):
|
||||
raise ValueError('The number of dimensions of the train labels '
|
||||
f'({label_train_shape}) is not the same as that of the '
|
||||
f'test labels ({label_test_shape}).')
|
||||
if (len(label_train_shape) not in [1, 2] or
|
||||
len(label_test_shape) not in [1, 2]):
|
||||
raise ValueError('Train and test labels shapes must be 1-D '
|
||||
'(number of samples) or 2-D (number of samples, '
|
||||
'number of classes).')
|
||||
data_is_2d = len(train_shape) == len(test_shape) == 2
|
||||
if data_is_2d:
|
||||
equal_feature_count = train_shape[1] == test_shape[1]
|
||||
else:
|
||||
equal_feature_count = False
|
||||
labels_are_2d = len(label_train_shape) == len(label_test_shape) == 2
|
||||
labels_train_are_multihot = self.is_multihot_labels(self.labels_train,
|
||||
'labels_train')
|
||||
labels_test_are_multihot = self.is_multihot_labels(self.labels_test,
|
||||
'labels_test')
|
||||
self.multilabel_data = (
|
||||
data_is_2d and labels_are_2d and equal_feature_count and
|
||||
labels_train_are_multihot and labels_test_are_multihot)
|
||||
return self.multilabel_data
|
||||
|
||||
def validate(self):
|
||||
"""Validates the inputs."""
|
||||
|
@ -413,12 +555,30 @@ class AttackInputData:
|
|||
'logits_test')
|
||||
_is_last_dim_equal(self.probs_train, 'probs_train', self.probs_test,
|
||||
'probs_test')
|
||||
_is_array_one_dimensional(self.loss_train, 'loss_train')
|
||||
_is_array_one_dimensional(self.loss_test, 'loss_test')
|
||||
_is_array_one_dimensional(self.entropy_train, 'entropy_train')
|
||||
_is_array_one_dimensional(self.entropy_test, 'entropy_test')
|
||||
_is_array_one_dimensional(self.labels_train, 'labels_train')
|
||||
_is_array_one_dimensional(self.labels_test, 'labels_test')
|
||||
|
||||
if self.is_multilabel_data():
|
||||
# Validation for multi-label data.
|
||||
# Verify that both logits and probabilities have the same number of
|
||||
# classes.
|
||||
_is_last_dim_equal(self.logits_train, 'logits_train', self.probs_train,
|
||||
'probs_train')
|
||||
_is_last_dim_equal(self.logits_test, 'logits_test', self.probs_test,
|
||||
'probs_test')
|
||||
# Check that losses, labels and entropies are 2D
|
||||
# (num_samples, num_classes).
|
||||
_is_array_two_dimensional(self.loss_train, 'loss_train')
|
||||
_is_array_two_dimensional(self.loss_test, 'loss_test')
|
||||
_is_array_two_dimensional(self.entropy_train, 'entropy_train')
|
||||
_is_array_two_dimensional(self.entropy_test, 'entropy_test')
|
||||
_is_array_two_dimensional(self.labels_train, 'labels_train')
|
||||
_is_array_two_dimensional(self.labels_test, 'labels_test')
|
||||
else:
|
||||
_is_array_one_dimensional(self.loss_train, 'loss_train')
|
||||
_is_array_one_dimensional(self.loss_test, 'loss_test')
|
||||
_is_array_one_dimensional(self.entropy_train, 'entropy_train')
|
||||
_is_array_one_dimensional(self.entropy_test, 'entropy_test')
|
||||
_is_array_one_dimensional(self.labels_train, 'labels_train')
|
||||
_is_array_one_dimensional(self.labels_test, 'labels_test')
|
||||
|
||||
def __str__(self):
|
||||
"""Return the shapes of variables that are not None."""
|
||||
|
@ -784,8 +944,8 @@ class AttackResults:
|
|||
aucs = [result.get_auc() for result in self.single_attack_results]
|
||||
|
||||
if min(aucs) < 0.4:
|
||||
print('Suspiciously low AUC detected: %.2f. ' +
|
||||
'There might be a bug in the classifier' % min(aucs))
|
||||
logging.info(('Suspiciously low AUC detected: %.2f. '
|
||||
'There might be a bug in the classifier'), min(aucs))
|
||||
|
||||
return self.single_attack_results[np.argmax(aucs)]
|
||||
|
||||
|
|
|
@ -277,6 +277,87 @@ class AttackInputDataTest(parameterized.TestCase):
|
|||
probs_train=np.array([]),
|
||||
probs_test=np.array([])).validate)
|
||||
|
||||
def test_multilabel_validator(self):
|
||||
# Tests for multilabel data.
|
||||
with self.assertRaises(
|
||||
ValueError,
|
||||
msg='Validation passes incorrectly when `logits_test` is not 1D/2D.'):
|
||||
AttackInputData(
|
||||
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||
logits_test=np.array([[[0.01, 1.5], [0.5, -3]],
|
||||
[[0.01, 1.5], [0.5, -3]]]),
|
||||
labels_train=np.array([[0, 0], [0, 1], [1, 0]]),
|
||||
labels_test=np.array([[1, 1], [1, 0]]),
|
||||
).validate()
|
||||
self.assertTrue(
|
||||
AttackInputData(
|
||||
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
|
||||
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
|
||||
labels_test=np.array([[1, 1], [1, 0]]),
|
||||
).is_multilabel_data(),
|
||||
msg='Multilabel data check fails even though conditions are met.')
|
||||
|
||||
def test_multihot_labels_check_on_null_array_returns_false(self):
|
||||
self.assertFalse(
|
||||
AttackInputData(
|
||||
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
|
||||
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
|
||||
labels_test=np.array([[1, 1], [1, 0]]),
|
||||
).is_multihot_labels(None, 'null_array'),
|
||||
msg='Multilabel test on a null array should return False.')
|
||||
self.assertFalse(
|
||||
AttackInputData(
|
||||
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
|
||||
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
|
||||
labels_test=np.array([[1, 1], [1, 0]]),
|
||||
).is_multihot_labels(np.array([1.0, 2.0, 3.0]), '1d_array'),
|
||||
msg='Multilabel test on a 1-D array should return False.')
|
||||
|
||||
def test_multilabel_get_bce_loss_from_probs(self):
|
||||
attack_input = AttackInputData(
|
||||
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||
probs_test=np.array([[0.8, 0.7, 0.9]]),
|
||||
labels_train=np.array([[0, 1, 1], [1, 1, 0]]),
|
||||
labels_test=np.array([[1, 1, 0]]))
|
||||
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_train(), [[0.22314343, 1.20397247, 0.3566748],
|
||||
[0.22314343, 0.51082546, 2.30258409]],
|
||||
atol=1e-6)
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_test(), [[0.22314354, 0.35667493, 2.30258499]],
|
||||
atol=1e-6)
|
||||
|
||||
def test_multilabel_get_bce_loss_from_logits(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
|
||||
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
|
||||
labels_test=np.array([[1, 1], [1, 0]]))
|
||||
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_train(),
|
||||
[[0.31326167, 0.126928], [0.69815966, 0.20141327],
|
||||
[0.47407697, 3.04858714]],
|
||||
atol=1e-6)
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_test(),
|
||||
[[0.68815966, 0.20141327], [0.47407697, 0.04858734]],
|
||||
atol=1e-6)
|
||||
|
||||
def test_multilabel_get_loss_explicitly_provided(self):
|
||||
attack_input = AttackInputData(
|
||||
loss_train=np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]),
|
||||
loss_test=np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
|
||||
|
||||
np.testing.assert_equal(attack_input.get_loss_train().tolist(),
|
||||
np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]))
|
||||
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
||||
np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
|
||||
|
||||
|
||||
class RocCurveTest(absltest.TestCase):
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import collections
|
||||
import copy
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -49,10 +50,21 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
|
|||
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
|
||||
result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test)
|
||||
|
||||
# A slice has the same multilabel status as the original data. This is because
|
||||
# of the way multilabel status is computed. A dataset is multilabel if at
|
||||
# least 1 sample has a label that is multihot encoded with more than one
|
||||
# positive class. A slice of this dataset could have only samples with labels
|
||||
# that have only a single positive class, even if the original dataset were
|
||||
# multilable. Therefore we ensure that the slice inherits the multilabel state
|
||||
# of the original dataset.
|
||||
result.multilabel_data = data.is_multilabel_data()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _slice_by_class(data: AttackInputData, class_value: int) -> AttackInputData:
|
||||
if data.is_multilabel_data():
|
||||
raise ValueError("Slicing by class not supported for multilabel data.")
|
||||
idx_train = data.labels_train == class_value
|
||||
idx_test = data.labels_test == class_value
|
||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||
|
@ -65,6 +77,12 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
|
|||
# Find from_percentile and to_percentile percentiles in losses.
|
||||
loss_train = data.get_loss_train()
|
||||
loss_test = data.get_loss_test()
|
||||
if data.is_multilabel_data():
|
||||
logging.info("For multilabel data, when slices by percentiles are "
|
||||
"requested, losses are summed over the class axis before "
|
||||
"slicing.")
|
||||
loss_train = np.sum(loss_train, axis=1)
|
||||
loss_test = np.sum(loss_test, axis=1)
|
||||
losses = np.concatenate((loss_train, loss_test))
|
||||
from_loss = np.percentile(losses, from_percentile)
|
||||
to_loss = np.percentile(losses, to_percentile)
|
||||
|
@ -82,6 +100,20 @@ def _indices_by_classification(logits_or_probs, labels, correctly_classified):
|
|||
|
||||
def _slice_by_classification_correctness(data: AttackInputData,
|
||||
correctly_classified: bool):
|
||||
"""Slices attack inputs by whether they were classified correctly.
|
||||
|
||||
Args:
|
||||
data: Data to be used as input to the attack models.
|
||||
correctly_classified: Whether to use the indices corresponding to the
|
||||
correctly classified samples.
|
||||
|
||||
Returns:
|
||||
AttackInputData object containing the sliced data.
|
||||
"""
|
||||
|
||||
if data.is_multilabel_data():
|
||||
raise ValueError("Slicing by classification correctness not supported for "
|
||||
"multilabel data.")
|
||||
idx_train = _indices_by_classification(data.logits_or_probs_train,
|
||||
data.labels_train,
|
||||
correctly_classified)
|
||||
|
|
|
@ -12,7 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from absl.testing import absltest
|
||||
from absl.testing.absltest import mock
|
||||
import numpy as np
|
||||
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
||||
|
@ -71,7 +73,7 @@ class SingleSliceSpecsTest(absltest.TestCase):
|
|||
self.assertTrue(_are_all_fields_equal(output[0], expected0))
|
||||
self.assertTrue(_are_all_fields_equal(output[5], expected5))
|
||||
|
||||
def test_slice_by_correcness(self):
|
||||
def test_slice_by_correctness(self):
|
||||
input_data = SlicingSpec(
|
||||
entire_dataset=False, by_classification_correctness=True)
|
||||
expected = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True)
|
||||
|
@ -199,5 +201,83 @@ class GetSliceTest(absltest.TestCase):
|
|||
self.assertTrue((output.labels_test == [1, 2, 0]).all())
|
||||
|
||||
|
||||
class GetSliceTestForMultilabelData(absltest.TestCase):
|
||||
|
||||
def __init__(self, methodname):
|
||||
"""Initialize the test class."""
|
||||
super().__init__(methodname)
|
||||
|
||||
# Create test data for 3 class multilabel classification task.
|
||||
logits_train = np.array([[0, 1, 0], [2, 0, 3], [4, 5, 0], [6, 7, 0]])
|
||||
logits_test = np.array([[10, 0, 11], [12, 13, 0], [14, 15, 0], [0, 16, 17]])
|
||||
probs_train = np.array([[0, 1, 0], [0.1, 0, 0.7], [0.4, 0.6, 0],
|
||||
[0.3, 0.7, 0]])
|
||||
probs_test = np.array([[0.4, 0, 0.6], [0.1, 0.9, 0], [0.15, 0.85, 0],
|
||||
[0, 0, 1]])
|
||||
labels_train = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0], [0, 1, 0]])
|
||||
labels_test = np.array([[1, 0, 1], [0, 1, 0], [0, 1, 0], [0, 0, 1]])
|
||||
loss_train = np.array([[0, 0, 2], [3, 0, 0.3], [2, 0.5, 0], [1.5, 2, 3]])
|
||||
loss_test = np.array([[1.5, 0, 1.0], [0.5, 0.8, 0], [0.5, 3, 0], [0, 0, 0]])
|
||||
entropy_train = np.array([[0.2, 0.2, 5], [10, 0, 2], [7, 3, 0], [6, 8, 9]])
|
||||
entropy_test = np.array([[3, 0, 2], [3, 6, 0], [2, 6, 0], [0, 0, 0]])
|
||||
|
||||
self.input_data = AttackInputData(
|
||||
logits_train=logits_train,
|
||||
logits_test=logits_test,
|
||||
probs_train=probs_train,
|
||||
probs_test=probs_test,
|
||||
labels_train=labels_train,
|
||||
labels_test=labels_test,
|
||||
loss_train=loss_train,
|
||||
loss_test=loss_test,
|
||||
entropy_train=entropy_train,
|
||||
entropy_test=entropy_test)
|
||||
|
||||
def test_slice_entire_dataset(self):
|
||||
entire_dataset_slice = SingleSliceSpec()
|
||||
output = get_slice(self.input_data, entire_dataset_slice)
|
||||
expected = self.input_data
|
||||
expected.slice_spec = entire_dataset_slice
|
||||
self.assertTrue(_are_all_fields_equal(output, self.input_data))
|
||||
|
||||
def test_slice_by_class_fails(self):
|
||||
class_index = 1
|
||||
class_slice = SingleSliceSpec(SlicingFeature.CLASS, class_index)
|
||||
self.assertRaises(ValueError, get_slice, self.input_data, class_slice)
|
||||
|
||||
@mock.patch('logging.Logger.info', wraps=logging.Logger)
|
||||
def test_slice_by_percentile_logs_multilabel_data(self, mock_logger):
|
||||
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
||||
_ = get_slice(self.input_data, percentile_slice)
|
||||
mock_logger.assert_called_with(
|
||||
('For multilabel data, when slices by percentiles are '
|
||||
'requested, losses are summed over the class axis before '
|
||||
'slicing.'))
|
||||
|
||||
def test_slice_by_percentile(self):
|
||||
# 50th percentile is the lower 50% of losses summed over the classes.
|
||||
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
||||
output = get_slice(self.input_data, percentile_slice)
|
||||
|
||||
# Check logits.
|
||||
with self.subTest(msg='Check logits'):
|
||||
self.assertLen(output.logits_train, 2)
|
||||
self.assertLen(output.logits_test, 3)
|
||||
self.assertTrue((output.logits_test[0] == [10, 0, 11]).all())
|
||||
|
||||
# Check labels.
|
||||
with self.subTest(msg='Check labels'):
|
||||
self.assertLen(output.labels_train, 2)
|
||||
self.assertLen(output.labels_test, 3)
|
||||
self.assertTrue((output.labels_train == [[0, 1, 1], [1, 1, 0]]).all())
|
||||
self.assertTrue((output.labels_test == [[1, 0, 1], [0, 1, 0], [0, 0,
|
||||
1]]).all())
|
||||
|
||||
def test_slice_by_correctness_fails(self):
|
||||
percentile_slice = SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED,
|
||||
False)
|
||||
self.assertRaises(ValueError, get_slice, self.input_data, percentile_slice)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -17,9 +17,11 @@ This file belongs to the new API for membership inference attacks. This file
|
|||
will be renamed to membership_inference_attack.py after the old API is removed.
|
||||
"""
|
||||
|
||||
from typing import Iterable
|
||||
import logging
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
import numpy as np
|
||||
from scipy import special
|
||||
from sklearn import metrics
|
||||
from sklearn import model_selection
|
||||
|
||||
|
@ -39,6 +41,9 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.datase
|
|||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.dataset_slicing import get_slice
|
||||
|
||||
|
||||
ArrayLike = Union[np.ndarray, List]
|
||||
|
||||
|
||||
def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
|
||||
if hasattr(data, 'slice_spec'):
|
||||
return data.slice_spec
|
||||
|
@ -81,7 +86,8 @@ def _run_trained_attack(attack_input: AttackInputData,
|
|||
|
||||
attacker = models.create_attacker(attack_type)
|
||||
attacker.train_model(features[train_indices], labels[train_indices])
|
||||
scores[test_indices] = attacker.predict(features[test_indices])
|
||||
predictions = attacker.predict(features[test_indices])
|
||||
scores[test_indices] = predictions
|
||||
|
||||
# Predict the left out with the last attacker
|
||||
if left_out_indices.size:
|
||||
|
@ -110,6 +116,11 @@ def _run_threshold_attack(attack_input: AttackInputData):
|
|||
loss_test = attack_input.get_loss_test()
|
||||
if loss_train is None or loss_test is None:
|
||||
raise ValueError('Not possible to run threshold attack without losses.')
|
||||
if attack_input.is_multilabel_data():
|
||||
logging.info(('For multilabel data, when a threshold attack is requested, '
|
||||
'losses are summed over the class axis before slicing.'))
|
||||
loss_train = np.sum(loss_train, axis=1)
|
||||
loss_test = np.sum(loss_test, axis=1)
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
|
||||
np.concatenate((loss_train, loss_test)))
|
||||
|
@ -126,6 +137,10 @@ def _run_threshold_attack(attack_input: AttackInputData):
|
|||
|
||||
|
||||
def _run_threshold_entropy_attack(attack_input: AttackInputData):
|
||||
"""Runs threshold entropy attack on single label data."""
|
||||
if attack_input.is_multilabel_data():
|
||||
raise NotImplementedError(('Entropy-based attacks are not implemented for '
|
||||
'multilabel data.'))
|
||||
ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size()
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
|
||||
|
@ -212,6 +227,7 @@ def run_attacks(attack_input: AttackInputData,
|
|||
for single_slice_spec in input_slice_specs:
|
||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||
for attack_type in attack_types:
|
||||
logging.info('Running attack: %s', attack_type.name)
|
||||
attack_result = _run_attack(attack_input_slice, attack_type,
|
||||
balance_attacker_training, min_num_samples)
|
||||
if attack_result is not None:
|
||||
|
@ -318,15 +334,48 @@ def run_membership_probability_analysis(
|
|||
def _compute_missing_privacy_report_metadata(
|
||||
metadata: PrivacyReportMetadata,
|
||||
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
||||
"""Populates metadata fields if they are missing."""
|
||||
"""Populates metadata fields if they are missing.
|
||||
|
||||
Args:
|
||||
metadata: Metadata that is used to create a privacy report based on the
|
||||
attack results.
|
||||
attack_input: The input data used to run a membership attack.
|
||||
|
||||
Returns:
|
||||
A new or updated metadata object containing information to create the
|
||||
privacy report.
|
||||
"""
|
||||
|
||||
if metadata is None:
|
||||
metadata = PrivacyReportMetadata()
|
||||
if attack_input.is_multilabel_data():
|
||||
accuracy_fn = _get_multilabel_accuracy
|
||||
sigmoid_func = special.expit
|
||||
# Multi label accuracy is calculated with the prediction probabilties and
|
||||
# the labels. A threshold with a default of 0.5 is used to get predicted
|
||||
# labels from the probabilities.
|
||||
if (attack_input.probs_train is None and
|
||||
attack_input.logits_train is not None):
|
||||
logits_or_probs_train = sigmoid_func(attack_input.logits_train)
|
||||
else:
|
||||
logits_or_probs_train = attack_input.probs_train
|
||||
if (attack_input.probs_test is None and
|
||||
attack_input.logits_test is not None):
|
||||
logits_or_probs_test = sigmoid_func(attack_input.logits_test)
|
||||
else:
|
||||
logits_or_probs_test = attack_input.probs_test
|
||||
else:
|
||||
accuracy_fn = _get_accuracy
|
||||
# Single label accuracy is calculated with the argmax of the logits which
|
||||
# is the same as the argmax of the probbilities.
|
||||
logits_or_probs_train = attack_input.logits_or_probs_train
|
||||
logits_or_probs_test = attack_input.logits_or_probs_test
|
||||
if metadata.accuracy_train is None:
|
||||
metadata.accuracy_train = _get_accuracy(attack_input.logits_train,
|
||||
attack_input.labels_train)
|
||||
metadata.accuracy_train = accuracy_fn(logits_or_probs_train,
|
||||
attack_input.labels_train)
|
||||
if metadata.accuracy_test is None:
|
||||
metadata.accuracy_test = _get_accuracy(attack_input.logits_test,
|
||||
attack_input.labels_test)
|
||||
metadata.accuracy_test = accuracy_fn(logits_or_probs_test,
|
||||
attack_input.labels_test)
|
||||
loss_train = attack_input.get_loss_train()
|
||||
loss_test = attack_input.get_loss_test()
|
||||
if metadata.loss_train is None and loss_train is not None:
|
||||
|
@ -341,3 +390,28 @@ def _get_accuracy(logits, labels):
|
|||
if logits is None or labels is None:
|
||||
return None
|
||||
return metrics.accuracy_score(labels, np.argmax(logits, axis=1))
|
||||
|
||||
|
||||
def _get_numpy_binary_accuracy(preds: ArrayLike, labels: ArrayLike):
|
||||
"""Computes the multilabel accuracy at threshold=0.5 using Numpy."""
|
||||
return np.mean(np.equal(labels, np.round(preds)))
|
||||
|
||||
|
||||
def _get_multilabel_accuracy(preds: ArrayLike, labels: ArrayLike):
|
||||
"""Computes the accuracy over multilabel data if it is missing.
|
||||
|
||||
Compute multilabel binary accuracy. AUC is a better measure of model quality
|
||||
for multilabel classification than accuracy, in particular when the classes
|
||||
are imbalanced. For consistency with the single label classification case,
|
||||
we compute and return the binary accuracy over the labels and predictions.
|
||||
|
||||
Args:
|
||||
preds: Prediction probabilities.
|
||||
labels: Ground truth multihot labels.
|
||||
|
||||
Returns:
|
||||
The binary accuracy averaged across all labels.
|
||||
"""
|
||||
if preds is None or labels is None:
|
||||
return None
|
||||
return _get_numpy_binary_accuracy(preds, labels)
|
||||
|
|
|
@ -34,6 +34,42 @@ def get_test_input(n_train, n_test):
|
|||
labels_test=np.array([i % 5 for i in range(n_test)]))
|
||||
|
||||
|
||||
def get_multihot_labels_for_test(num_samples: int,
|
||||
num_classes: int) -> np.ndarray:
|
||||
"""Generate a array of multihot labels.
|
||||
|
||||
Given an integer 'num_samples', generate a deterministic array of
|
||||
'num_classes'multihot labels. Each multihot label is the list of bits (0/1) of
|
||||
the corresponding row number in the array, upto 'num_classes'. If the value
|
||||
of num_classes < num_samples, then the bit list repeats.
|
||||
e.g. if num_samples=10 and num_classes=3, row=3 corresponds to the label
|
||||
vector [0, 1, 1].
|
||||
|
||||
Args:
|
||||
num_samples: Number of samples for which to generate test labels.
|
||||
num_classes: Number of classes for which to generate test multihot labels.
|
||||
|
||||
Returns:
|
||||
Numpy integer array with rows=number of samples, and columns=length of the
|
||||
bit-representation of the number of samples.
|
||||
"""
|
||||
m = 2**num_classes # Number of unique labels given the number of classes.
|
||||
bit_format = f'0{num_classes}b' # Bit representation format with leading 0s.
|
||||
return np.asarray(
|
||||
[list(format(i % m, bit_format)) for i in range(num_samples)]).astype(int)
|
||||
|
||||
|
||||
def get_multilabel_test_input(n_train, n_test):
|
||||
"""Get example multilabel inputs for attacks."""
|
||||
rng = np.random.RandomState(4)
|
||||
num_classes = max(n_train // 20, 5) # use at least 5 classes.
|
||||
return AttackInputData(
|
||||
logits_train=rng.randn(n_train, num_classes) + 0.2,
|
||||
logits_test=rng.randn(n_test, num_classes) + 0.2,
|
||||
labels_train=get_multihot_labels_for_test(n_train, num_classes),
|
||||
labels_test=get_multihot_labels_for_test(n_test, num_classes))
|
||||
|
||||
|
||||
def get_test_input_logits_only(n_train, n_test):
|
||||
"""Get example input logits for attacks."""
|
||||
rng = np.random.RandomState(4)
|
||||
|
@ -172,5 +208,60 @@ class RunAttacksTest(absltest.TestCase):
|
|||
DataSize(ntrain=20, ntest=16))
|
||||
|
||||
|
||||
class RunAttacksTestOnMultilabelData(absltest.TestCase):
|
||||
|
||||
def test_run_attacks_size(self):
|
||||
result = mia.run_attacks(
|
||||
get_multilabel_test_input(100, 100), SlicingSpec(),
|
||||
(AttackType.LOGISTIC_REGRESSION,))
|
||||
|
||||
self.assertLen(result.single_attack_results, 1)
|
||||
|
||||
def test_run_attack_trained_sets_attack_type(self):
|
||||
result = mia._run_attack(
|
||||
get_multilabel_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
||||
|
||||
self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
||||
|
||||
def test_run_attack_threshold_sets_attack_type(self):
|
||||
result = mia._run_attack(
|
||||
get_multilabel_test_input(100, 100), AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
def test_run_attack_threshold_entropy_fails(self):
|
||||
self.assertRaises(NotImplementedError, mia._run_threshold_entropy_attack,
|
||||
get_multilabel_test_input(100, 100))
|
||||
|
||||
def test_run_attack_by_percentiles_slice(self):
|
||||
result = mia.run_attacks(
|
||||
get_multilabel_test_input(100, 100),
|
||||
SlicingSpec(entire_dataset=True, by_class=False, by_percentiles=True),
|
||||
(AttackType.THRESHOLD_ATTACK,))
|
||||
|
||||
# 1 attack on entire dataset, 1 attack each of 10 percentile ranges, total
|
||||
# of 11.
|
||||
self.assertLen(result.single_attack_results, 11)
|
||||
expected_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (20, 30))
|
||||
# First slice (Slice #0) is entire dataset. Hence Slice #3 is the 3rd
|
||||
# percentile range 20-30.
|
||||
self.assertEqual(result.single_attack_results[3].slice_spec, expected_slice)
|
||||
|
||||
def test_numpy_multilabel_accuracy(self):
|
||||
predictions = [[0.5, 0.2, 0.3], [0.1, 0.6, 0.3], [0.5, 0.2, 0.3]]
|
||||
labels = [[1, 0, 0], [0, 1, 1], [1, 0, 1]]
|
||||
# At a threshold=0.5, 5 of the total 9 lables are correct.
|
||||
self.assertAlmostEqual(
|
||||
mia._get_numpy_binary_accuracy(predictions, labels), 5 / 9, places=6)
|
||||
|
||||
def test_multilabel_accuracy(self):
|
||||
predictions = [[0.5, 0.2, 0.3], [0.1, 0.6, 0.3], [0.5, 0.2, 0.3]]
|
||||
labels = [[1, 0, 0], [0, 1, 1], [1, 0, 1]]
|
||||
# At a threshold=0.5, 5 of the total 9 lables are correct.
|
||||
self.assertAlmostEqual(
|
||||
mia._get_multilabel_accuracy(predictions, labels), 5 / 9, places=6)
|
||||
self.assertIsNone(mia._get_accuracy(None, labels))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -47,6 +47,25 @@ class TrainedAttackerTest(absltest.TestCase):
|
|||
self.assertLen(attacker_data.fold_indices, 5)
|
||||
self.assertEmpty(attacker_data.left_out_indices)
|
||||
|
||||
def test_multilabel_create_attacker_data_loss_and_logits(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
||||
logits_test=np.array([[10, 11], [14, 15]]),
|
||||
labels_train=np.array([[0, 1], [1, 1], [1, 0]]),
|
||||
labels_test=np.array([[1, 0], [1, 1]]),
|
||||
loss_train=np.array([[1, 3], [6, 7], [8, 9]]),
|
||||
loss_test=np.array([[4, 2], [4, 6]]))
|
||||
attacker_data = models.create_attacker_data(attack_input, balance=False)
|
||||
self.assertLen(attacker_data.features_all, 5)
|
||||
self.assertLen(attacker_data.fold_indices, 5)
|
||||
self.assertEmpty(attacker_data.left_out_indices)
|
||||
self.assertEqual(
|
||||
attacker_data.features_all.shape[1],
|
||||
attack_input.logits_train.shape[1] + attack_input.loss_train.shape[1])
|
||||
self.assertTrue(
|
||||
attack_input.is_multilabel_data(),
|
||||
msg='Expected multilabel check to pass.')
|
||||
|
||||
def test_unbalanced_create_attacker_data_loss_and_logits(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""Utility functions for membership inference attacks."""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
from scipy import special
|
||||
|
||||
|
@ -76,3 +77,48 @@ def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
|
|||
the squared loss of each sample.
|
||||
"""
|
||||
return (y_true - y_pred)**2
|
||||
|
||||
|
||||
def multilabel_bce_loss(labels: np.ndarray,
|
||||
pred: np.ndarray,
|
||||
from_logits=False,
|
||||
small_value=1e-8) -> np.ndarray:
|
||||
"""Computes the per-multi-label-example cross entropy loss.
|
||||
|
||||
Args:
|
||||
labels: numpy array of shape (num_samples, num_classes). labels[i] is the
|
||||
true multi-hot encoded label (vector) of the i-th sample and each element
|
||||
of the vector is one of {0, 1}.
|
||||
pred: numpy array of shape (num_samples, num_classes). pred[i] is the
|
||||
logits or probability vector of the i-th sample.
|
||||
from_logits: whether `pred` is logits or probability vector.
|
||||
small_value: a scalar. np.log can become -inf if the probability is too
|
||||
close to 0, so the probability is clipped below by small_value.
|
||||
|
||||
Returns:
|
||||
the cross-entropy loss of each sample for each class.
|
||||
"""
|
||||
# Check arrays.
|
||||
if labels.shape[0] != pred.shape[0]:
|
||||
raise ValueError('labels and pred should have the same number of examples,',
|
||||
f'but got {labels.shape[0]} and {pred.shape[0]}.')
|
||||
if not ((labels == 0) | (labels == 1)).all():
|
||||
raise ValueError(
|
||||
'labels should be in {0, 1}. For multi-label classification the labels '
|
||||
'should be multihot encoded.')
|
||||
# Check if labels vectors are multi-label.
|
||||
summed_labels = np.sum(labels, axis=1)
|
||||
if ((summed_labels == 0) | (summed_labels == 1)).all():
|
||||
logging.info(
|
||||
('Labels are one-hot encoded single label. Every sample has at most one'
|
||||
' positive label.'))
|
||||
if not from_logits and ((pred < 0.0) | (pred > 1.0)).any():
|
||||
raise ValueError(('Prediction probabilities are not in [0, 1] and '
|
||||
'`from_logits` is set to False.'))
|
||||
|
||||
# Multi-class multi-label binary cross entropy loss
|
||||
if from_logits:
|
||||
pred = special.expit(pred)
|
||||
bce = labels * np.log(pred + small_value)
|
||||
bce += (1 - labels) * np.log(1 - pred + small_value)
|
||||
return -bce
|
||||
|
|
|
@ -124,5 +124,62 @@ class TestSquaredLoss(parameterized.TestCase):
|
|||
np.testing.assert_allclose(loss, expected_loss, atol=1e-7)
|
||||
|
||||
|
||||
class TestMultilabelBCELoss(parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('probs_example1', np.array(
|
||||
[[0, 1, 1], [1, 1, 0]]), np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||
np.array([[0.22314343, 1.20397247, 0.3566748],
|
||||
[0.22314343, 0.51082546, 2.30258409]]), False),
|
||||
('probs_example2', np.array([[0, 1, 0], [1, 1, 0]]),
|
||||
np.array([[0.01, 0.02, 0.04], [0.8, 0.7, 0.9]]),
|
||||
np.array([[0.01005033, 3.91202251, 0.04082198],
|
||||
[0.22314354, 0.35667493, 2.30258499]]), False),
|
||||
('logits_example1', np.array([[0, 1, 1], [1, 1, 0]]),
|
||||
np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]),
|
||||
np.array([[0.26328245, 0.85435522, 0.11551951],
|
||||
[0.69314716, 0.47407697, 1.70141322]]), True),
|
||||
('logits_example2', np.array([[0, 1, 0], [1, 1, 0]]),
|
||||
np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]),
|
||||
np.array([[0.26328245, 0.85435522, 2.21551943],
|
||||
[0.69314716, 0.47407697, 1.70141322]]), True),
|
||||
)
|
||||
def test_multilabel_bce_loss(self, label, pred, expected_losses, from_logits):
|
||||
loss = utils.multilabel_bce_loss(label, pred, from_logits=from_logits)
|
||||
np.testing.assert_allclose(loss, expected_losses, atol=1e-6)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('from_logits_true_and_incorrect_values_example1',
|
||||
np.array([[0, 1, 1], [1, 1, 0]
|
||||
]), np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||
np.array([[0.22314343, 1.20397247, 0.3566748],
|
||||
[0.22314343, 0.51082546, 2.30258409]]), True),
|
||||
('from_logits_true_and_incorrect_values_example2',
|
||||
np.array([[0, 1, 0], [1, 1, 0]
|
||||
]), np.array([[0.01, 0.02, 0.04], [0.8, 0.7, 0.9]]),
|
||||
np.array([[0.01005033, 3.91202251, 0.04082198],
|
||||
[0.22314354, 0.35667493, 2.30258499]]), True),
|
||||
)
|
||||
def test_multilabel_bce_loss_incorrect_value(self, label, pred,
|
||||
expected_losses, from_logits):
|
||||
loss = utils.multilabel_bce_loss(label, pred, from_logits=from_logits)
|
||||
self.assertFalse(np.allclose(loss, expected_losses))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('from_logits_false_and_pred_not_in_0to1_example1',
|
||||
np.array([[0, 1, 1], [1, 1, 0]
|
||||
]), np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]), False,
|
||||
(r'Prediction probabilities are not in \[0, 1\] and '
|
||||
'`from_logits` is set to False.')),
|
||||
('labels_not_0_or_1', np.array([[0, 1, 0], [1, 2, 0]]),
|
||||
np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]), False,
|
||||
('labels should be in {0, 1}. For multi-label classification the labels '
|
||||
'should be multihot encoded.')),
|
||||
)
|
||||
def test_multilabel_bce_loss_raises(self, label, pred, from_logits, regex):
|
||||
self.assertRaisesRegex(ValueError, regex, utils.multilabel_bce_loss, label,
|
||||
pred, from_logits)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue