Fixing calculating loss on logits.

PiperOrigin-RevId: 329966058
This commit is contained in:
Vadym Doroshenko 2020-09-03 12:06:07 -07:00 committed by A. Unique TensorFlower
parent f4fc9b2623
commit 8f3a61b50d
10 changed files with 144 additions and 109 deletions

View file

@ -22,7 +22,7 @@ from dataclasses import dataclass
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from sklearn import metrics from sklearn import metrics
import tensorflow_privacy.privacy.membership_inference_attack.utils as utils
ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)' ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)'
@ -173,21 +173,19 @@ class AttackInputData:
'Please set labels_train and labels_test') 'Please set labels_train and labels_test')
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1 return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
@staticmethod
def _get_loss(logits: np.ndarray, true_labels: np.ndarray):
return logits[range(logits.shape[0]), true_labels]
def get_loss_train(self): def get_loss_train(self):
"""Calculates cross-entropy losses for the training set.""" """Calculates (if needed) cross-entropy losses for the training set."""
if self.loss_train is not None: if self.loss_train is None:
return self.loss_train self.loss_train = utils.log_loss_from_logits(self.labels_train,
return self._get_loss(self.logits_train, self.labels_train) self.logits_train)
return self.loss_train
def get_loss_test(self): def get_loss_test(self):
"""Calculates cross-entropy losses for the test set.""" """Calculates (if needed) cross-entropy losses for the test set."""
if self.loss_test is not None: if self.loss_test is None:
return self.loss_test self.loss_test = utils.log_loss_from_logits(self.labels_test,
return self._get_loss(self.logits_test, self.labels_test) self.logits_test)
return self.loss_test
def get_train_size(self): def get_train_size(self):
"""Returns size of the training set.""" """Returns size of the training set."""
@ -365,11 +363,13 @@ class AttackResults:
advantages.append(float(attack_result.get_attacker_advantage())) advantages.append(float(attack_result.get_attacker_advantage()))
aucs.append(float(attack_result.get_auc())) aucs.append(float(attack_result.get_auc()))
df = pd.DataFrame({'slice feature': slice_features, df = pd.DataFrame({
'slice value': slice_values, 'slice feature': slice_features,
'attack type': attack_types, 'slice value': slice_values,
'attack advantage': advantages, 'attack type': attack_types,
'roc auc': aucs}) 'attack advantage': advantages,
'roc auc': aucs
})
return df return df
def summary(self, by_slices=False) -> str: def summary(self, by_slices=False) -> str:
@ -452,3 +452,27 @@ class AttackResults:
"""Loads AttackResults from a pickle file.""" """Loads AttackResults from a pickle file."""
with open(filepath, 'rb') as inp: with open(filepath, 'rb') as inp:
return pickle.load(inp) return pickle.load(inp)
def get_flattened_attack_metrics(results: AttackResults):
"""Get flattened attack metrics.
Args:
results: membership inference attack results.
Returns:
properties: a list of (slice, attack_type, metric name)
values: a list of metric values, i-th element correspond to properties[i]
"""
properties = []
values = []
for attack_result in results.single_attack_results:
slice_spec = attack_result.slice_spec
prop = [str(slice_spec), str(attack_result.attack_type)]
properties += [prop + ['adv'], prop + ['auc']]
values += [
float(attack_result.get_attacker_advantage()),
float(attack_result.get_auc())
]
return properties, values

View file

@ -47,13 +47,15 @@ class AttackInputDataTest(absltest.TestCase):
def test_get_loss(self): def test_get_loss(self):
attack_input = AttackInputData( attack_input = AttackInputData(
logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]), logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]),
logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]), logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]),
labels_train=np.array([1, 0]), labels_train=np.array([1, 0]),
labels_test=np.array([0, 1])) labels_test=np.array([0, 2]))
np.testing.assert_equal(attack_input.get_loss_train().tolist(), [0.5, 0.2]) np.testing.assert_allclose(
np.testing.assert_equal(attack_input.get_loss_test().tolist(), [0.2, 0.5]) attack_input.get_loss_train(), [0.36313551, 1.37153903], atol=1e-7)
np.testing.assert_allclose(
attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7)
def test_get_loss_explicitly_provided(self): def test_get_loss_explicitly_provided(self):
attack_input = AttackInputData( attack_input = AttackInputData(
@ -237,8 +239,9 @@ class AttackResultsTest(absltest.TestCase):
self.assertEqual(repr(results), repr(loaded_results)) self.assertEqual(repr(results), repr(loaded_results))
def test_calculate_pd_dataframe(self): def test_calculate_pd_dataframe(self):
single_results = [self.perfect_classifier_result, single_results = [
self.random_classifier_result] self.perfect_classifier_result, self.random_classifier_result
]
results = AttackResults(single_results) results = AttackResults(single_results)
df = results.calculate_pd_dataframe() df = results.calculate_pd_dataframe()
df_expected = pd.DataFrame({ df_expected = pd.DataFrame({

View file

@ -24,8 +24,8 @@ import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
@ -86,7 +86,7 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
self._attack_types) self._attack_types)
logging.info(results) logging.info(results)
attack_properties, attack_values = get_all_attack_results(results) attack_properties, attack_values = get_flattened_attack_metrics(results)
print('Attack result:') print('Attack result:')
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)])) zip(attack_properties, attack_values)]))

