Refactor: move loss computation utilities under privacy_tests
.
PiperOrigin-RevId: 463391913
This commit is contained in:
parent
44dc40454b
commit
17cd0c52bc
10 changed files with 119 additions and 88 deletions
|
@ -1,6 +1,6 @@
|
|||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load("@rules_python//python:defs.bzl", "py_library", "py_test")
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
@ -8,3 +8,18 @@ py_library(
|
|||
name = "privacy_tests",
|
||||
srcs = ["__init__.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "utils_test",
|
||||
timeout = "long",
|
||||
srcs = ["utils_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":utils"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "utils",
|
||||
srcs = ["utils.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
|
|
@ -15,21 +15,6 @@ py_library(
|
|||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "utils",
|
||||
srcs = ["utils.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "utils_test",
|
||||
timeout = "long",
|
||||
srcs = ["utils_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":utils"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "membership_inference_attack_test",
|
||||
timeout = "long",
|
||||
|
@ -45,7 +30,10 @@ py_test(
|
|||
srcs = ["data_structures_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":membership_inference_attack"],
|
||||
deps = [
|
||||
":membership_inference_attack",
|
||||
"//tensorflow_privacy/privacy/privacy_tests:utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -95,7 +83,7 @@ py_library(
|
|||
"seq2seq_mia.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [":utils"],
|
||||
deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -122,8 +110,8 @@ py_library(
|
|||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":membership_inference_attack",
|
||||
":utils",
|
||||
":utils_tensorboard",
|
||||
"//tensorflow_privacy/privacy/privacy_tests:utils",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -144,8 +132,8 @@ py_library(
|
|||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":membership_inference_attack",
|
||||
":utils",
|
||||
":utils_tensorboard",
|
||||
"//tensorflow_privacy/privacy/privacy_tests:utils",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -185,7 +173,7 @@ py_library(
|
|||
"advanced_mia.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [":utils"],
|
||||
deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -205,6 +193,6 @@ py_binary(
|
|||
deps = [
|
||||
":advanced_mia",
|
||||
":membership_inference_attack",
|
||||
":utils",
|
||||
"//tensorflow_privacy/privacy/privacy_tests:utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ import functools
|
|||
from typing import Sequence, Union
|
||||
import numpy as np
|
||||
import scipy.stats
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss
|
||||
from tensorflow_privacy.privacy.privacy_tests.utils import log_loss
|
||||
|
||||
|
||||
def replace_nan_with_column_mean(a: np.ndarray):
|
||||
|
|
|
@ -21,11 +21,10 @@ from absl import flags
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.privacy_tests import utils
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
|
|
@ -26,7 +26,7 @@ import numpy as np
|
|||
import pandas as pd
|
||||
from scipy import special
|
||||
from sklearn import metrics
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
|
||||
from tensorflow_privacy.privacy.privacy_tests import utils
|
||||
|
||||
# The minimum TPR or FPR below which they are considered equal.
|
||||
_ABSOLUTE_TOLERANCE = 1e-3
|
||||
|
@ -183,12 +183,6 @@ def _log_value(probs, small_value=1e-30):
|
|||
return -np.log(np.maximum(probs, small_value))
|
||||
|
||||
|
||||
class LossFunction(enum.Enum):
|
||||
"""An enum that defines loss function to use in `AttackInputData`."""
|
||||
CROSS_ENTROPY = 'cross_entropy'
|
||||
SQUARED = 'squared'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AttackInputData:
|
||||
"""Input data for running an attack.
|
||||
|
@ -225,7 +219,7 @@ class AttackInputData:
|
|||
# If a callable is provided, it should take in two argument, the 1st is
|
||||
# labels, the 2nd is logits or probs.
|
||||
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
|
||||
LossFunction] = LossFunction.CROSS_ENTROPY
|
||||
utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
|
||||
# Whether `loss_function` will be called with logits or probs. If not set
|
||||
# (None), will decide by availablity of logits and probs and logits is
|
||||
# preferred when both are available.
|
||||
|
@ -298,52 +292,6 @@ class AttackInputData:
|
|||
true_labels]
|
||||
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
|
||||
|
||||
@staticmethod
|
||||
def _get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
||||
loss_function: Union[Callable[[np.ndarray, np.ndarray],
|
||||
np.ndarray], LossFunction],
|
||||
loss_function_using_logits: Optional[bool],
|
||||
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
|
||||
"""Calculates (if needed) losses.
|
||||
|
||||
Args:
|
||||
loss: the loss of each example.
|
||||
labels: the scalar label of each example.
|
||||
logits: the logits vector of each example.
|
||||
probs: the probability vector of each example.
|
||||
loss_function: if `loss` is not available, `labels` and one of `logits`
|
||||
and `probs` are available, we will use this function to compute loss. It
|
||||
is supposed to take in (label, logits / probs) as input.
|
||||
loss_function_using_logits: if `loss_function` expects `logits` or
|
||||
`probs`.
|
||||
multilabel_data: if the data is from a multilabel classification problem.
|
||||
|
||||
Returns:
|
||||
Loss (or None if neither the loss nor the labels are present).
|
||||
"""
|
||||
if loss is not None:
|
||||
return loss
|
||||
if labels is None or (logits is None and probs is None):
|
||||
return None
|
||||
if loss_function_using_logits and logits is None:
|
||||
raise ValueError('We need logits to compute loss, but it is set to None.')
|
||||
if not loss_function_using_logits and probs is None:
|
||||
raise ValueError('We need probs to compute loss, but it is set to None.')
|
||||
|
||||
predictions = logits if loss_function_using_logits else probs
|
||||
if loss_function == LossFunction.CROSS_ENTROPY:
|
||||
if multilabel_data:
|
||||
loss = utils.multilabel_bce_loss(labels, predictions,
|
||||
loss_function_using_logits)
|
||||
else:
|
||||
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
|
||||
elif loss_function == LossFunction.SQUARED:
|
||||
loss = utils.squared_loss(labels, predictions)
|
||||
else:
|
||||
loss = loss_function(labels, predictions)
|
||||
return loss
|
||||
|
||||
def __post_init__(self):
|
||||
"""Checks performed after instantiation of the AttackInputData dataclass."""
|
||||
# Check if the data is multilabel.
|
||||
|
@ -358,7 +306,7 @@ class AttackInputData:
|
|||
"""
|
||||
if self.loss_function_using_logits is None:
|
||||
self.loss_function_using_logits = (self.logits_train is not None)
|
||||
return self._get_loss(self.loss_train, self.labels_train, self.logits_train,
|
||||
return utils.get_loss(self.loss_train, self.labels_train, self.logits_train,
|
||||
self.probs_train, self.loss_function,
|
||||
self.loss_function_using_logits, self.multilabel_data)
|
||||
|
||||
|
@ -370,7 +318,7 @@ class AttackInputData:
|
|||
"""
|
||||
if self.loss_function_using_logits is None:
|
||||
self.loss_function_using_logits = bool(self.logits_test)
|
||||
return self._get_loss(self.loss_test, self.labels_test, self.logits_test,
|
||||
return utils.get_loss(self.loss_test, self.labels_test, self.logits_test,
|
||||
self.probs_test, self.loss_function,
|
||||
self.loss_function_using_logits, self.multilabel_data)
|
||||
|
||||
|
|
|
@ -20,13 +20,13 @@ from absl.testing import absltest
|
|||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tensorflow_privacy.privacy.privacy_tests import utils
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import LossFunction
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
|
||||
|
@ -123,7 +123,7 @@ class AttackInputDataTest(parameterized.TestCase):
|
|||
probs_test=np.array([1, 1.]),
|
||||
labels_train=np.array([1, 0.]),
|
||||
labels_test=np.array([0, 2.]),
|
||||
loss_function=LossFunction.SQUARED,
|
||||
loss_function=utils.LossFunction.SQUARED,
|
||||
loss_function_using_logits=loss_function_using_logits,
|
||||
)
|
||||
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
|
||||
|
@ -175,7 +175,7 @@ class AttackInputDataTest(parameterized.TestCase):
|
|||
probs_test=probs,
|
||||
labels_train=np.array([1, 0.]),
|
||||
labels_test=np.array([1, 0.]),
|
||||
loss_function=LossFunction.SQUARED,
|
||||
loss_function=utils.LossFunction.SQUARED,
|
||||
)
|
||||
np.testing.assert_allclose(attack_input.get_loss_train(), expected)
|
||||
np.testing.assert_allclose(attack_input.get_loss_test(), expected)
|
||||
|
|
|
@ -24,7 +24,6 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s
|
|||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard
|
||||
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@ from absl import logging
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import estimator as tf_estimator
|
||||
from tensorflow_privacy.privacy.privacy_tests import utils
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,10 @@
|
|||
# limitations under the License.
|
||||
"""Utility functions for membership inference attacks."""
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from scipy import special
|
||||
|
||||
|
@ -122,3 +125,65 @@ def multilabel_bce_loss(labels: np.ndarray,
|
|||
bce = labels * np.log(pred + small_value)
|
||||
bce += (1 - labels) * np.log(1 - pred + small_value)
|
||||
return -bce
|
||||
|
||||
|
||||
class LossFunction(enum.Enum):
|
||||
"""An enum that defines loss function."""
|
||||
CROSS_ENTROPY = 'cross_entropy'
|
||||
SQUARED = 'squared'
|
||||
|
||||
|
||||
def string_to_loss_function(string: str):
|
||||
"""Convert string to the corresponding LossFunction."""
|
||||
|
||||
if string == LossFunction.CROSS_ENTROPY.value:
|
||||
return LossFunction.CROSS_ENTROPY
|
||||
if string == LossFunction.SQUARED.value:
|
||||
return LossFunction.SQUARED
|
||||
raise ValueError(f'{string} is not a valid loss function name.')
|
||||
|
||||
|
||||
def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
||||
loss_function: Union[Callable[[np.ndarray, np.ndarray],
|
||||
np.ndarray], LossFunction],
|
||||
loss_function_using_logits: Optional[bool],
|
||||
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
|
||||
"""Calculates (if needed) losses.
|
||||
|
||||
Args:
|
||||
loss: the loss of each example.
|
||||
labels: the scalar label of each example.
|
||||
logits: the logits vector of each example.
|
||||
probs: the probability vector of each example.
|
||||
loss_function: if `loss` is not available, `labels` and one of `logits`
|
||||
and `probs` are available, we will use this function to compute loss. It
|
||||
is supposed to take in (label, logits / probs) as input.
|
||||
loss_function_using_logits: if `loss_function` expects `logits` or
|
||||
`probs`.
|
||||
multilabel_data: if the data is from a multilabel classification problem.
|
||||
|
||||
Returns:
|
||||
Loss (or None if neither the loss nor the labels are present).
|
||||
"""
|
||||
if loss is not None:
|
||||
return loss
|
||||
if labels is None or (logits is None and probs is None):
|
||||
return None
|
||||
if loss_function_using_logits and logits is None:
|
||||
raise ValueError('We need logits to compute loss, but it is set to None.')
|
||||
if not loss_function_using_logits and probs is None:
|
||||
raise ValueError('We need probs to compute loss, but it is set to None.')
|
||||
|
||||
predictions = logits if loss_function_using_logits else probs
|
||||
if loss_function == LossFunction.CROSS_ENTROPY:
|
||||
if multilabel_data:
|
||||
loss = multilabel_bce_loss(labels, predictions,
|
||||
loss_function_using_logits)
|
||||
else:
|
||||
loss = log_loss(labels, predictions, loss_function_using_logits)
|
||||
elif loss_function == LossFunction.SQUARED:
|
||||
loss = squared_loss(labels, predictions)
|
||||
else:
|
||||
loss = loss_function(labels, predictions)
|
||||
return loss
|
|
@ -16,7 +16,24 @@ from absl.testing import absltest
|
|||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
|
||||
from tensorflow_privacy.privacy.privacy_tests import utils
|
||||
|
||||
|
||||
class LossFunctionFromStringTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
(utils.LossFunction.CROSS_ENTROPY, 'cross_entropy'),
|
||||
(utils.LossFunction.SQUARED, 'squared'),
|
||||
)
|
||||
def test_from_str(self, en, string):
|
||||
self.assertEqual(utils.string_to_loss_function(string), en)
|
||||
|
||||
@parameterized.parameters(
|
||||
('random string'),
|
||||
(''),
|
||||
)
|
||||
def test_from_str_wrong_input(self, string):
|
||||
self.assertRaises(ValueError, utils.string_to_loss_function, string)
|
||||
|
||||
|
||||
class TestLogLoss(parameterized.TestCase):
|
Loading…
Reference in a new issue