From 17cd0c52bc4d1f10ff5b4405290064c78ddd73b0 Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Tue, 26 Jul 2022 11:49:14 -0700 Subject: [PATCH] Refactor: move loss computation utilities under `privacy_tests`. PiperOrigin-RevId: 463391913 --- .../privacy/privacy_tests/BUILD | 19 +++++- .../membership_inference_attack/BUILD | 30 +++------ .../advanced_mia.py | 2 +- .../advanced_mia_example.py | 3 +- .../data_structures.py | 60 ++--------------- .../data_structures_test.py | 6 +- .../keras_evaluation.py | 1 - .../tf_estimator_evaluation.py | 2 +- .../utils.py | 65 +++++++++++++++++++ .../utils_test.py | 19 +++++- 10 files changed, 119 insertions(+), 88 deletions(-) rename tensorflow_privacy/privacy/privacy_tests/{membership_inference_attack => }/utils.py (67%) rename tensorflow_privacy/privacy/privacy_tests/{membership_inference_attack => }/utils_test.py (93%) diff --git a/tensorflow_privacy/privacy/privacy_tests/BUILD b/tensorflow_privacy/privacy/privacy_tests/BUILD index 02dc728..b99205b 100644 --- a/tensorflow_privacy/privacy/privacy_tests/BUILD +++ b/tensorflow_privacy/privacy/privacy_tests/BUILD @@ -1,6 +1,6 @@ -load("@rules_python//python:defs.bzl", "py_library") +load("@rules_python//python:defs.bzl", "py_library", "py_test") -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -8,3 +8,18 @@ py_library( name = "privacy_tests", srcs = ["__init__.py"], ) + +py_test( + name = "utils_test", + timeout = "long", + srcs = ["utils_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":utils"], +) + +py_library( + name = "utils", + srcs = ["utils.py"], + srcs_version = "PY3", +) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/BUILD b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/BUILD index 1044742..ff75a51 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/BUILD +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/BUILD @@ -15,21 +15,6 @@ py_library( srcs_version = "PY3", ) -py_library( - name = "utils", - srcs = ["utils.py"], - srcs_version = "PY3", -) - -py_test( - name = "utils_test", - timeout = "long", - srcs = ["utils_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [":utils"], -) - py_test( name = "membership_inference_attack_test", timeout = "long", @@ -45,7 +30,10 @@ py_test( srcs = ["data_structures_test.py"], python_version = "PY3", srcs_version = "PY3", - deps = [":membership_inference_attack"], + deps = [ + ":membership_inference_attack", + "//tensorflow_privacy/privacy/privacy_tests:utils", + ], ) py_test( @@ -95,7 +83,7 @@ py_library( "seq2seq_mia.py", ], srcs_version = "PY3", - deps = [":utils"], + deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"], ) py_library( @@ -122,8 +110,8 @@ py_library( srcs_version = "PY3", deps = [ ":membership_inference_attack", - ":utils", ":utils_tensorboard", + "//tensorflow_privacy/privacy/privacy_tests:utils", ], ) @@ -144,8 +132,8 @@ py_library( srcs_version = "PY3", deps = [ ":membership_inference_attack", - ":utils", ":utils_tensorboard", + "//tensorflow_privacy/privacy/privacy_tests:utils", ], ) @@ -185,7 +173,7 @@ py_library( "advanced_mia.py", ], srcs_version = "PY3", - deps = [":utils"], + deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"], ) py_test( @@ -205,6 +193,6 @@ py_binary( deps = [ ":advanced_mia", ":membership_inference_attack", - ":utils", + "//tensorflow_privacy/privacy/privacy_tests:utils", ], ) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py index 4dd109a..5674c1c 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py @@ -17,7 +17,7 @@ import functools from typing import Sequence, Union import numpy as np import scipy.stats -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss +from tensorflow_privacy.privacy.privacy_tests.utils import log_loss def replace_nan_with_column_mean(a: np.ndarray): diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py index 0cbae68..38b7a43 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py @@ -21,11 +21,10 @@ from absl import flags import matplotlib.pyplot as plt import numpy as np import tensorflow as tf - +from tensorflow_privacy.privacy.privacy_tests import utils from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData FLAGS = flags.FLAGS diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index c124656..c122443 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -26,7 +26,7 @@ import numpy as np import pandas as pd from scipy import special from sklearn import metrics -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils +from tensorflow_privacy.privacy.privacy_tests import utils # The minimum TPR or FPR below which they are considered equal. _ABSOLUTE_TOLERANCE = 1e-3 @@ -183,12 +183,6 @@ def _log_value(probs, small_value=1e-30): return -np.log(np.maximum(probs, small_value)) -class LossFunction(enum.Enum): - """An enum that defines loss function to use in `AttackInputData`.""" - CROSS_ENTROPY = 'cross_entropy' - SQUARED = 'squared' - - @dataclasses.dataclass class AttackInputData: """Input data for running an attack. @@ -225,7 +219,7 @@ class AttackInputData: # If a callable is provided, it should take in two argument, the 1st is # labels, the 2nd is logits or probs. loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], - LossFunction] = LossFunction.CROSS_ENTROPY + utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY # Whether `loss_function` will be called with logits or probs. If not set # (None), will decide by availablity of logits and probs and logits is # preferred when both are available. @@ -298,52 +292,6 @@ class AttackInputData: true_labels] return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1) - @staticmethod - def _get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray], - logits: Optional[np.ndarray], probs: Optional[np.ndarray], - loss_function: Union[Callable[[np.ndarray, np.ndarray], - np.ndarray], LossFunction], - loss_function_using_logits: Optional[bool], - multilabel_data: Optional[bool]) -> Optional[np.ndarray]: - """Calculates (if needed) losses. - - Args: - loss: the loss of each example. - labels: the scalar label of each example. - logits: the logits vector of each example. - probs: the probability vector of each example. - loss_function: if `loss` is not available, `labels` and one of `logits` - and `probs` are available, we will use this function to compute loss. It - is supposed to take in (label, logits / probs) as input. - loss_function_using_logits: if `loss_function` expects `logits` or - `probs`. - multilabel_data: if the data is from a multilabel classification problem. - - Returns: - Loss (or None if neither the loss nor the labels are present). - """ - if loss is not None: - return loss - if labels is None or (logits is None and probs is None): - return None - if loss_function_using_logits and logits is None: - raise ValueError('We need logits to compute loss, but it is set to None.') - if not loss_function_using_logits and probs is None: - raise ValueError('We need probs to compute loss, but it is set to None.') - - predictions = logits if loss_function_using_logits else probs - if loss_function == LossFunction.CROSS_ENTROPY: - if multilabel_data: - loss = utils.multilabel_bce_loss(labels, predictions, - loss_function_using_logits) - else: - loss = utils.log_loss(labels, predictions, loss_function_using_logits) - elif loss_function == LossFunction.SQUARED: - loss = utils.squared_loss(labels, predictions) - else: - loss = loss_function(labels, predictions) - return loss - def __post_init__(self): """Checks performed after instantiation of the AttackInputData dataclass.""" # Check if the data is multilabel. @@ -358,7 +306,7 @@ class AttackInputData: """ if self.loss_function_using_logits is None: self.loss_function_using_logits = (self.logits_train is not None) - return self._get_loss(self.loss_train, self.labels_train, self.logits_train, + return utils.get_loss(self.loss_train, self.labels_train, self.logits_train, self.probs_train, self.loss_function, self.loss_function_using_logits, self.multilabel_data) @@ -370,7 +318,7 @@ class AttackInputData: """ if self.loss_function_using_logits is None: self.loss_function_using_logits = bool(self.logits_test) - return self._get_loss(self.loss_test, self.labels_test, self.logits_test, + return utils.get_loss(self.loss_test, self.labels_test, self.logits_test, self.probs_test, self.loss_function, self.loss_function_using_logits, self.multilabel_data) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py index cb4369c..23432fe 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py @@ -20,13 +20,13 @@ from absl.testing import absltest from absl.testing import parameterized import numpy as np import pandas as pd +from tensorflow_privacy.privacy.privacy_tests import utils from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import LossFunction from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult @@ -123,7 +123,7 @@ class AttackInputDataTest(parameterized.TestCase): probs_test=np.array([1, 1.]), labels_train=np.array([1, 0.]), labels_test=np.array([0, 2.]), - loss_function=LossFunction.SQUARED, + loss_function=utils.LossFunction.SQUARED, loss_function_using_logits=loss_function_using_logits, ) np.testing.assert_allclose(attack_input.get_loss_train(), expected_train) @@ -175,7 +175,7 @@ class AttackInputDataTest(parameterized.TestCase): probs_test=probs, labels_train=np.array([1, 0.]), labels_test=np.array([1, 0.]), - loss_function=LossFunction.SQUARED, + loss_function=utils.LossFunction.SQUARED, ) np.testing.assert_allclose(attack_input.get_loss_train(), expected) np.testing.assert_allclose(attack_input.get_loss_test(), expected) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py index fa0875c..53c8393 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py @@ -24,7 +24,6 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py index bf13558..9d87c4b 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/tf_estimator_evaluation.py @@ -20,9 +20,9 @@ from absl import logging import numpy as np import tensorflow as tf from tensorflow import estimator as tf_estimator +from tensorflow_privacy.privacy.privacy_tests import utils from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py b/tensorflow_privacy/privacy/privacy_tests/utils.py similarity index 67% rename from tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py rename to tensorflow_privacy/privacy/privacy_tests/utils.py index 9d02d83..b09384a 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils.py +++ b/tensorflow_privacy/privacy/privacy_tests/utils.py @@ -13,7 +13,10 @@ # limitations under the License. """Utility functions for membership inference attacks.""" +import enum import logging +from typing import Callable, Optional, Union + import numpy as np from scipy import special @@ -122,3 +125,65 @@ def multilabel_bce_loss(labels: np.ndarray, bce = labels * np.log(pred + small_value) bce += (1 - labels) * np.log(1 - pred + small_value) return -bce + + +class LossFunction(enum.Enum): + """An enum that defines loss function.""" + CROSS_ENTROPY = 'cross_entropy' + SQUARED = 'squared' + + +def string_to_loss_function(string: str): + """Convert string to the corresponding LossFunction.""" + + if string == LossFunction.CROSS_ENTROPY.value: + return LossFunction.CROSS_ENTROPY + if string == LossFunction.SQUARED.value: + return LossFunction.SQUARED + raise ValueError(f'{string} is not a valid loss function name.') + + +def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray], + logits: Optional[np.ndarray], probs: Optional[np.ndarray], + loss_function: Union[Callable[[np.ndarray, np.ndarray], + np.ndarray], LossFunction], + loss_function_using_logits: Optional[bool], + multilabel_data: Optional[bool]) -> Optional[np.ndarray]: + """Calculates (if needed) losses. + + Args: + loss: the loss of each example. + labels: the scalar label of each example. + logits: the logits vector of each example. + probs: the probability vector of each example. + loss_function: if `loss` is not available, `labels` and one of `logits` + and `probs` are available, we will use this function to compute loss. It + is supposed to take in (label, logits / probs) as input. + loss_function_using_logits: if `loss_function` expects `logits` or + `probs`. + multilabel_data: if the data is from a multilabel classification problem. + + Returns: + Loss (or None if neither the loss nor the labels are present). + """ + if loss is not None: + return loss + if labels is None or (logits is None and probs is None): + return None + if loss_function_using_logits and logits is None: + raise ValueError('We need logits to compute loss, but it is set to None.') + if not loss_function_using_logits and probs is None: + raise ValueError('We need probs to compute loss, but it is set to None.') + + predictions = logits if loss_function_using_logits else probs + if loss_function == LossFunction.CROSS_ENTROPY: + if multilabel_data: + loss = multilabel_bce_loss(labels, predictions, + loss_function_using_logits) + else: + loss = log_loss(labels, predictions, loss_function_using_logits) + elif loss_function == LossFunction.SQUARED: + loss = squared_loss(labels, predictions) + else: + loss = loss_function(labels, predictions) + return loss diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py b/tensorflow_privacy/privacy/privacy_tests/utils_test.py similarity index 93% rename from tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py rename to tensorflow_privacy/privacy/privacy_tests/utils_test.py index 725aeb8..b65a3b9 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/utils_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/utils_test.py @@ -16,7 +16,24 @@ from absl.testing import absltest from absl.testing import parameterized import numpy as np -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils +from tensorflow_privacy.privacy.privacy_tests import utils + + +class LossFunctionFromStringTest(parameterized.TestCase): + + @parameterized.parameters( + (utils.LossFunction.CROSS_ENTROPY, 'cross_entropy'), + (utils.LossFunction.SQUARED, 'squared'), + ) + def test_from_str(self, en, string): + self.assertEqual(utils.string_to_loss_function(string), en) + + @parameterized.parameters( + ('random string'), + (''), + ) + def test_from_str_wrong_input(self, string): + self.assertRaises(ValueError, utils.string_to_loss_function, string) class TestLogLoss(parameterized.TestCase):