View file

@ -21,11 +21,10 @@ from absl import flags
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
GradientDescentOptimizer = tf.train.GradientDescentOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer
@ -40,11 +39,16 @@ flags.DEFINE_string('model_dir', None, 'Model directory.')
def cnn_model(): def cnn_model():
"""Define a CNN model.""" """Define a CNN model."""
model = tf.keras.Sequential([ model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, 8, strides=2, padding='same', tf.keras.layers.Conv2D(
activation='relu', input_shape=(28, 28, 1)), 16,
8,
strides=2,
padding='same',
activation='relu',
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPool2D(2, 1), tf.keras.layers.MaxPool2D(2, 1),
tf.keras.layers.Conv2D(32, 4, strides=2, padding='valid', tf.keras.layers.Conv2D(
activation='relu'), 32, 4, strides=2, padding='valid', activation='relu'),
tf.keras.layers.MaxPool2D(2, 1), tf.keras.layers.MaxPool2D(2, 1),
tf.keras.layers.Flatten(), tf.keras.layers.Flatten(),
tf.keras.layers.Dense(32, activation='relu'), tf.keras.layers.Dense(32, activation='relu'),
@ -83,31 +87,34 @@ def main(unused_argv):
# Get callback for membership inference attack. # Get callback for membership inference attack.
mia_callback = MembershipInferenceCallback( mia_callback = MembershipInferenceCallback(
(train_data, train_labels), (train_data, train_labels), (test_data, test_labels),
(test_data, test_labels),
attack_types=[AttackType.THRESHOLD_ATTACK], attack_types=[AttackType.THRESHOLD_ATTACK],
tensorboard_dir=FLAGS.model_dir) tensorboard_dir=FLAGS.model_dir)
# Train model with Keras # Train model with Keras
model.fit(train_data, train_labels, model.fit(
epochs=FLAGS.epochs, train_data,
validation_data=(test_data, test_labels), train_labels,
batch_size=FLAGS.batch_size, epochs=FLAGS.epochs,
callbacks=[mia_callback], validation_data=(test_data, test_labels),
verbose=2) batch_size=FLAGS.batch_size,
callbacks=[mia_callback],
verbose=2)
print('End of training attack:') print('End of training attack:')
attack_results = run_attack_on_keras_model( attack_results = run_attack_on_keras_model(
model, model, (train_data, train_labels), (test_data, test_labels),
(train_data, train_labels),
(test_data, test_labels),
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS] attack_types=[
) AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
])
attack_properties, attack_values = get_all_attack_results(attack_results) attack_properties, attack_values = get_flattened_attack_metrics(
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in attack_results)
zip(attack_properties, attack_values)])) print('\n'.join([
' %s: %.4f' % (', '.join(p), r)
for p, r in zip(attack_properties, attack_values)
]))
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -23,7 +23,7 @@ import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
class UtilsTest(absltest.TestCase): class UtilsTest(absltest.TestCase):
@ -67,7 +67,7 @@ class UtilsTest(absltest.TestCase):
(self.test_data, self.test_labels), (self.test_data, self.test_labels),
attack_types=[AttackType.THRESHOLD_ATTACK]) attack_types=[AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, AttackResults) self.assertIsInstance(results, AttackResults)
attack_properties, attack_values = get_all_attack_results(results) attack_properties, attack_values = get_flattened_attack_metrics(results)
self.assertLen(attack_properties, 2) self.assertLen(attack_properties, 2)
self.assertLen(attack_values, 2) self.assertLen(attack_values, 2)

View file

@ -26,8 +26,8 @@ import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
@ -101,7 +101,7 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
self._attack_types) self._attack_types)
logging.info(results) logging.info(results)
attack_properties, attack_values = get_all_attack_results(results) attack_properties, attack_values = get_flattened_attack_metrics(results)
print('Attack result:') print('Attack result:')
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
zip(attack_properties, attack_values)])) zip(attack_properties, attack_values)]))

View file

