Refactor: move loss computation utilities under privacy_tests.

PiperOrigin-RevId: 463391913
This commit is contained in:
Shuang Song 2022-07-26 11:49:14 -07:00 committed by A. Unique TensorFlower
parent 44dc40454b
commit 17cd0c52bc
10 changed files with 119 additions and 88 deletions

View file

@ -1,6 +1,6 @@
load("@rules_python//python:defs.bzl", "py_library")
load("@rules_python//python:defs.bzl", "py_library", "py_test")
package(default_visibility = ["//visibility:private"])
package(default_visibility = ["//visibility:public"])
licenses(["notice"])
@ -8,3 +8,18 @@ py_library(
name = "privacy_tests",
srcs = ["__init__.py"],
)
py_test(
name = "utils_test",
timeout = "long",
srcs = ["utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":utils"],
)
py_library(
name = "utils",
srcs = ["utils.py"],
srcs_version = "PY3",
)

View file

@ -15,21 +15,6 @@ py_library(
srcs_version = "PY3",
)
py_library(
name = "utils",
srcs = ["utils.py"],
srcs_version = "PY3",
)
py_test(
name = "utils_test",
timeout = "long",
srcs = ["utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":utils"],
)
py_test(
name = "membership_inference_attack_test",
timeout = "long",
@ -45,7 +30,10 @@ py_test(
srcs = ["data_structures_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":membership_inference_attack"],
deps = [
":membership_inference_attack",
"//tensorflow_privacy/privacy/privacy_tests:utils",
],
)
py_test(
@ -95,7 +83,7 @@ py_library(
"seq2seq_mia.py",
],
srcs_version = "PY3",
deps = [":utils"],
deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"],
)
py_library(
@ -122,8 +110,8 @@ py_library(
srcs_version = "PY3",
deps = [
":membership_inference_attack",
":utils",
":utils_tensorboard",
"//tensorflow_privacy/privacy/privacy_tests:utils",
],
)
@ -144,8 +132,8 @@ py_library(
srcs_version = "PY3",
deps = [
":membership_inference_attack",
":utils",
":utils_tensorboard",
"//tensorflow_privacy/privacy/privacy_tests:utils",
],
)
@ -185,7 +173,7 @@ py_library(
"advanced_mia.py",
],
srcs_version = "PY3",
deps = [":utils"],
deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"],
)
py_test(
@ -205,6 +193,6 @@ py_binary(
deps = [
":advanced_mia",
":membership_inference_attack",
":utils",
"//tensorflow_privacy/privacy/privacy_tests:utils",
],
)

View file

@ -17,7 +17,7 @@ import functools
from typing import Sequence, Union
import numpy as np
import scipy.stats
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.privacy_tests.utils import log_loss
def replace_nan_with_column_mean(a: np.ndarray):

View file

@ -21,11 +21,10 @@ from absl import flags
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.privacy_tests import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
FLAGS = flags.FLAGS

View file

@ -26,7 +26,7 @@ import numpy as np
import pandas as pd
from scipy import special
from sklearn import metrics
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
from tensorflow_privacy.privacy.privacy_tests import utils
# The minimum TPR or FPR below which they are considered equal.
_ABSOLUTE_TOLERANCE = 1e-3
@ -183,12 +183,6 @@ def _log_value(probs, small_value=1e-30):
return -np.log(np.maximum(probs, small_value))
class LossFunction(enum.Enum):
"""An enum that defines loss function to use in `AttackInputData`."""
CROSS_ENTROPY = 'cross_entropy'
SQUARED = 'squared'
@dataclasses.dataclass
class AttackInputData:
"""Input data for running an attack.
@ -225,7 +219,7 @@ class AttackInputData:
# If a callable is provided, it should take in two argument, the 1st is
# labels, the 2nd is logits or probs.
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
LossFunction] = LossFunction.CROSS_ENTROPY
utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
# Whether `loss_function` will be called with logits or probs. If not set
# (None), will decide by availablity of logits and probs and logits is
# preferred when both are available.
@ -298,52 +292,6 @@ class AttackInputData:
true_labels]
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
@staticmethod
def _get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray],
np.ndarray], LossFunction],
loss_function_using_logits: Optional[bool],
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
"""Calculates (if needed) losses.
Args:
loss: the loss of each example.
labels: the scalar label of each example.
logits: the logits vector of each example.
probs: the probability vector of each example.
loss_function: if `loss` is not available, `labels` and one of `logits`
and `probs` are available, we will use this function to compute loss. It
is supposed to take in (label, logits / probs) as input.
loss_function_using_logits: if `loss_function` expects `logits` or
`probs`.
multilabel_data: if the data is from a multilabel classification problem.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if loss is not None:
return loss
if labels is None or (logits is None and probs is None):
return None
if loss_function_using_logits and logits is None:
raise ValueError('We need logits to compute loss, but it is set to None.')
if not loss_function_using_logits and probs is None:
raise ValueError('We need probs to compute loss, but it is set to None.')
predictions = logits if loss_function_using_logits else probs
if loss_function == LossFunction.CROSS_ENTROPY:
if multilabel_data:
loss = utils.multilabel_bce_loss(labels, predictions,
loss_function_using_logits)
else:
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
elif loss_function == LossFunction.SQUARED:
loss = utils.squared_loss(labels, predictions)
else:
loss = loss_function(labels, predictions)
return loss
def __post_init__(self):
"""Checks performed after instantiation of the AttackInputData dataclass."""
# Check if the data is multilabel.
@ -358,7 +306,7 @@ class AttackInputData:
"""
if self.loss_function_using_logits is None:
self.loss_function_using_logits = (self.logits_train is not None)
return self._get_loss(self.loss_train, self.labels_train, self.logits_train,
return utils.get_loss(self.loss_train, self.labels_train, self.logits_train,
self.probs_train, self.loss_function,
self.loss_function_using_logits, self.multilabel_data)
@ -370,7 +318,7 @@ class AttackInputData:
"""
if self.loss_function_using_logits is None:
self.loss_function_using_logits = bool(self.logits_test)
return self._get_loss(self.loss_test, self.labels_test, self.logits_test,
return utils.get_loss(self.loss_test, self.labels_test, self.logits_test,
self.probs_test, self.loss_function,
self.loss_function_using_logits, self.multilabel_data)

View file

@ -20,13 +20,13 @@ from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import pandas as pd
from tensorflow_privacy.privacy.privacy_tests import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import LossFunction
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
@ -123,7 +123,7 @@ class AttackInputDataTest(parameterized.TestCase):
probs_test=np.array([1, 1.]),
labels_train=np.array([1, 0.]),
labels_test=np.array([0, 2.]),
loss_function=LossFunction.SQUARED,
loss_function=utils.LossFunction.SQUARED,
loss_function_using_logits=loss_function_using_logits,
)
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
@ -175,7 +175,7 @@ class AttackInputDataTest(parameterized.TestCase):
probs_test=probs,
labels_train=np.array([1, 0.]),
labels_test=np.array([1, 0.]),
loss_function=LossFunction.SQUARED,
loss_function=utils.LossFunction.SQUARED,
)
np.testing.assert_allclose(attack_input.get_loss_train(), expected)
np.testing.assert_allclose(attack_input.get_loss_test(), expected)

View file

@ -24,7 +24,6 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard

View file

@ -20,9 +20,9 @@ from absl import logging
import numpy as np
import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow_privacy.privacy.privacy_tests import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard

View file

@ -13,7 +13,10 @@
# limitations under the License.
"""Utility functions for membership inference attacks."""
import enum
import logging
from typing import Callable, Optional, Union
import numpy as np
from scipy import special
@ -122,3 +125,65 @@ def multilabel_bce_loss(labels: np.ndarray,
bce = labels * np.log(pred + small_value)
bce += (1 - labels) * np.log(1 - pred + small_value)
return -bce
class LossFunction(enum.Enum):
"""An enum that defines loss function."""
CROSS_ENTROPY = 'cross_entropy'
SQUARED = 'squared'
def string_to_loss_function(string: str):
"""Convert string to the corresponding LossFunction."""
if string == LossFunction.CROSS_ENTROPY.value:
return LossFunction.CROSS_ENTROPY
if string == LossFunction.SQUARED.value:
return LossFunction.SQUARED
raise ValueError(f'{string} is not a valid loss function name.')
def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray],
np.ndarray], LossFunction],
loss_function_using_logits: Optional[bool],
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
"""Calculates (if needed) losses.
Args:
loss: the loss of each example.
labels: the scalar label of each example.
logits: the logits vector of each example.
probs: the probability vector of each example.
loss_function: if `loss` is not available, `labels` and one of `logits`
and `probs` are available, we will use this function to compute loss. It
is supposed to take in (label, logits / probs) as input.
loss_function_using_logits: if `loss_function` expects `logits` or
`probs`.
multilabel_data: if the data is from a multilabel classification problem.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if loss is not None:
return loss
if labels is None or (logits is None and probs is None):
return None
if loss_function_using_logits and logits is None:
raise ValueError('We need logits to compute loss, but it is set to None.')
if not loss_function_using_logits and probs is None:
raise ValueError('We need probs to compute loss, but it is set to None.')
predictions = logits if loss_function_using_logits else probs
if loss_function == LossFunction.CROSS_ENTROPY:
if multilabel_data:
loss = multilabel_bce_loss(labels, predictions,
loss_function_using_logits)
else:
loss = log_loss(labels, predictions, loss_function_using_logits)
elif loss_function == LossFunction.SQUARED:
loss = squared_loss(labels, predictions)
else:
loss = loss_function(labels, predictions)
return loss

View file

@ -16,7 +16,24 @@ from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
from tensorflow_privacy.privacy.privacy_tests import utils
class LossFunctionFromStringTest(parameterized.TestCase):
@parameterized.parameters(
(utils.LossFunction.CROSS_ENTROPY, 'cross_entropy'),
(utils.LossFunction.SQUARED, 'squared'),
)
def test_from_str(self, en, string):
self.assertEqual(utils.string_to_loss_function(string), en)
@parameterized.parameters(
('random string'),
(''),
)
def test_from_str_wrong_input(self, string):
self.assertRaises(ValueError, utils.string_to_loss_function, string)
class TestLogLoss(parameterized.TestCase):