diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py index 5674c1c..332972a 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py @@ -14,7 +14,7 @@ """Functions for advanced membership inference attacks.""" import functools -from typing import Sequence, Union +from typing import Optional, Sequence, Union import numpy as np import scipy.stats from tensorflow_privacy.privacy.privacy_tests.utils import log_loss @@ -197,6 +197,7 @@ def convert_logit_to_prob(logit: np.ndarray) -> np.ndarray: def calculate_statistic(pred: np.ndarray, labels: np.ndarray, + sample_weight: Optional[np.ndarray] = None, is_logits: bool = True, option: str = 'logit', small_value: float = 1e-45): @@ -215,6 +216,10 @@ def calculate_statistic(pred: np.ndarray, An array of size n by c where n is the number of samples and c is the number of classes labels: true labels of samples (integer valued) + sample_weight: a vector of weights of shape (num_samples, ) that are + assigned to individual samples. If not provided, then each sample is + given unit weight. Only the LogisticRegressionAttacker and the + RandomForestAttacker support sample weights. is_logits: whether pred is logits or probability vectors option: confidence using probability, xe loss, logit of confidence, confidence using logits, hinge loss @@ -241,7 +246,7 @@ def calculate_statistic(pred: np.ndarray, if option in ['conf with prob', 'conf with logit']: return pred[range(n), labels] if option == 'xe': - return log_loss(labels, pred) + return log_loss(labels, pred, sample_weight=sample_weight) if option == 'logit': p_true = pred[range(n), labels] pred[range(n), labels] = 0 diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py index 38b7a43..f033aa8 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py @@ -16,6 +16,8 @@ import functools import gc import os +from typing import Optional + from absl import app from absl import flags import matplotlib.pyplot as plt @@ -69,7 +71,11 @@ def plot_curve_with_area(x, y, xlabel, ylabel, ax, label, title=None): ax.title.set_text(title) -def get_stat_and_loss_aug(model, x, y, batch_size=4096): +def get_stat_and_loss_aug(model, + x, + y, + sample_weight: Optional[np.ndarray] = None, + batch_size=4096): """A helper function to get the statistics and losses. Here we get the statistics and losses for the original and @@ -80,6 +86,10 @@ def get_stat_and_loss_aug(model, x, y, batch_size=4096): model: model to make prediction x: samples y: true labels of samples (integer valued) + sample_weight: a vector of weights of shape (n_samples, ) that are + assigned to individual samples. If not provided, then each sample is + given unit weight. Only the LogisticRegressionAttacker and the + RandomForestAttacker support sample weights. batch_size: the batch size for model.predict Returns: @@ -89,8 +99,10 @@ def get_stat_and_loss_aug(model, x, y, batch_size=4096): for data in [x, x[:, :, ::-1, :]]: prob = amia.convert_logit_to_prob( model.predict(data, batch_size=batch_size)) - losses.append(utils.log_loss(y, prob)) - stat.append(amia.calculate_statistic(prob, y, convert_to_prob=False)) + losses.append(utils.log_loss(y, prob, sample_weight=sample_weight)) + stat.append( + amia.calculate_statistic( + prob, y, sample_weight=sample_weight, convert_to_prob=False)) return np.vstack(stat).transpose(1, 0), np.vstack(losses).transpose(1, 0) @@ -103,6 +115,8 @@ def main(unused_argv): # Load data. x, y = load_cifar10() + # Sample weights are set to `None` by default, but can be changed here. + sample_weight = None n = x.shape[0] # Train the target and shadow models. We will use one of the model in `models` @@ -144,7 +158,7 @@ def main(unused_argv): print(f'Trained model #{i} with {in_indices[-1].sum()} examples.') # Get the statistics of the current model. - s, l = get_stat_and_loss_aug(model, x, y) + s, l = get_stat_and_loss_aug(model, x, y, sample_weight) stat.append(s) losses.append(l) @@ -175,7 +189,9 @@ def main(unused_argv): stat_target, stat_in, stat_out, fix_variance=True) attack_input = AttackInputData( loss_train=scores[in_indices_target], - loss_test=scores[~in_indices_target]) + loss_test=scores[~in_indices_target], + sample_weight_train=sample_weight, + sample_weight_test=sample_weight) result_lira = mia.run_attacks(attack_input).single_attack_results[0] print('Advanced MIA attack with Gaussian:', f'auc = {result_lira.get_auc():.4f}', @@ -187,7 +203,9 @@ def main(unused_argv): scores = -amia.compute_score_offset(stat_target, stat_in, stat_out) attack_input = AttackInputData( loss_train=scores[in_indices_target], - loss_test=scores[~in_indices_target]) + loss_test=scores[~in_indices_target], + sample_weight_train=sample_weight, + sample_weight_test=sample_weight) result_offset = mia.run_attacks(attack_input).single_attack_results[0] print('Advanced MIA attack with offset:', f'auc = {result_offset.get_auc():.4f}', @@ -197,7 +215,9 @@ def main(unused_argv): loss_target = losses[idx][:, 0] attack_input = AttackInputData( loss_train=loss_target[in_indices_target], - loss_test=loss_target[~in_indices_target]) + loss_test=loss_target[~in_indices_target], + sample_weight_train=sample_weight, + sample_weight_test=sample_weight) result_baseline = mia.run_attacks(attack_input).single_attack_results[0] print('Baseline MIA attack:', f'auc = {result_baseline.get_auc():.4f}', f'adv = {result_baseline.get_attacker_advantage():.4f}') diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_test.py index 1865d84..268e865 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_test.py @@ -158,19 +158,21 @@ class TestCalculateStatistic(absltest.TestCase): # [0.09003057, 0.66524096, 0.24472847]]) labels = np.array([1, 2]) - stat = amia.calculate_statistic(logit, labels, is_logits, 'conf with prob') + stat = amia.calculate_statistic(logit, labels, None, is_logits, + 'conf with prob') np.testing.assert_allclose(stat, np.array([0.72747516, 0.24472847])) - stat = amia.calculate_statistic(logit, labels, is_logits, 'xe') + stat = amia.calculate_statistic(logit, labels, None, is_logits, 'xe') np.testing.assert_allclose(stat, np.array([0.31817543, 1.40760596])) - stat = amia.calculate_statistic(logit, labels, is_logits, 'logit') + stat = amia.calculate_statistic(logit, labels, None, is_logits, 'logit') np.testing.assert_allclose(stat, np.array([0.98185009, -1.12692802])) - stat = amia.calculate_statistic(logit, labels, is_logits, 'conf with logit') + stat = amia.calculate_statistic(logit, labels, None, is_logits, + 'conf with logit') np.testing.assert_allclose(stat, np.array([2, 0.])) - stat = amia.calculate_statistic(logit, labels, is_logits, 'hinge') + stat = amia.calculate_statistic(logit, labels, None, is_logits, 'hinge') np.testing.assert_allclose(stat, np.array([1, -1.])) def test_calculate_statistic_prob(self): @@ -179,19 +181,74 @@ class TestCalculateStatistic(absltest.TestCase): prob = np.array([[0.1, 0.85, 0.05], [0.1, 0.5, 0.4]]) labels = np.array([1, 2]) - stat = amia.calculate_statistic(prob, labels, is_logits, 'conf with prob') + stat = amia.calculate_statistic(prob, labels, None, is_logits, + 'conf with prob') np.testing.assert_allclose(stat, np.array([0.85, 0.4])) - stat = amia.calculate_statistic(prob, labels, is_logits, 'xe') + stat = amia.calculate_statistic(prob, labels, None, is_logits, 'xe') np.testing.assert_allclose(stat, np.array([0.16251893, 0.91629073])) - stat = amia.calculate_statistic(prob, labels, is_logits, 'logit') + stat = amia.calculate_statistic(prob, labels, None, is_logits, 'logit') np.testing.assert_allclose(stat, np.array([1.73460106, -0.40546511])) np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels, - is_logits, 'conf with logit') + None, is_logits, 'conf with logit') np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels, - is_logits, 'hinge') + None, is_logits, 'hinge') + + def test_calculate_statistic_logit_with_sample_weights(self): + """Test calculate_statistic with input as logit.""" + is_logits = True + logit = np.array([[1, 2, -3.], [-1, 1, 0]]) + # expected probability vector + # array([[0.26762315, 0.72747516, 0.00490169], + # [0.09003057, 0.66524096, 0.24472847]]) + labels = np.array([1, 2]) + sample_weight = np.array([1.0, 0.5]) + + stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits, + 'conf with prob') + np.testing.assert_allclose(stat, np.array([0.72747516, 0.24472847])) + + stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits, + 'xe') + np.testing.assert_allclose(stat, np.array([0.31817543, 0.70380298])) + + stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits, + 'logit') + np.testing.assert_allclose(stat, np.array([0.98185009, -1.12692802])) + + stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits, + 'conf with logit') + np.testing.assert_allclose(stat, np.array([2, 0.])) + + stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits, + 'hinge') + np.testing.assert_allclose(stat, np.array([1, -1.])) + + def test_calculate_statistic_prob_with_sample_weights(self): + """Test calculate_statistic with input as probability vector.""" + is_logits = False + prob = np.array([[0.1, 0.85, 0.05], [0.1, 0.5, 0.4]]) + labels = np.array([1, 2]) + sample_weight = np.array([1.0, 0.5]) + + stat = amia.calculate_statistic(prob, labels, sample_weight, is_logits, + 'conf with prob') + np.testing.assert_allclose(stat, np.array([0.85, 0.4])) + + stat = amia.calculate_statistic(prob, labels, sample_weight, is_logits, + 'xe') + np.testing.assert_allclose(stat, np.array([0.16251893, 0.458145365])) + + stat = amia.calculate_statistic(prob, labels, sample_weight, is_logits, + 'logit') + np.testing.assert_allclose(stat, np.array([1.73460106, -0.40546511])) + + np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels, + None, is_logits, 'conf with logit') + np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels, + None, is_logits, 'hinge') if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index dc558c1..39a4773 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -20,7 +20,7 @@ import glob import logging import os import pickle -from typing import Any, Callable, Iterable, MutableSequence, Optional, Union +from typing import Any, Iterable, MutableSequence, Optional, Union import numpy as np import pandas as pd @@ -203,6 +203,10 @@ class AttackInputData: labels_train: Optional[np.ndarray] = None labels_test: Optional[np.ndarray] = None + # Sample weights, if provided. + sample_weight_train: Optional[np.ndarray] = None + sample_weight_test: Optional[np.ndarray] = None + # Explicitly specified loss. If provided, this is used instead of deriving # loss from logits and labels loss_train: Optional[np.ndarray] = None @@ -219,8 +223,7 @@ class AttackInputData: # string representation, or a callable. # If a callable is provided, it should take in two argument, the 1st is # labels, the 2nd is logits or probs. - loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], str, - utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY + loss_function: utils.LossFunctionCallable = utils.LossFunction.CROSS_ENTROPY # Whether `loss_function` will be called with logits or probs. If not set # (None), will decide by availablity of logits and probs and logits is # preferred when both are available. @@ -309,7 +312,8 @@ class AttackInputData: self.loss_function_using_logits = (self.logits_train is not None) return utils.get_loss(self.loss_train, self.labels_train, self.logits_train, self.probs_train, self.loss_function, - self.loss_function_using_logits, self.multilabel_data) + self.loss_function_using_logits, self.multilabel_data, + self.sample_weight_train) def get_loss_test(self): """Calculates (if needed) cross-entropy losses for the test set. @@ -321,7 +325,8 @@ class AttackInputData: self.loss_function_using_logits = bool(self.logits_test) return utils.get_loss(self.loss_test, self.labels_test, self.logits_test, self.probs_test, self.loss_function, - self.loss_function_using_logits, self.multilabel_data) + self.loss_function_using_logits, self.multilabel_data, + self.sample_weight_test) def get_entropy_train(self): """Calculates prediction entropy for the training set.""" @@ -367,6 +372,11 @@ class AttackInputData: """Returns the number of examples of the test set.""" return self.get_test_shape()[0] + def has_nonnull_sample_weights(self): + """Whether both the train and test input data have sample weights.""" + return (self.sample_weight_train is not None and + self.sample_weight_test is not None) + def is_multihot_labels(self, arr, arr_name) -> bool: """Check if the 2D array is multihot, with values in [0, 1]. @@ -556,6 +566,8 @@ class AttackInputData: _append_array_shape(self.probs_test, 'probs_test', result) _append_array_shape(self.labels_train, 'labels_train', result) _append_array_shape(self.labels_test, 'labels_test', result) + _append_array_shape(self.sample_weight_train, 'sample_weight_train', result) + _append_array_shape(self.sample_weight_test, 'sample_weight_test', result) result.append(')') return '\n'.join(result) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py index 23432fe..aad3a09 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py @@ -63,6 +63,20 @@ class AttackInputDataTest(parameterized.TestCase): np.testing.assert_allclose( attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7) + def test_get_xe_loss_from_logits_with_sample_weights(self): + attack_input = AttackInputData( + logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]), + logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]), + labels_train=np.array([1, 0]), + labels_test=np.array([0, 2]), + sample_weight_train=np.array([1.0, 0.5]), + sample_weight_test=np.array([0.5, 1.0])) + + np.testing.assert_allclose( + attack_input.get_loss_train(), [0.36313551, 0.685769515], atol=1e-7) + np.testing.assert_allclose( + attack_input.get_loss_test(), [0.149304485, 0.95618669], atol=1e-7) + def test_get_xe_loss_from_probs(self): attack_input = AttackInputData( probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]), @@ -88,6 +102,30 @@ class AttackInputDataTest(parameterized.TestCase): np.testing.assert_allclose( attack_input.get_loss_test(), expected_loss0[::-1], rtol=1e-2) + def test_get_binary_xe_loss_from_logits_with_sample_weights(self): + attack_input = AttackInputData( + logits_train=np.array([-10, -5, 0., 5, 10]), + logits_test=np.array([-10, -5, 0., 5, 10]), + labels_train=np.zeros((5,)), + labels_test=np.ones((5,)), + sample_weight_train=np.array([1.0, 0.0, 0.5, 0.2, 1.0]), + sample_weight_test=np.array([0.0, 0.1, 0.5, 0.2, 1.0]), + loss_function_using_logits=True) + expected_train_loss = np.array( + [4.539890e-05, 0.0, 0.3465736, 1.001343, 10.0]) + expected_test_loss = np.array( + [4.539890e-05, 0.001343070, 0.3465736, 0.5006715, 0.0]) + np.testing.assert_allclose( + attack_input.get_loss_train(), + expected_train_loss, + rtol=1e-2, + err_msg='Failure in binary xe training loss calculation.') + np.testing.assert_allclose( + attack_input.get_loss_test(), + expected_test_loss[::-1], + rtol=1e-2, + err_msg='Failure in binary xe test loss calculation.') + def test_get_binary_xe_loss_from_probs(self): attack_input = AttackInputData( probs_train=np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]), @@ -137,7 +175,8 @@ class AttackInputDataTest(parameterized.TestCase): def test_get_customized_loss(self, loss_function_using_logits, expected_train, expected_test): - def fake_loss(x, y): + def fake_loss(x, y, sample_weight=None): + del sample_weight # Unused. return 2 * x + y attack_input = AttackInputData( @@ -332,6 +371,23 @@ class AttackInputDataTest(parameterized.TestCase): attack_input.get_loss_test(), [[0.22314354, 0.35667493, 2.30258499]], atol=1e-6) + def test_multilabel_get_bce_loss_from_probs_with_sample_weights(self): + attack_input = AttackInputData( + probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]), + probs_test=np.array([[0.8, 0.7, 0.9]]), + labels_train=np.array([[0, 1, 1], [1, 1, 0]]), + labels_test=np.array([[1, 1, 0]]), + sample_weight_train=np.array([1.0, 0.5]), + sample_weight_test=np.array([0.5])) + + np.testing.assert_allclose( + attack_input.get_loss_train(), [[0.22314343, 1.20397247, 0.3566748], + [0.111571715, 0.25541273, 1.151292045]], + atol=1e-6) + np.testing.assert_allclose( + attack_input.get_loss_test(), [[0.11157177, 0.17833747, 1.1512925]], + atol=1e-6) + def test_multilabel_get_bce_loss_from_logits(self): attack_input = AttackInputData( logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]), @@ -349,6 +405,25 @@ class AttackInputDataTest(parameterized.TestCase): [[0.68815966, 0.20141327], [0.47407697, 0.04858734]], atol=1e-6) + def test_multilabel_get_bce_loss_from_logits_with_sample_weights(self): + attack_input = AttackInputData( + logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]), + logits_test=np.array([[0.01, 1.5], [0.5, -3]]), + labels_train=np.array([[0, 0], [0, 1], [1, 1]]), + labels_test=np.array([[1, 1], [1, 0]]), + sample_weight_train=np.array([1.0, 0.0, 0.5]), + sample_weight_test=np.array([0.0, 0.5])) + + np.testing.assert_allclose( + attack_input.get_loss_train(), + [[0.31326167, 0.126928], [0.0, 0.0], [0.23703848, 1.52429357]], + atol=1e-6, + err_msg='Failure in multilabel bce training loss calculation.') + np.testing.assert_allclose( + attack_input.get_loss_test(), [[0.0, 0.0], [0.23703848, 0.02429367]], + atol=1e-6, + err_msg='Failure in multilabel bce test loss calculation.') + def test_multilabel_get_loss_explicitly_provided(self): attack_input = AttackInputData( loss_train=np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]), @@ -359,6 +434,23 @@ class AttackInputDataTest(parameterized.TestCase): np.testing.assert_equal(attack_input.get_loss_test().tolist(), np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]])) + def test_multilabel_get_loss_explicitly_provided_with_sample_weights(self): + attack_input = AttackInputData( + loss_train=np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]), + loss_test=np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]), + sample_weight_train=np.array([1.0, 0.5]), + sample_weight_test=np.array([0.0, 0.5])) + + # Since loss is provided, sample weights have no effect. + np.testing.assert_equal( + attack_input.get_loss_train().tolist(), + np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]), + err_msg='Failure in multilabel get training loss calculation.') + np.testing.assert_equal( + attack_input.get_loss_test().tolist(), + np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]), + err_msg='Failure in multilabel get test loss calculation.') + def test_validate_with_force_multilabel_false(self): attack_input = AttackInputData( probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]), diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py index 5d4f381..ca8a494 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py @@ -42,6 +42,9 @@ def _slice_data_by_indices(data: AttackInputData, idx_train, result.labels_train = _slice_if_not_none(data.labels_train, idx_train) result.loss_train = _slice_if_not_none(data.loss_train, idx_train) result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train) + # Copy over sample weights if provided. + result.sample_weight_train = data.sample_weight_train + result.sample_weight_test = data.sample_weight_test # Slice test data. result.logits_test = _slice_if_not_none(data.logits_test, idx_test) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py index 3f12908..c0cdd51 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py @@ -115,6 +115,8 @@ class GetSliceTest(absltest.TestCase): loss_test = np.array([0.5, 3.5, 7, 4.5]) entropy_train = np.array([0.4, 8, 0.6, 10]) entropy_test = np.array([15, 10.5, 4.5, 0.3]) + sample_weight_train = np.array([1.0, 0.5]) + sample_weight_test = np.array([0.5, 1.0]) self.input_data = AttackInputData( logits_train=logits_train, @@ -126,7 +128,9 @@ class GetSliceTest(absltest.TestCase): loss_train=loss_train, loss_test=loss_test, entropy_train=entropy_train, - entropy_test=entropy_test) + entropy_test=entropy_test, + sample_weight_train=sample_weight_train, + sample_weight_test=sample_weight_test) def test_slice_entire_dataset(self): entire_dataset_slice = SingleSliceSpec() @@ -168,6 +172,12 @@ class GetSliceTest(absltest.TestCase): self.assertTrue((output.entropy_train == [0.4, 0.6]).all()) self.assertTrue((output.entropy_test == [15]).all()) + # Check sample weights + self.assertLen(output.sample_weight_train, 2) + np.testing.assert_array_equal(output.sample_weight_train, [1.0, 0.5]) + self.assertLen(output.sample_weight_test, 2) + np.testing.assert_array_equal(output.sample_weight_test, [0.5, 1.0]) + def test_slice_by_percentile(self): percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50)) output = get_slice(self.input_data, percentile_slice) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py index 40c2e7b..fd0f294 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py @@ -72,6 +72,19 @@ def get_multilabel_test_input(n_train, n_test): labels_test=get_multihot_labels_for_test(n_test, num_classes)) +def get_multilabel_test_input_with_sample_weights(n_train, n_test): + """Get example multilabel inputs for attacks.""" + rng = np.random.RandomState(4) + num_classes = max(n_train // 20, 5) # use at least 5 classes. + return AttackInputData( + logits_train=rng.randn(n_train, num_classes) + 0.2, + logits_test=rng.randn(n_test, num_classes) + 0.2, + labels_train=get_multihot_labels_for_test(n_train, num_classes), + labels_test=get_multihot_labels_for_test(n_test, num_classes), + sample_weight_train=rng.randn(n_train, 1), + sample_weight_test=rng.randn(n_test, 1)) + + def get_test_input_logits_only(n_train, n_test): """Get example input logits for attacks.""" rng = np.random.RandomState(4) @@ -80,6 +93,16 @@ def get_test_input_logits_only(n_train, n_test): logits_test=rng.randn(n_test, 5) + 0.2) +def get_test_input_logits_only_with_sample_weights(n_train, n_test): + """Get example input logits for attacks.""" + rng = np.random.RandomState(4) + return AttackInputData( + logits_train=rng.randn(n_train, 5) + 0.2, + logits_test=rng.randn(n_test, 5) + 0.2, + sample_weight_train=rng.randn(n_train, 1), + sample_weight_test=rng.randn(n_test, 1)) + + class RunAttacksTest(parameterized.TestCase): def test_run_attacks_size(self): @@ -113,6 +136,17 @@ class RunAttacksTest(parameterized.TestCase): self.assertLen(result.single_attack_results, 2) + def test_run_attacks_parallel_backend_with_sample_weights(self): + result = mia.run_attacks( + get_multilabel_test_input_with_sample_weights(100, 100), + SlicingSpec(), ( + AttackType.LOGISTIC_REGRESSION, + AttackType.RANDOM_FOREST, + ), + backend='threading') + + self.assertLen(result.single_attack_results, 2) + def test_trained_attacks_logits_only_size(self): result = mia.run_attacks( get_test_input_logits_only(100, 100), SlicingSpec(), @@ -120,6 +154,13 @@ class RunAttacksTest(parameterized.TestCase): self.assertLen(result.single_attack_results, 1) + def test_trained_attacks_logits_only_with_sample_weights_size(self): + result = mia.run_attacks( + get_test_input_logits_only_with_sample_weights(100, 100), SlicingSpec(), + (AttackType.LOGISTIC_REGRESSION,)) + + self.assertLen(result.single_attack_results, 1) + def test_run_attack_trained_sets_attack_type(self): result = mia._run_attack( get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION) @@ -271,6 +312,15 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase): self.assertLen(result.single_attack_results, 1) + def test_run_attacks_parallel_backend_with_sample_weights(self): + result = mia.run_attacks( + get_multilabel_test_input_with_sample_weights(100, 100), + SlicingSpec(), + (AttackType.LOGISTIC_REGRESSION, AttackType.RANDOM_FOREST), + backend='threading') + + self.assertLen(result.single_attack_results, 2) + def test_run_attack_trained_sets_attack_type(self): result = mia._run_attack( get_multilabel_test_input(100, 100), AttackType.LOGISTIC_REGRESSION) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py index 6e6c02e..77045b1 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py @@ -39,6 +39,8 @@ class AttackerData: features_all: Optional[np.ndarray] = None # Indicator for whether the example is in-training (0) or out-of-training (1). labels_all: Optional[np.ndarray] = None + # Sample weights of in-training and out-of-training examples, if provided. + sample_weights_all: Optional[np.ndarray] = None # Indices for `features_all` and `labels_all` that are going to be used for # training the attackers. @@ -75,6 +77,12 @@ def create_attacker_data(attack_input_data: data_structures.AttackInputData, ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0] features_all = np.concatenate((attack_input_train, attack_input_test)) labels_all = np.concatenate((np.zeros(ntrain), np.ones(ntest))) + if attack_input_data.has_nonnull_sample_weights(): + sample_weights_all = np.concatenate((attack_input_data.sample_weight_train, + attack_input_data.sample_weight_test), + axis=0) + else: + sample_weights_all = None fold_indices = np.arange(ntrain + ntest) left_out_indices = np.asarray([], dtype=np.int32) @@ -99,6 +107,7 @@ def create_attacker_data(attack_input_data: data_structures.AttackInputData, return AttackerData( features_all=features_all, labels_all=labels_all, + sample_weights_all=sample_weights_all, fold_indices=fold_indices, left_out_indices=left_out_indices, data_size=data_structures.DataSize(ntrain=ntrain, ntest=ntest)) @@ -158,7 +167,7 @@ class TrainedAttacker(object): n_jobs=self.n_jobs) logging.info('Using %s backend for training.', backend) - def train_model(self, input_features, is_training_labels): + def train_model(self, input_features, is_training_labels, sample_weight=None): """Train an attacker model. This is trained on examples from train and test datasets. @@ -168,6 +177,10 @@ class TrainedAttacker(object): number of features. is_training_labels : a vector of booleans of shape (n_samples, ) representing whether the sample is in the training set or not. + sample_weight: a vector of weights of shape (n_samples, ) that are + assigned to individual samples. If not provided, then each sample is + given unit weight. Only the LogisticRegressionAttacker and the + RandomForestAttacker support sample weights. """ raise NotImplementedError() @@ -193,7 +206,7 @@ class LogisticRegressionAttacker(TrainedAttacker): def __init__(self, backend: Optional[str] = None): super().__init__(backend=backend) - def train_model(self, input_features, is_training_labels): + def train_model(self, input_features, is_training_labels, sample_weight=None): with self.ctx_mgr: lr = linear_model.LogisticRegression(solver='lbfgs', n_jobs=self.n_jobs) param_grid = { @@ -201,7 +214,7 @@ class LogisticRegressionAttacker(TrainedAttacker): } model = model_selection.GridSearchCV( lr, param_grid=param_grid, cv=3, n_jobs=self.n_jobs, verbose=0) - model.fit(input_features, is_training_labels) + model.fit(input_features, is_training_labels, sample_weight=sample_weight) self.model = model @@ -211,7 +224,8 @@ class MultilayerPerceptronAttacker(TrainedAttacker): def __init__(self, backend: Optional[str] = None): super().__init__(backend=backend) - def train_model(self, input_features, is_training_labels): + def train_model(self, input_features, is_training_labels, sample_weight=None): + del sample_weight # MLP attacker does not use sample weights. with self.ctx_mgr: mlp_model = neural_network.MLPClassifier() param_grid = { @@ -231,7 +245,7 @@ class RandomForestAttacker(TrainedAttacker): def __init__(self, backend: Optional[str] = None): super().__init__(backend=backend) - def train_model(self, input_features, is_training_labels): + def train_model(self, input_features, is_training_labels, sample_weight=None): """Setup a random forest pipeline with cross-validation.""" with self.ctx_mgr: rf_model = ensemble.RandomForestClassifier(n_jobs=self.n_jobs) @@ -241,11 +255,11 @@ class RandomForestAttacker(TrainedAttacker): 'max_features': ['auto', 'sqrt'], 'max_depth': [5, 10, 20, None], 'min_samples_split': [2, 5, 10], - 'min_samples_leaf': [1, 2, 4] + 'min_samples_leaf': [1, 2, 4], } model = model_selection.GridSearchCV( rf_model, param_grid=param_grid, cv=3, n_jobs=self.n_jobs, verbose=0) - model.fit(input_features, is_training_labels) + model.fit(input_features, is_training_labels, sample_weight=sample_weight) self.model = model @@ -255,7 +269,8 @@ class KNearestNeighborsAttacker(TrainedAttacker): def __init__(self, backend: Optional[str] = None): super().__init__(backend=backend) - def train_model(self, input_features, is_training_labels): + def train_model(self, input_features, is_training_labels, sample_weight=None): + del sample_weight # K-NN attacker does not use sample weights. with self.ctx_mgr: knn_model = neighbors.KNeighborsClassifier(n_jobs=self.n_jobs) param_grid = { diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py index 4a4bcf9..032864b 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py @@ -68,6 +68,25 @@ class TrainedAttackerTest(parameterized.TestCase): attack_input.is_multilabel_data(), msg='Expected multilabel check to pass.') + def test_multilabel_create_attacker_data_logits_labels_sample_weights(self): + attack_input = AttackInputData( + logits_train=np.array([[1, 2], [5, 6], [8, 9]]), + logits_test=np.array([[10, 11], [14, 15]]), + labels_train=np.array([[0, 1], [1, 1], [1, 0]]), + labels_test=np.array([[1, 0], [1, 1]]), + sample_weight_train=np.array([1.0, 0.5, 0.0]), + sample_weight_test=np.array([1.0, 0.5])) + attacker_data = models.create_attacker_data(attack_input, balance=False) + self.assertLen(attacker_data.features_all, 5) + self.assertLen(attacker_data.fold_indices, 5) + self.assertEmpty(attacker_data.left_out_indices) + self.assertTrue( + attack_input.is_multilabel_data(), + msg='Expected multilabel check to pass.') + self.assertTrue( + attack_input.has_nonnull_sample_weights(), + msg='Expected to have sample weights.') + def test_unbalanced_create_attacker_data_loss_and_logits(self): attack_input = AttackInputData( logits_train=np.array([[1, 2], [5, 6], [8, 9]]), diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py index 9d87c4b..78e7f2e 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py @@ -14,7 +14,7 @@ """A hook and a function in tf estimator for membership inference attack.""" import os -from typing import Iterable +from typing import Iterable, Optional from absl import logging import numpy as np @@ -26,7 +26,10 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard -def calculate_losses(estimator, input_fn, labels): +def calculate_losses(estimator, + input_fn, + labels, + sample_weight: Optional[np.ndarray] = None): """Get predictions and losses for samples. The assumptions are 1) the loss is cross-entropy loss, and 2) user have @@ -38,13 +41,17 @@ def calculate_losses(estimator, input_fn, labels): estimator: model to make prediction input_fn: input function to be used in estimator.predict labels: array of size (n_samples, ), true labels of samples (integer valued) + sample_weight: a vector of weights of shape (n_samples, ) that are + assigned to individual samples. If not provided, then each sample is + given unit weight. Only the LogisticRegressionAttacker and the + RandomForestAttacker support sample weights. Returns: preds: probability vector of each sample loss: cross entropy loss of each sample """ pred = np.array(list(estimator.predict(input_fn=input_fn))) - loss = utils.log_loss(labels, pred) + loss = utils.log_loss(labels, pred, sample_weight=sample_weight) return pred, loss @@ -65,8 +72,10 @@ class MembershipInferenceTrainingHook(tf_estimator.SessionRunHook): Args: estimator: model to be tested - in_train: (in_training samples, in_training labels) - out_train: (out_training samples, out_training labels) + in_train: (in_training samples, in_training labels, + in_train_sample_weights=None) + out_train: (out_training samples, out_training labels, + out_train_sample_weights=None) input_fn_constructor: a function that receives sample, label and construct the input_fn for model prediction slicing_spec: slicing specification of the attack @@ -75,8 +84,24 @@ class MembershipInferenceTrainingHook(tf_estimator.SessionRunHook): tensorboard_merge_classifiers: if true, plot different classifiers with the same slicing_spec and metric in the same figure """ - in_train_data, self._in_train_labels = in_train - out_train_data, self._out_train_labels = out_train + if len(in_train) == 2: + in_train_data, self._in_train_labels = in_train + self._in_train_sample_weights = None + elif len(in_train) == 3: + (in_train_data, self._in_train_labels, + self._in_train_sample_weights) = in_train + else: + raise ValueError('`in_train` should be length 2 or 3, received ' + f'{len(in_train)}') + if len(out_train) == 2: + out_train_data, self._out_train_labels = out_train + self._out_train_sample_weights = None + elif len(out_train) == 3: + (out_train_data, self._out_train_labels, + self._out_train_sample_weights) = in_train + else: + raise ValueError('`out_train` should be length 2 or 3, received ' + f'{len(out_train)}') # Define the input functions for both in and out-training samples. self._in_train_input_fn = input_fn_constructor(in_train_data, @@ -105,8 +130,10 @@ class MembershipInferenceTrainingHook(tf_estimator.SessionRunHook): def end(self, session): results = run_attack_helper(self._estimator, self._in_train_input_fn, self._out_train_input_fn, self._in_train_labels, - self._out_train_labels, self._slicing_spec, - self._attack_types) + self._out_train_labels, + self._in_train_sample_weights, + self._out_train_sample_weights, + self._slicing_spec, self._attack_types) logging.info(results) att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics( @@ -137,8 +164,10 @@ def run_attack_on_tf_estimator_model( Args: estimator: model to be tested - in_train: (in_training samples, in_training labels) - out_train: (out_training samples, out_training labels) + in_train: (in_training samples, in_training labels, + in_training_sample_weights=None) + out_train: (out_training samples, out_training labels, + out_training_sample_weights=None) input_fn_constructor: a function that receives sample, label and construct the input_fn for model prediction slicing_spec: slicing specification of the attack @@ -147,8 +176,20 @@ def run_attack_on_tf_estimator_model( Returns: Results of the attack """ - in_train_data, in_train_labels = in_train - out_train_data, out_train_labels = out_train + + def unpack(data): + sample_weights = None + if len(data) == 2: + inputs, labels = data + elif len(in_train) == 3: + inputs, labels, sample_weights = in_train + else: + raise ValueError('`data` should be length 2 or 3, received ' + f'{len(data)}') + return inputs, labels, sample_weights + + in_train_data, in_train_labels, in_train_sample_weights = unpack(in_train) + out_train_data, out_train_labels, out_train_sample_weights = unpack(out_train) # Define the input functions for both in and out-training samples. in_train_input_fn = input_fn_constructor(in_train_data, in_train_labels) @@ -156,8 +197,9 @@ def run_attack_on_tf_estimator_model( # Call the helper to run the attack. results = run_attack_helper(estimator, in_train_input_fn, out_train_input_fn, - in_train_labels, out_train_labels, slicing_spec, - attack_types) + in_train_labels, out_train_labels, + in_train_sample_weights, out_train_sample_weights, + slicing_spec, attack_types) logging.info('End of training attack:') logging.info(results) return results @@ -168,6 +210,8 @@ def run_attack_helper(estimator, out_train_input_fn, in_train_labels, out_train_labels, + in_train_sample_weight: Optional[np.ndarray] = None, + out_train_sample_weight: Optional[np.ndarray] = None, slicing_spec: data_structures.SlicingSpec = None, attack_types: Iterable[data_structures.AttackType] = ( data_structures.AttackType.THRESHOLD_ATTACK,)): @@ -179,6 +223,14 @@ def run_attack_helper(estimator, out_train_input_fn: input_fn for out of training data in_train_labels: in training labels out_train_labels: out of training labels + in_train_sample_weight: a vector of weights of shape (n_train, ) that are + assigned to individual samples. If not provided, then each sample is + given unit weight. Only the LogisticRegressionAttacker and the + RandomForestAttacker support sample weights. + out_train_sample_weight: a vector of weights of shape (n_test, ) that are + assigned to individual samples. If not provided, then each sample is + given unit weight. Only the LogisticRegressionAttacker and the + RandomForestAttacker support sample weights. slicing_spec: slicing specification of the attack attack_types: a list of attacks, each of type AttackType @@ -186,18 +238,25 @@ def run_attack_helper(estimator, Results of the attack """ # Compute predictions and losses - in_train_pred, in_train_loss = calculate_losses(estimator, in_train_input_fn, - in_train_labels) - out_train_pred, out_train_loss = calculate_losses(estimator, - out_train_input_fn, - out_train_labels) + in_train_pred, in_train_loss = calculate_losses( + estimator, + in_train_input_fn, + in_train_labels, + sample_weight=in_train_sample_weight) + out_train_pred, out_train_loss = calculate_losses( + estimator, + out_train_input_fn, + out_train_labels, + sample_weight=out_train_sample_weight) attack_input = data_structures.AttackInputData( logits_train=in_train_pred, logits_test=out_train_pred, labels_train=in_train_labels, labels_test=out_train_labels, loss_train=in_train_loss, - loss_test=out_train_loss) + loss_test=out_train_loss, + sample_weight_train=in_train_sample_weight, + sample_weight_test=out_train_sample_weight) results = mia.run_attacks( attack_input, slicing_spec=slicing_spec, attack_types=attack_types) return results diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_example.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_example.py index 2b2d87b..b3b4ffa 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_example.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_example.py @@ -96,6 +96,8 @@ def main(unused_argv): # Load training and test data. x_train, y_train, x_test, y_test = load_cifar10() + # Sample weights are set to `None` by default, but can be changed here. + sample_weight_train, sample_weight_test = None, None # Instantiate the tf.Estimator. classifier = tf_estimator.Estimator( @@ -142,7 +144,8 @@ def main(unused_argv): print('End of training attack') attack_results = run_attack_on_tf_estimator_model( - classifier, (x_train, y_train), (x_test, y_test), + classifier, (x_train, y_train, sample_weight_train), + (x_test, y_test, sample_weight_test), input_fn_constructor, slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), attack_types=[ diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py index cdef4f5..e26fa9f 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation_test.py @@ -36,6 +36,8 @@ class UtilsTest(absltest.TestCase): self.test_data = np.random.rand(self.ntest, self.ndim) self.train_labels = np.random.randint(self.nclass, size=self.ntrain) self.test_labels = np.random.randint(self.nclass, size=self.ntest) + self.sample_weight_train = np.random.rand(self.ntrain) + self.sample_weight_test = np.random.rand(self.ntest) # Define a simple model function def model_fn(features, labels, mode): @@ -70,6 +72,14 @@ class UtilsTest(absltest.TestCase): self.assertEqual(pred.shape, (self.ntrain, self.nclass)) self.assertEqual(loss.shape, (self.ntrain,)) + pred, loss = tf_estimator_evaluation.calculate_losses( + self.classifier, + self.input_fn_train, + self.train_labels, + sample_weight=self.sample_weight_train) + self.assertEqual(pred.shape, (self.ntrain, self.nclass)) + self.assertEqual(loss.shape, (self.ntrain,)) + pred, loss = tf_estimator_evaluation.calculate_losses( self.classifier, self.input_fn_test, self.test_labels) self.assertEqual(pred.shape, (self.ntest, self.nclass)) @@ -83,6 +93,27 @@ class UtilsTest(absltest.TestCase): self.input_fn_test, self.train_labels, self.test_labels, + self.sample_weight_train, + self.sample_weight_test, + attack_types=[data_structures.AttackType.THRESHOLD_ATTACK]) + self.assertIsInstance(results, data_structures.AttackResults) + att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics( + results) + self.assertLen(att_types, 2) + self.assertLen(att_slices, 2) + self.assertLen(att_metrics, 2) + self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV + + def test_run_attack_helper_with_sample_weights(self): + """Test the attack.""" + results = tf_estimator_evaluation.run_attack_helper( + self.classifier, + self.input_fn_train, + self.input_fn_test, + self.train_labels, + self.test_labels, + in_train_sample_weight=self.sample_weight_train, + out_train_sample_weight=self.sample_weight_test, attack_types=[data_structures.AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, data_structures.AttackResults) att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics( @@ -112,6 +143,27 @@ class UtilsTest(absltest.TestCase): self.assertLen(att_metrics, 2) self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV + def test_run_attack_on_tf_estimator_model_with_sample_weights(self): + """Test the attack on the final models.""" + + def input_fn_constructor(x, y): + return tf_compat_v1_estimator.inputs.numpy_input_fn( + x={'x': x}, y=y, shuffle=False) + + results = tf_estimator_evaluation.run_attack_on_tf_estimator_model( + self.classifier, + (self.train_data, self.train_labels, self.sample_weight_train), + (self.test_data, self.test_labels), + input_fn_constructor, + attack_types=[data_structures.AttackType.THRESHOLD_ATTACK]) + self.assertIsInstance(results, data_structures.AttackResults) + att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics( + results) + self.assertLen(att_types, 2) + self.assertLen(att_slices, 2) + self.assertLen(att_metrics, 2) + self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV + if __name__ == '__main__': absltest.main() diff --git a/tensorflow_privacy/privacy/privacy_tests/utils.py b/tensorflow_privacy/privacy/privacy_tests/utils.py index ae97a8d..e8f31aa 100644 --- a/tensorflow_privacy/privacy/privacy_tests/utils.py +++ b/tensorflow_privacy/privacy/privacy_tests/utils.py @@ -21,8 +21,20 @@ import numpy as np from scipy import special +class LossFunction(enum.Enum): + """An enum that defines loss function.""" + CROSS_ENTROPY = 'cross_entropy' + SQUARED = 'squared' + + +LossFunctionCallable = Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], + np.ndarray] +LossFunctionType = Union[LossFunctionCallable, LossFunction, str] + + def log_loss(labels: np.ndarray, pred: np.ndarray, + sample_weight: Optional[np.ndarray] = None, from_logits=False, small_value=1e-8) -> np.ndarray: """Computes the per-example cross entropy loss. @@ -35,6 +47,10 @@ def log_loss(labels: np.ndarray, num_classes) and pred[i] is the logits or probability vector of the i-th sample. For binary logistic loss, the shape should be (num_samples,) and pred[i] is the probability of the positive class. + sample_weight: a vector of weights of shape (num_samples, ) that are + assigned to individual samples. If not provided, then each sample is + given unit weight. Only the LogisticRegressionAttacker and the + RandomForestAttacker support sample weights. from_logits: whether `pred` is logits or probability vector. small_value: a scalar. np.log can become -inf if the probability is too close to 0, so the probability is clipped below by small_value. @@ -46,6 +62,15 @@ def log_loss(labels: np.ndarray, raise ValueError('labels and pred should have the same number of examples,', f'but got {labels.shape[0]} and {pred.shape[0]}.') classes = np.unique(labels) + if sample_weight is None: + # If sample weights are not provided, set them to 1.0. + sample_weight = 1.0 + else: + if np.shape(sample_weight)[0] != np.shape(labels)[0]: + # Number of elements should be the same. + raise ValueError( + 'Expected sample weights to have the same length as the labels, ' + f'received {np.shape(sample_weight)[0]} and {np.shape(labels)[0]}.') # Binary logistic loss if pred.size == pred.shape[0]: @@ -59,22 +84,29 @@ def log_loss(labels: np.ndarray, indices_class0 = (labels == 0) prob_correct = np.copy(pred) prob_correct[indices_class0] = 1 - prob_correct[indices_class0] - return -np.log(np.maximum(prob_correct, small_value)) + return -np.log(np.maximum(prob_correct, small_value)) * sample_weight # Multi-class categorical cross entropy loss if classes.min() < 0 or classes.max() >= pred.shape[1]: raise ValueError('labels should be in the range [0, num_classes-1].') if from_logits: pred = special.softmax(pred, axis=-1) - return -np.log(np.maximum(pred[range(labels.size), labels], small_value)) + return (-np.log(np.maximum(pred[range(labels.size), labels], small_value)) * + sample_weight) -def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: +def squared_loss(y_true: np.ndarray, + y_pred: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> np.ndarray: """Computes the per-example squared loss. Args: y_true: numpy array of shape (num_samples,) representing the true labels. y_pred: numpy array of shape (num_samples,) representing the predictions. + sample_weight: a vector of weights of shape (num_samples, ) that are + assigned to individual samples. If not provided, then each sample is + given unit weight. Only the LogisticRegressionAttacker and the + RandomForestAttacker support sample weights. Returns: the squared loss of each sample. @@ -93,11 +125,15 @@ def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: raise ValueError('Squared loss expects the labels and predictions to have ' 'shape (num_examples, ), but after np.squeeze, the shapes ' 'are %s and %s.' % (y_true.shape, y_pred.shape)) - return (y_true - y_pred)**2 + if sample_weight is None: + # If sample weights are not provided, set them to 1.0. + sample_weight = 1.0 + return sample_weight * (y_true - y_pred)**2 def multilabel_bce_loss(labels: np.ndarray, pred: np.ndarray, + sample_weight: Optional[np.ndarray] = None, from_logits=False, small_value=1e-8) -> np.ndarray: """Computes the per-multi-label-example cross entropy loss. @@ -108,6 +144,10 @@ def multilabel_bce_loss(labels: np.ndarray, of the vector is one of {0, 1}. pred: numpy array of shape (num_samples, num_classes). pred[i] is the logits or probability vector of the i-th sample. + sample_weight: a vector of weights of shape (num_samples, ) that are + assigned to individual samples. If not provided, then each sample is + given unit weight. Only the LogisticRegressionAttacker and the + RandomForestAttacker support sample weights. from_logits: whether `pred` is logits or probability vector. small_value: a scalar. np.log can become -inf if the probability is too close to 0, so the probability is clipped below by small_value. @@ -132,19 +172,21 @@ def multilabel_bce_loss(labels: np.ndarray, if not from_logits and ((pred < 0.0) | (pred > 1.0)).any(): raise ValueError(('Prediction probabilities are not in [0, 1] and ' '`from_logits` is set to False.')) + if sample_weight is None: + # If sample weights are not provided, set them to 1.0. + sample_weight = 1.0 + if isinstance(sample_weight, list): + sample_weight = np.asarray(sample_weight) + if isinstance(sample_weight, np.ndarray) and (sample_weight.ndim == 1): + # NOMUTANTS--np.reshape(X, (-1, 1)) == np.reshape(X, (-N, 1)), N >=1. + sample_weight = np.reshape(sample_weight, (-1, 1)) # Multi-class multi-label binary cross entropy loss if from_logits: pred = special.expit(pred) bce = labels * np.log(pred + small_value) bce += (1 - labels) * np.log(1 - pred + small_value) - return -bce - - -class LossFunction(enum.Enum): - """An enum that defines loss function.""" - CROSS_ENTROPY = 'cross_entropy' - SQUARED = 'squared' + return -bce * sample_weight def string_to_loss_function(string: str): @@ -157,12 +199,15 @@ def string_to_loss_function(string: str): raise ValueError(f'{string} is not a valid loss function name.') -def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray], - logits: Optional[np.ndarray], probs: Optional[np.ndarray], - loss_function: Union[Callable[[np.ndarray, np.ndarray], - np.ndarray], LossFunction, str], - loss_function_using_logits: Optional[bool], - multilabel_data: Optional[bool]) -> Optional[np.ndarray]: +def get_loss( + loss: Optional[np.ndarray], + labels: Optional[np.ndarray], + logits: Optional[np.ndarray], + probs: Optional[np.ndarray], + loss_function: LossFunctionCallable, + loss_function_using_logits: Optional[bool], + multilabel_data: Optional[bool], + sample_weight: Optional[np.ndarray] = None) -> Optional[np.ndarray]: """Calculates (if needed) losses. Args: @@ -176,6 +221,10 @@ def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray], loss_function_using_logits: if `loss_function` expects `logits` or `probs`. multilabel_data: if the data is from a multilabel classification problem. + sample_weight: a vector of weights of shape (num_samples, ) that are + assigned to individual samples. If not provided, then each sample is + given unit weight. Only the LogisticRegressionAttacker and the + RandomForestAttacker support sample weights. Returns: Loss (or None if neither the loss nor the labels are present). @@ -195,12 +244,13 @@ def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray], loss_function = string_to_loss_function(loss_function) if loss_function == LossFunction.CROSS_ENTROPY: if multilabel_data: - loss = multilabel_bce_loss(labels, predictions, + loss = multilabel_bce_loss(labels, predictions, sample_weight, loss_function_using_logits) else: - loss = log_loss(labels, predictions, loss_function_using_logits) + loss = log_loss(labels, predictions, sample_weight, + loss_function_using_logits) elif loss_function == LossFunction.SQUARED: - loss = squared_loss(labels, predictions) + loss = squared_loss(labels, predictions, sample_weight) else: - loss = loss_function(labels, predictions) + loss = loss_function(labels, predictions, sample_weight) return loss diff --git a/tensorflow_privacy/privacy/privacy_tests/utils_test.py b/tensorflow_privacy/privacy/privacy_tests/utils_test.py index 4c37680..cb592a8 100644 --- a/tensorflow_privacy/privacy/privacy_tests/utils_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/utils_test.py @@ -89,6 +89,15 @@ class TestLogLoss(parameterized.TestCase): loss = utils.log_loss(labels, logits, from_logits=True) np.testing.assert_allclose(expected_loss, loss, atol=1e-7) + def test_log_loss_from_logits_with_sample_weights(self): + logits = np.array([[1, 2, 0, -1], [1, 2, 0, -1], [-1, 3, 0, 0]]) + labels = np.array([0, 3, 1]) + sample_weight = np.array([1.0, 0.0, 0.5]) + expected_loss = np.array([1.4401897, 0.0, 0.05572139]) + + loss = utils.log_loss(labels, logits, sample_weight, from_logits=True) + np.testing.assert_allclose(expected_loss, loss, atol=1e-7) + @parameterized.named_parameters( ('label0', 0, np.array([ @@ -161,6 +170,14 @@ class TestSquaredLoss(parameterized.TestCase): y_pred = np.array([4, 3, 2]) self.assertRaises(ValueError, utils.squared_loss, y_true, y_pred) + def test_squared_loss_with_sample_weights(self): + y_true = np.array([1, 2, 3, 4.]) + y_pred = np.array([4, 3, 2, 1.]) + sample_weight = np.array([1.0, 0.0, 0.5, 1.0]) + expected_loss = np.array([9, 0, 0.5, 9.]) + loss = utils.squared_loss(y_true, y_pred, sample_weight) + np.testing.assert_allclose(loss, expected_loss, atol=1e-7) + class TestMultilabelBCELoss(parameterized.TestCase): @@ -186,6 +203,16 @@ class TestMultilabelBCELoss(parameterized.TestCase): loss = utils.multilabel_bce_loss(label, pred, from_logits=from_logits) np.testing.assert_allclose(loss, expected_losses, atol=1e-6) + def test_multilabel_bce_loss_with_sample_weights(self): + label = np.array([[0, 1, 0], [1, 1, 0]]) + pred = np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]]) + sample_weight = np.array([1.0, 0.5]) + expected_loss = np.array([[0.26328245, 0.85435522, 2.21551943], + [0.34657358, 0.23703848, 0.85070661]]) + loss = utils.multilabel_bce_loss( + label, pred, sample_weight=sample_weight, from_logits=True) + np.testing.assert_allclose(loss, expected_loss, atol=1e-6) + @parameterized.named_parameters( ('from_logits_true_and_incorrect_values_example1', np.array([[0, 1, 1], [1, 1, 0] @@ -216,7 +243,7 @@ class TestMultilabelBCELoss(parameterized.TestCase): ) def test_multilabel_bce_loss_raises(self, label, pred, from_logits, regex): self.assertRaisesRegex(ValueError, regex, utils.multilabel_bce_loss, label, - pred, from_logits) + pred, None, from_logits) class TestGetLoss(parameterized.TestCase):