@ -22,10 +22,10 @@ from absl import logging
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
GradientDescentOptimizer = tf.train.GradientDescentOptimizer GradientDescentOptimizer = tf.train.GradientDescentOptimizer
@ -63,9 +63,7 @@ def cnn_model_fn(features, labels, mode):
global_step = tf.train.get_global_step() global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=scalar_loss, global_step=global_step) train_op = optimizer.minimize(loss=scalar_loss, global_step=global_step)
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=mode, mode=mode, loss=scalar_loss, train_op=train_op)
loss=scalar_loss,
train_op=train_op)
# Add evaluation metrics (for EVAL mode). # Add evaluation metrics (for EVAL mode).
elif mode == tf.estimator.ModeKeys.EVAL: elif mode == tf.estimator.ModeKeys.EVAL:
@ -108,8 +106,8 @@ def main(unused_argv):
train_data, train_labels, test_data, test_labels = load_mnist() train_data, train_labels, test_data, test_labels = load_mnist()
# Instantiate the tf.Estimator. # Instantiate the tf.Estimator.
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, mnist_classifier = tf.estimator.Estimator(
model_dir=FLAGS.model_dir) model_fn=cnn_model_fn, model_dir=FLAGS.model_dir)
# A function to construct input_fn given (data, label), to be used by the # A function to construct input_fn given (data, label), to be used by the
# membership inference training hook. # membership inference training hook.
@ -124,9 +122,7 @@ def main(unused_argv):
else: else:
summary_writer = None summary_writer = None
mia_hook = MembershipInferenceTrainingHook( mia_hook = MembershipInferenceTrainingHook(
mnist_classifier, mnist_classifier, (train_data, train_labels), (test_data, test_labels),
(train_data, train_labels),
(test_data, test_labels),
input_fn_constructor, input_fn_constructor,
attack_types=[AttackType.THRESHOLD_ATTACK], attack_types=[AttackType.THRESHOLD_ATTACK],
writer=summary_writer) writer=summary_writer)
@ -145,8 +141,8 @@ def main(unused_argv):
steps_per_epoch = 60000 // FLAGS.batch_size steps_per_epoch = 60000 // FLAGS.batch_size
for epoch in range(1, FLAGS.epochs + 1): for epoch in range(1, FLAGS.epochs + 1):
# Train the model, with the membership inference hook. # Train the model, with the membership inference hook.
mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch, mnist_classifier.train(
hooks=[mia_hook]) input_fn=train_input_fn, steps=steps_per_epoch, hooks=[mia_hook])
# Evaluate the model and print results # Evaluate the model and print results
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
@ -155,16 +151,18 @@ def main(unused_argv):
print('End of training attack') print('End of training attack')
attack_results = run_attack_on_tf_estimator_model( attack_results = run_attack_on_tf_estimator_model(
mnist_classifier, mnist_classifier, (train_data, train_labels), (test_data, test_labels),
(train_data, train_labels),
(test_data, test_labels),
input_fn_constructor, input_fn_constructor,
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS] attack_types=[
) AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
attack_properties, attack_values = get_all_attack_results(attack_results) ])
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in attack_properties, attack_values = get_flattened_attack_metrics(
zip(attack_properties, attack_values)])) attack_results)
print('\n'.join([
' %s: %.4f' % (', '.join(p), r)
for p, r in zip(attack_properties, attack_values)
]))
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -23,7 +23,7 @@ import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
class UtilsTest(absltest.TestCase): class UtilsTest(absltest.TestCase):
@ -88,7 +88,7 @@ class UtilsTest(absltest.TestCase):
self.test_labels, self.test_labels,
attack_types=[AttackType.THRESHOLD_ATTACK]) attack_types=[AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, AttackResults) self.assertIsInstance(results, AttackResults)
attack_properties, attack_values = get_all_attack_results(results) attack_properties, attack_values = get_flattened_attack_metrics(results)
self.assertLen(attack_properties, 2) self.assertLen(attack_properties, 2)
self.assertLen(attack_values, 2) self.assertLen(attack_values, 2)
@ -104,7 +104,7 @@ class UtilsTest(absltest.TestCase):
input_fn_constructor, input_fn_constructor,
attack_types=[AttackType.THRESHOLD_ATTACK]) attack_types=[AttackType.THRESHOLD_ATTACK])
self.assertIsInstance(results, AttackResults) self.assertIsInstance(results, AttackResults)
attack_properties, attack_values = get_all_attack_results(results) attack_properties, attack_values = get_flattened_attack_metrics(results)
self.assertLen(attack_properties, 2) self.assertLen(attack_properties, 2)
self.assertLen(attack_values, 2) self.assertLen(attack_values, 2)

View file

