Refactor: move loss computation utilities under privacy_tests
.
PiperOrigin-RevId: 463391913
This commit is contained in:
parent
44dc40454b
commit
17cd0c52bc
10 changed files with 119 additions and 88 deletions
|
@ -1,6 +1,6 @@
|
||||||
load("@rules_python//python:defs.bzl", "py_library")
|
load("@rules_python//python:defs.bzl", "py_library", "py_test")
|
||||||
|
|
||||||
package(default_visibility = ["//visibility:private"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
@ -8,3 +8,18 @@ py_library(
|
||||||
name = "privacy_tests",
|
name = "privacy_tests",
|
||||||
srcs = ["__init__.py"],
|
srcs = ["__init__.py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "utils_test",
|
||||||
|
timeout = "long",
|
||||||
|
srcs = ["utils_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":utils"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "utils",
|
||||||
|
srcs = ["utils.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
)
|
||||||
|
|
|
@ -15,21 +15,6 @@ py_library(
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "utils",
|
|
||||||
srcs = ["utils.py"],
|
|
||||||
srcs_version = "PY3",
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "utils_test",
|
|
||||||
timeout = "long",
|
|
||||||
srcs = ["utils_test.py"],
|
|
||||||
python_version = "PY3",
|
|
||||||
srcs_version = "PY3",
|
|
||||||
deps = [":utils"],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "membership_inference_attack_test",
|
name = "membership_inference_attack_test",
|
||||||
timeout = "long",
|
timeout = "long",
|
||||||
|
@ -45,7 +30,10 @@ py_test(
|
||||||
srcs = ["data_structures_test.py"],
|
srcs = ["data_structures_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [":membership_inference_attack"],
|
deps = [
|
||||||
|
":membership_inference_attack",
|
||||||
|
"//tensorflow_privacy/privacy/privacy_tests:utils",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -95,7 +83,7 @@ py_library(
|
||||||
"seq2seq_mia.py",
|
"seq2seq_mia.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [":utils"],
|
deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -122,8 +110,8 @@ py_library(
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
":membership_inference_attack",
|
":membership_inference_attack",
|
||||||
":utils",
|
|
||||||
":utils_tensorboard",
|
":utils_tensorboard",
|
||||||
|
"//tensorflow_privacy/privacy/privacy_tests:utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -144,8 +132,8 @@ py_library(
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
":membership_inference_attack",
|
":membership_inference_attack",
|
||||||
":utils",
|
|
||||||
":utils_tensorboard",
|
":utils_tensorboard",
|
||||||
|
"//tensorflow_privacy/privacy/privacy_tests:utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -185,7 +173,7 @@ py_library(
|
||||||
"advanced_mia.py",
|
"advanced_mia.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [":utils"],
|
deps = ["//tensorflow_privacy/privacy/privacy_tests:utils"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -205,6 +193,6 @@ py_binary(
|
||||||
deps = [
|
deps = [
|
||||||
":advanced_mia",
|
":advanced_mia",
|
||||||
":membership_inference_attack",
|
":membership_inference_attack",
|
||||||
":utils",
|
"//tensorflow_privacy/privacy/privacy_tests:utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,7 +17,7 @@ import functools
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss
|
from tensorflow_privacy.privacy.privacy_tests.utils import log_loss
|
||||||
|
|
||||||
|
|
||||||
def replace_nan_with_column_mean(a: np.ndarray):
|
def replace_nan_with_column_mean(a: np.ndarray):
|
||||||
|
|
|
@ -21,11 +21,10 @@ from absl import flags
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests import utils
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
|
@ -26,7 +26,7 @@ import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from scipy import special
|
from scipy import special
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
|
from tensorflow_privacy.privacy.privacy_tests import utils
|
||||||
|
|
||||||
# The minimum TPR or FPR below which they are considered equal.
|
# The minimum TPR or FPR below which they are considered equal.
|
||||||
_ABSOLUTE_TOLERANCE = 1e-3
|
_ABSOLUTE_TOLERANCE = 1e-3
|
||||||
|
@ -183,12 +183,6 @@ def _log_value(probs, small_value=1e-30):
|
||||||
return -np.log(np.maximum(probs, small_value))
|
return -np.log(np.maximum(probs, small_value))
|
||||||
|
|
||||||
|
|
||||||
class LossFunction(enum.Enum):
|
|
||||||
"""An enum that defines loss function to use in `AttackInputData`."""
|
|
||||||
CROSS_ENTROPY = 'cross_entropy'
|
|
||||||
SQUARED = 'squared'
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class AttackInputData:
|
class AttackInputData:
|
||||||
"""Input data for running an attack.
|
"""Input data for running an attack.
|
||||||
|
@ -225,7 +219,7 @@ class AttackInputData:
|
||||||
# If a callable is provided, it should take in two argument, the 1st is
|
# If a callable is provided, it should take in two argument, the 1st is
|
||||||
# labels, the 2nd is logits or probs.
|
# labels, the 2nd is logits or probs.
|
||||||
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
|
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray],
|
||||||
LossFunction] = LossFunction.CROSS_ENTROPY
|
utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
|
||||||
# Whether `loss_function` will be called with logits or probs. If not set
|
# Whether `loss_function` will be called with logits or probs. If not set
|
||||||
# (None), will decide by availablity of logits and probs and logits is
|
# (None), will decide by availablity of logits and probs and logits is
|
||||||
# preferred when both are available.
|
# preferred when both are available.
|
||||||
|
@ -298,52 +292,6 @@ class AttackInputData:
|
||||||
true_labels]
|
true_labels]
|
||||||
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
|
return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
|
||||||
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
|
||||||
loss_function: Union[Callable[[np.ndarray, np.ndarray],
|
|
||||||
np.ndarray], LossFunction],
|
|
||||||
loss_function_using_logits: Optional[bool],
|
|
||||||
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
|
|
||||||
"""Calculates (if needed) losses.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loss: the loss of each example.
|
|
||||||
labels: the scalar label of each example.
|
|
||||||
logits: the logits vector of each example.
|
|
||||||
probs: the probability vector of each example.
|
|
||||||
loss_function: if `loss` is not available, `labels` and one of `logits`
|
|
||||||
and `probs` are available, we will use this function to compute loss. It
|
|
||||||
is supposed to take in (label, logits / probs) as input.
|
|
||||||
loss_function_using_logits: if `loss_function` expects `logits` or
|
|
||||||
`probs`.
|
|
||||||
multilabel_data: if the data is from a multilabel classification problem.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Loss (or None if neither the loss nor the labels are present).
|
|
||||||
"""
|
|
||||||
if loss is not None:
|
|
||||||
return loss
|
|
||||||
if labels is None or (logits is None and probs is None):
|
|
||||||
return None
|
|
||||||
if loss_function_using_logits and logits is None:
|
|
||||||
raise ValueError('We need logits to compute loss, but it is set to None.')
|
|
||||||
if not loss_function_using_logits and probs is None:
|
|
||||||
raise ValueError('We need probs to compute loss, but it is set to None.')
|
|
||||||
|
|
||||||
predictions = logits if loss_function_using_logits else probs
|
|
||||||
if loss_function == LossFunction.CROSS_ENTROPY:
|
|
||||||
if multilabel_data:
|
|
||||||
loss = utils.multilabel_bce_loss(labels, predictions,
|
|
||||||
loss_function_using_logits)
|
|
||||||
else:
|
|
||||||
loss = utils.log_loss(labels, predictions, loss_function_using_logits)
|
|
||||||
elif loss_function == LossFunction.SQUARED:
|
|
||||||
loss = utils.squared_loss(labels, predictions)
|
|
||||||
else:
|
|
||||||
loss = loss_function(labels, predictions)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Checks performed after instantiation of the AttackInputData dataclass."""
|
"""Checks performed after instantiation of the AttackInputData dataclass."""
|
||||||
# Check if the data is multilabel.
|
# Check if the data is multilabel.
|
||||||
|
@ -358,7 +306,7 @@ class AttackInputData:
|
||||||
"""
|
"""
|
||||||
if self.loss_function_using_logits is None:
|
if self.loss_function_using_logits is None:
|
||||||
self.loss_function_using_logits = (self.logits_train is not None)
|
self.loss_function_using_logits = (self.logits_train is not None)
|
||||||
return self._get_loss(self.loss_train, self.labels_train, self.logits_train,
|
return utils.get_loss(self.loss_train, self.labels_train, self.logits_train,
|
||||||
self.probs_train, self.loss_function,
|
self.probs_train, self.loss_function,
|
||||||
self.loss_function_using_logits, self.multilabel_data)
|
self.loss_function_using_logits, self.multilabel_data)
|
||||||
|
|
||||||
|
@ -370,7 +318,7 @@ class AttackInputData:
|
||||||
"""
|
"""
|
||||||
if self.loss_function_using_logits is None:
|
if self.loss_function_using_logits is None:
|
||||||
self.loss_function_using_logits = bool(self.logits_test)
|
self.loss_function_using_logits = bool(self.logits_test)
|
||||||
return self._get_loss(self.loss_test, self.labels_test, self.logits_test,
|
return utils.get_loss(self.loss_test, self.labels_test, self.logits_test,
|
||||||
self.probs_test, self.loss_function,
|
self.probs_test, self.loss_function,
|
||||||
self.loss_function_using_logits, self.multilabel_data)
|
self.loss_function_using_logits, self.multilabel_data)
|
||||||
|
|
||||||
|
|
|
@ -20,13 +20,13 @@ from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests import utils
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import _log_value
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import LossFunction
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
|
@ -123,7 +123,7 @@ class AttackInputDataTest(parameterized.TestCase):
|
||||||
probs_test=np.array([1, 1.]),
|
probs_test=np.array([1, 1.]),
|
||||||
labels_train=np.array([1, 0.]),
|
labels_train=np.array([1, 0.]),
|
||||||
labels_test=np.array([0, 2.]),
|
labels_test=np.array([0, 2.]),
|
||||||
loss_function=LossFunction.SQUARED,
|
loss_function=utils.LossFunction.SQUARED,
|
||||||
loss_function_using_logits=loss_function_using_logits,
|
loss_function_using_logits=loss_function_using_logits,
|
||||||
)
|
)
|
||||||
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
|
np.testing.assert_allclose(attack_input.get_loss_train(), expected_train)
|
||||||
|
@ -175,7 +175,7 @@ class AttackInputDataTest(parameterized.TestCase):
|
||||||
probs_test=probs,
|
probs_test=probs,
|
||||||
labels_train=np.array([1, 0.]),
|
labels_train=np.array([1, 0.]),
|
||||||
labels_test=np.array([1, 0.]),
|
labels_test=np.array([1, 0.]),
|
||||||
loss_function=LossFunction.SQUARED,
|
loss_function=utils.LossFunction.SQUARED,
|
||||||
)
|
)
|
||||||
np.testing.assert_allclose(attack_input.get_loss_train(), expected)
|
np.testing.assert_allclose(attack_input.get_loss_train(), expected)
|
||||||
np.testing.assert_allclose(attack_input.get_loss_test(), expected)
|
np.testing.assert_allclose(attack_input.get_loss_test(), expected)
|
||||||
|
|
|
@ -24,7 +24,6 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,9 @@ from absl import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow import estimator as tf_estimator
|
from tensorflow import estimator as tf_estimator
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests import utils
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utility functions for membership inference attacks."""
|
"""Utility functions for membership inference attacks."""
|
||||||
|
|
||||||
|
import enum
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import special
|
from scipy import special
|
||||||
|
|
||||||
|
@ -122,3 +125,65 @@ def multilabel_bce_loss(labels: np.ndarray,
|
||||||
bce = labels * np.log(pred + small_value)
|
bce = labels * np.log(pred + small_value)
|
||||||
bce += (1 - labels) * np.log(1 - pred + small_value)
|
bce += (1 - labels) * np.log(1 - pred + small_value)
|
||||||
return -bce
|
return -bce
|
||||||
|
|
||||||
|
|
||||||
|
class LossFunction(enum.Enum):
|
||||||
|
"""An enum that defines loss function."""
|
||||||
|
CROSS_ENTROPY = 'cross_entropy'
|
||||||
|
SQUARED = 'squared'
|
||||||
|
|
||||||
|
|
||||||
|
def string_to_loss_function(string: str):
|
||||||
|
"""Convert string to the corresponding LossFunction."""
|
||||||
|
|
||||||
|
if string == LossFunction.CROSS_ENTROPY.value:
|
||||||
|
return LossFunction.CROSS_ENTROPY
|
||||||
|
if string == LossFunction.SQUARED.value:
|
||||||
|
return LossFunction.SQUARED
|
||||||
|
raise ValueError(f'{string} is not a valid loss function name.')
|
||||||
|
|
||||||
|
|
||||||
|
def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||||
|
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
||||||
|
loss_function: Union[Callable[[np.ndarray, np.ndarray],
|
||||||
|
np.ndarray], LossFunction],
|
||||||
|
loss_function_using_logits: Optional[bool],
|
||||||
|
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
|
||||||
|
"""Calculates (if needed) losses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss: the loss of each example.
|
||||||
|
labels: the scalar label of each example.
|
||||||
|
logits: the logits vector of each example.
|
||||||
|
probs: the probability vector of each example.
|
||||||
|
loss_function: if `loss` is not available, `labels` and one of `logits`
|
||||||
|
and `probs` are available, we will use this function to compute loss. It
|
||||||
|
is supposed to take in (label, logits / probs) as input.
|
||||||
|
loss_function_using_logits: if `loss_function` expects `logits` or
|
||||||
|
`probs`.
|
||||||
|
multilabel_data: if the data is from a multilabel classification problem.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss (or None if neither the loss nor the labels are present).
|
||||||
|
"""
|
||||||
|
if loss is not None:
|
||||||
|
return loss
|
||||||
|
if labels is None or (logits is None and probs is None):
|
||||||
|
return None
|
||||||
|
if loss_function_using_logits and logits is None:
|
||||||
|
raise ValueError('We need logits to compute loss, but it is set to None.')
|
||||||
|
if not loss_function_using_logits and probs is None:
|
||||||
|
raise ValueError('We need probs to compute loss, but it is set to None.')
|
||||||
|
|
||||||
|
predictions = logits if loss_function_using_logits else probs
|
||||||
|
if loss_function == LossFunction.CROSS_ENTROPY:
|
||||||
|
if multilabel_data:
|
||||||
|
loss = multilabel_bce_loss(labels, predictions,
|
||||||
|
loss_function_using_logits)
|
||||||
|
else:
|
||||||
|
loss = log_loss(labels, predictions, loss_function_using_logits)
|
||||||
|
elif loss_function == LossFunction.SQUARED:
|
||||||
|
loss = squared_loss(labels, predictions)
|
||||||
|
else:
|
||||||
|
loss = loss_function(labels, predictions)
|
||||||
|
return loss
|
|
@ -16,7 +16,24 @@ from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
|
from tensorflow_privacy.privacy.privacy_tests import utils
|
||||||
|
|
||||||
|
|
||||||
|
class LossFunctionFromStringTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(utils.LossFunction.CROSS_ENTROPY, 'cross_entropy'),
|
||||||
|
(utils.LossFunction.SQUARED, 'squared'),
|
||||||
|
)
|
||||||
|
def test_from_str(self, en, string):
|
||||||
|
self.assertEqual(utils.string_to_loss_function(string), en)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
('random string'),
|
||||||
|
(''),
|
||||||
|
)
|
||||||
|
def test_from_str_wrong_input(self, string):
|
||||||
|
self.assertRaises(ValueError, utils.string_to_loss_function, string)
|
||||||
|
|
||||||
|
|
||||||
class TestLogLoss(parameterized.TestCase):
|
class TestLogLoss(parameterized.TestCase):
|
Loading…
Reference in a new issue