Refactor: move loss computation utilities under privacy_tests.

PiperOrigin-RevId: 463391913
This commit is contained in:
Shuang Song 2022-07-26 11:49:14 -07:00 committed by A. Unique TensorFlower
parent 44dc40454b
commit 17cd0c52bc
10 changed files with 119 additions and 88 deletions

View file

@ -1,6 +1,6 @@
load("@rules_python//python:defs.bzl", "py_library") load("@rules_python//python:defs.bzl", "py_library", "py_test")
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:public"])
licenses(["notice"]) licenses(["notice"])
@ -8,3 +8,18 @@ py_library(
name = "privacy_tests", name = "privacy_tests",
srcs = ["__init__.py"], srcs = ["__init__.py"],
) )
py_test(
name = "utils_test",
timeout = "long",
srcs = ["utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":utils"],
)
py_library(
name = "utils",
srcs = ["utils.py"],
srcs_version = "PY3",
)

View file

@ -15,21 +15,6 @@ py_library(
srcs_version = "PY3", srcs_version = "PY3",
) )
py_library(
name = "utils",
srcs = ["utils.py"],
srcs_version = "PY3",
)
py_test(
name = "utils_test",
timeout = "long",
srcs = ["utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":utils"],
)
py_test( py_test(
name = "membership_inference_attack_test", name = "membership_inference_attack_test",
timeout = "long", timeout = "long",
@ -45,7 +30,10 @@ py_test(
srcs = ["data_structures_test.py"], srcs = ["data_structures_test.py"],
python_version = "PY3", python_version = "PY3",
srcs_version = "PY3", srcs_version = "PY3",
deps = [":membership_inference_attack"], deps = [
":membership_inference_attack",
"//tensorflow_privacy/privacy/privacy_tests:utils",
],
) )
py_test( py_test(
@ -95,7 +83,7 @@ py_library(
"seq2seq_mia.py", "seq2seq_mia.py",
], ],
srcs_version = "PY3", srcs_version = "PY3",
deps = [":utils"], deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"],
) )
py_library( py_library(
@ -122,8 +110,8 @@ py_library(
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":membership_inference_attack", ":membership_inference_attack",
":utils",
":utils_tensorboard", ":utils_tensorboard",
"//tensorflow_privacy/privacy/privacy_tests:utils",
], ],
) )
@ -144,8 +132,8 @@ py_library(
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":membership_inference_attack", ":membership_inference_attack",
":utils",
":utils_tensorboard", ":utils_tensorboard",
"//tensorflow_privacy/privacy/privacy_tests:utils",
], ],
) )
@ -185,7 +173,7 @@ py_library(
"advanced_mia.py", "advanced_mia.py",
], ],
srcs_version = "PY3", srcs_version = "PY3",
deps = [":utils"], deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"],
) )
py_test( py_test(
@ -205,6 +193,6 @@ py_binary(
deps = [ deps = [
":advanced_mia", ":advanced_mia",
":membership_inference_attack", ":membership_inference_attack",
":utils", "//tensorflow_privacy/privacy/privacy_tests:utils",
], ],
) )

View file

@ -17,7 +17,7 @@ import functools
from typing import Sequence, Union from typing import Sequence, Union
import numpy as np import numpy as np
import scipy.stats import scipy.stats
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.privacy_tests.utils import log_loss
def replace_nan_with_column_mean(a: np.ndarray): def replace_nan_with_column_mean(a: np.ndarray):

View file

@ -21,11 +21,10 @@ from absl import flags
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.privacy_tests import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
FLAGS = flags.FLAGS FLAGS = flags.FLAGS

View file

@ -26,7 +26,7 @@ import numpy as np
import pandas as pd import pandas as pd
from scipy import special from scipy import special
from sklearn import metrics from sklearn import metrics
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils from tensorflow_privacy.privacy.privacy_tests import utils
# The minimum TPR or FPR below which they are considered equal. # The minimum TPR or FPR below which they are considered equal.
_ABSOLUTE_TOLERANCE = 1e-3 _ABSOLUTE_TOLERANCE = 1e-3
@ -183,12 +183,6 @@ def _log_value(probs, small_value=1e-30):
return -np.log(np.maximum(probs, small_value)) return -np.log(np.maximum(probs, small_value))
class LossFunction(enum.Enum):
"""An enum that defines loss function to use in `AttackInputData`."""
CROSS_ENTROPY = 'cross_entropy'
SQUARED = 'squared'
@dataclasses.dataclass @dataclasses.dataclass
class AttackInputData: class AttackInputData:
"""Input data for running an attack. """Input data for running an attack.
@ -225,7 +219,7 @@ class AttackInputData:
# If a callable is provided, it should take in two argument, the 1st is # If a callable is provided, it should take in two argument, the 1st is
# labels, the 2nd is logits or probs. # labels, the 2nd is logits or probs.
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
LossFunction] = LossFunction.CROSS_ENTROPY utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
# Whether `loss_function` will be called with logits or probs. If not set # Whether `loss_function` will be called with logits or probs. If not set
# (None), will decide by availablity of logits and probs and logits is # (None), will decide by availablity of logits and probs and logits is
# preferred when both are available. # preferred when both are available.
@ -298,52 +292,6 @@ class AttackInputData:
true_labels] true_labels]
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1) return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
@staticmethod
def _get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray],
np.ndarray], LossFunction],
loss_function_using_logits: Optional[bool],
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
"""Calculates (if needed) losses.
Args:
loss: the loss of each example.
labels: the scalar label of each example.
logits: the logits vector of each example.
probs: the probability vector of each example.
loss_function: if `loss` is not available, `labels` and one of `logits`
and `probs` are available, we will use this function to compute loss. It
is supposed to take in (label, logits / probs) as input.
loss_function_using_logits: if `loss_function` expects `logits` or
`probs`.
multilabel_data: if the data is from a multilabel classification problem.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if loss is not None:
return loss
if labels is None or (logits is None and probs is None):
return None
if loss_function_using_logits and logits is None:
raise ValueError('We need logits to compute loss, but it is set to None.')
if not loss_function_using_logits and probs is None:
raise ValueError('We need probs to compute loss, but it is set to None.')
predictions = logits if loss_function_using_logits else probs
if loss_function == LossFunction.CROSS_ENTROPY:
if multilabel_data:
loss = utils.multilabel_bce_loss(labels, predictions,
loss_function_using_logits)
else:
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
elif loss_function == LossFunction.SQUARED:
loss = utils.squared_loss(labels, predictions)
else:
loss = loss_function(labels, predictions)
return loss
def __post_init__(self): def __post_init__(self):
"""Checks performed after instantiation of the AttackInputData dataclass.""" """Checks performed after instantiation of the AttackInputData dataclass."""
# Check if the data is multilabel. # Check if the data is multilabel.
@ -358,7 +306,7 @@ class AttackInputData:
""" """
if self.loss_function_using_logits is None: if self.loss_function_using_logits is None:
self.loss_function_using_logits = (self.logits_train is not None) self.loss_function_using_logits = (self.logits_train is not None)
return self._get_loss(self.loss_train, self.labels_train, self.logits_train, return utils.get_loss(self.loss_train, self.labels_train, self.logits_train,
self.probs_train, self.loss_function, self.probs_train, self.loss_function,
self.loss_function_using_logits, self.multilabel_data) self.loss_function_using_logits, self.multilabel_data)
@ -370,7 +318,7 @@ class AttackInputData:
""" """
if self.loss_function_using_logits is None: if self.loss_function_using_logits is None:
self.loss_function_using_logits = bool(self.logits_test) self.loss_function_using_logits = bool(self.logits_test)
return self._get_loss(self.loss_test, self.labels_test, self.logits_test, return utils.get_loss(self.loss_test, self.labels_test, self.logits_test,
self.probs_test, self.loss_function, self.probs_test, self.loss_function,
self.loss_function_using_logits, self.multilabel_data) self.loss_function_using_logits, self.multilabel_data)

View file

@ -20,13 +20,13 @@ from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from tensorflow_privacy.privacy.privacy_tests import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import LossFunction
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
@ -123,7 +123,7 @@ class AttackInputDataTest(parameterized.TestCase):
probs_test=np.array([1, 1.]), probs_test=np.array([1, 1.]),
labels_train=np.array([1, 0.]), labels_train=np.array([1, 0.]),
labels_test=np.array([0, 2.]), labels_test=np.array([0, 2.]),
loss_function=LossFunction.SQUARED, loss_function=utils.LossFunction.SQUARED,
loss_function_using_logits=loss_function_using_logits, loss_function_using_logits=loss_function_using_logits,
) )
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train) np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
@ -175,7 +175,7 @@ class AttackInputDataTest(parameterized.TestCase):
probs_test=probs, probs_test=probs,
labels_train=np.array([1, 0.]), labels_train=np.array([1, 0.]),
labels_test=np.array([1, 0.]), labels_test=np.array([1, 0.]),
loss_function=LossFunction.SQUARED, loss_function=utils.LossFunction.SQUARED,
) )
np.testing.assert_allclose(attack_input.get_loss_train(), expected) np.testing.assert_allclose(attack_input.get_loss_train(), expected)
np.testing.assert_allclose(attack_input.get_loss_test(), expected) np.testing.assert_allclose(attack_input.get_loss_test(), expected)

View file

@ -24,7 +24,6 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard

View file

@ -20,9 +20,9 @@ from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow import estimator as tf_estimator from tensorflow import estimator as tf_estimator
from tensorflow_privacy.privacy.privacy_tests import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard

View file

@ -13,7 +13,10 @@
# limitations under the License. # limitations under the License.
"""Utility functions for membership inference attacks.""" """Utility functions for membership inference attacks."""
import enum
import logging import logging
from typing import Callable, Optional, Union
import numpy as np import numpy as np
from scipy import special from scipy import special
@ -122,3 +125,65 @@ def multilabel_bce_loss(labels: np.ndarray,
bce = labels * np.log(pred + small_value) bce = labels * np.log(pred + small_value)
bce += (1 - labels) * np.log(1 - pred + small_value) bce += (1 - labels) * np.log(1 - pred + small_value)
return -bce return -bce
class LossFunction(enum.Enum):
"""An enum that defines loss function."""
CROSS_ENTROPY = 'cross_entropy'
SQUARED = 'squared'
def string_to_loss_function(string: str):
"""Convert string to the corresponding LossFunction."""
if string == LossFunction.CROSS_ENTROPY.value:
return LossFunction.CROSS_ENTROPY
if string == LossFunction.SQUARED.value:
return LossFunction.SQUARED
raise ValueError(f'{string} is not a valid loss function name.')
def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
loss_function: Union[Callable[[np.ndarray, np.ndarray],
np.ndarray], LossFunction],
loss_function_using_logits: Optional[bool],
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
"""Calculates (if needed) losses.
Args:
loss: the loss of each example.
labels: the scalar label of each example.
logits: the logits vector of each example.
probs: the probability vector of each example.
loss_function: if `loss` is not available, `labels` and one of `logits`
and `probs` are available, we will use this function to compute loss. It
is supposed to take in (label, logits / probs) as input.
loss_function_using_logits: if `loss_function` expects `logits` or
`probs`.
multilabel_data: if the data is from a multilabel classification problem.
Returns:
Loss (or None if neither the loss nor the labels are present).
"""
if loss is not None:
return loss
if labels is None or (logits is None and probs is None):
return None
if loss_function_using_logits and logits is None:
raise ValueError('We need logits to compute loss, but it is set to None.')
if not loss_function_using_logits and probs is None:
raise ValueError('We need probs to compute loss, but it is set to None.')
predictions = logits if loss_function_using_logits else probs
if loss_function == LossFunction.CROSS_ENTROPY:
if multilabel_data:
loss = multilabel_bce_loss(labels, predictions,
loss_function_using_logits)
else:
loss = log_loss(labels, predictions, loss_function_using_logits)
elif loss_function == LossFunction.SQUARED:
loss = squared_loss(labels, predictions)
else:
loss = loss_function(labels, predictions)
return loss

View file

@ -16,7 +16,24 @@ from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils from tensorflow_privacy.privacy.privacy_tests import utils
class LossFunctionFromStringTest(parameterized.TestCase):
@parameterized.parameters(
(utils.LossFunction.CROSS_ENTROPY, 'cross_entropy'),
(utils.LossFunction.SQUARED, 'squared'),
)
def test_from_str(self, en, string):
self.assertEqual(utils.string_to_loss_function(string), en)
@parameterized.parameters(
('random string'),
(''),
)
def test_from_str_wrong_input(self, string):
self.assertRaises(ValueError, utils.string_to_loss_function, string)
class TestLogLoss(parameterized.TestCase): class TestLogLoss(parameterized.TestCase):