@ -18,10 +18,9 @@
from typing import Text, Dict, Union, List, Any, Tuple from typing import Text, Dict, Union, List, Any, Tuple
import numpy as np import numpy as np
import scipy.special
from sklearn import metrics from sklearn import metrics
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
ArrayDict = Dict[Text, np.ndarray] ArrayDict = Dict[Text, np.ndarray]
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
@ -74,25 +73,6 @@ def prepend_to_keys(in_dict: Dict[Text, Any], prefix: Text) -> Dict[Text, Any]:
return {prefix + k: v for k, v in in_dict.items()} return {prefix + k: v for k, v in in_dict.items()}
# ------------------------------------------------------------------------------
# Utilities for managing result.
# ------------------------------------------------------------------------------
def get_all_attack_results(results: AttackResults):
"""Get all results as a list of attack properties and a list of attack result."""
properties = []
values = []
for attack_result in results.single_attack_results:
slice_spec = attack_result.slice_spec
prop = [str(slice_spec), str(attack_result.attack_type)]
properties += [prop + ['adv'], prop + ['auc']]
values += [float(attack_result.get_attacker_advantage()),
float(attack_result.get_auc())]
return properties, values
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Subsampling and data selection functionality # Subsampling and data selection functionality
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
@ -245,19 +225,24 @@ def compute_performance_metrics(true_labels: np.ndarray,
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
def log_loss(y, pred, small_value=1e-8): def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8):
"""Compute the cross entropy loss. """Compute the cross entropy loss.
Args: Args:
y: numpy array, y[i] is the true label (scalar) of the i-th sample labels: numpy array, labels[i] is the true label (scalar) of the i-th sample
pred: numpy array, pred[i] is the probability vector of the i-th sample pred: numpy array, pred[i] is the probability vector of the i-th sample
small_value: np.log can become -inf if the probability is too close to 0, small_value: np.log can become -inf if the probability is too close to 0, so
so the probability is clipped below by small_value. the probability is clipped below by small_value.
Returns: Returns:
the cross-entropy loss of each sample the cross-entropy loss of each sample
""" """
return -np.log(np.maximum(pred[range(y.size), y], small_value)) return -np.log(np.maximum(pred[range(labels.size), labels], small_value))
def log_loss_from_logits(labels: np.ndarray, logits: np.ndarray):
"""Compute the cross entropy loss from logits."""
return log_loss(labels, scipy.special.softmax(logits, axis=-1))
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------

View file

@ -39,8 +39,10 @@ class UtilsTest(absltest.TestCase):
results = utils.compute_performance_metrics(true, pred, threshold=0.5) results = utils.compute_performance_metrics(true, pred, threshold=0.5)
for k in ['precision', 'recall', 'accuracy', 'f1_score', 'fpr', 'tpr', for k in [
'thresholds', 'auc', 'advantage']: 'precision', 'recall', 'accuracy', 'f1_score', 'fpr', 'tpr',
'thresholds', 'auc', 'advantage'
]:
self.assertIn(k, results) self.assertIn(k, results)
np.testing.assert_almost_equal(results['accuracy'], 1. / 2.) np.testing.assert_almost_equal(results['accuracy'], 1. / 2.)
@ -107,10 +109,16 @@ class UtilsTest(absltest.TestCase):
[0.75, 0.25], [0.9, 0.1], [0.99, 0.01]]) [0.75, 0.25], [0.9, 0.1], [0.99, 0.01]])
# Test the cases when true label (for all samples) is 0 and 1 # Test the cases when true label (for all samples) is 0 and 1
expected_losses = { expected_losses = {
0: np.array([4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207, 0:
0.10536052, 0.01005034]), np.array([
1: np.array([0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436, 4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207,
2.30258509, 4.60517019]) 0.10536052, 0.01005034
]),
1:
np.array([
0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436,
2.30258509, 4.60517019
])
} }
for c in [0, 1]: # true label for c in [0, 1]: # true label
y = np.ones(shape=pred.shape[0], dtype=int) * c y = np.ones(shape=pred.shape[0], dtype=int) * c
@ -139,8 +147,18 @@ class UtilsTest(absltest.TestCase):
expected_losses = np.array([18.42068074, 46.05170186, 115.12925465]) expected_losses = np.array([18.42068074, 46.05170186, 115.12925465])
for i, small_value in enumerate(small_values): for i, small_value in enumerate(small_values):
loss = utils.log_loss(y, pred, small_value) loss = utils.log_loss(y, pred, small_value)
np.testing.assert_allclose(loss, np.array([expected_losses[i], 0]), np.testing.assert_allclose(
atol=1e-7) loss, np.array([expected_losses[i], 0]), atol=1e-7)
def test_log_loss_from_logits(self):
"""Test computing cross-entropy loss from logits."""
logits = np.array([[1, 2, 0, -1], [1, 2, 0, -1], [-1, 3, 0, 0]])
labels = np.array([0, 3, 1])
expected_loss = np.array([1.4401897, 3.4401897, 0.11144278])
loss = utils.log_loss_from_logits(labels, logits)
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
if __name__ == '__main__': if __name__ == '__main__':