From 8f3a61b50d2255ef405e6cb52ca615c9f5072fc0 Mon Sep 17 00:00:00 2001 From: Vadym Doroshenko Date: Thu, 3 Sep 2020 12:06:07 -0700 Subject: [PATCH] Fixing calculating loss on logits. PiperOrigin-RevId: 329966058 --- .../data_structures.py | 60 +++++++++++++------ .../data_structures_test.py | 17 +++--- .../keras_evaluation.py | 4 +- .../keras_evaluation_example.py | 51 +++++++++------- .../keras_evaluation_test.py | 4 +- .../tf_estimator_evaluation.py | 4 +- .../tf_estimator_evaluation_example.py | 36 ++++++----- .../tf_estimator_evaluation_test.py | 6 +- .../membership_inference_attack/utils.py | 37 ++++-------- .../membership_inference_attack/utils_test.py | 34 ++++++++--- 10 files changed, 144 insertions(+), 109 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index a309b35..9f81f53 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -22,7 +22,7 @@ from dataclasses import dataclass import numpy as np import pandas as pd from sklearn import metrics - +import tensorflow_privacy.privacy.membership_inference_attack.utils as utils ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)' @@ -173,21 +173,19 @@ class AttackInputData: 'Please set labels_train and labels_test') return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1 - @staticmethod - def _get_loss(logits: np.ndarray, true_labels: np.ndarray): - return logits[range(logits.shape[0]), true_labels] - def get_loss_train(self): - """Calculates cross-entropy losses for the training set.""" - if self.loss_train is not None: - return self.loss_train - return self._get_loss(self.logits_train, self.labels_train) + """Calculates (if needed) cross-entropy losses for the training set.""" + if self.loss_train is None: + self.loss_train = utils.log_loss_from_logits(self.labels_train, + self.logits_train) + return self.loss_train def get_loss_test(self): - """Calculates cross-entropy losses for the test set.""" - if self.loss_test is not None: - return self.loss_test - return self._get_loss(self.logits_test, self.labels_test) + """Calculates (if needed) cross-entropy losses for the test set.""" + if self.loss_test is None: + self.loss_test = utils.log_loss_from_logits(self.labels_test, + self.logits_test) + return self.loss_test def get_train_size(self): """Returns size of the training set.""" @@ -365,11 +363,13 @@ class AttackResults: advantages.append(float(attack_result.get_attacker_advantage())) aucs.append(float(attack_result.get_auc())) - df = pd.DataFrame({'slice feature': slice_features, - 'slice value': slice_values, - 'attack type': attack_types, - 'attack advantage': advantages, - 'roc auc': aucs}) + df = pd.DataFrame({ + 'slice feature': slice_features, + 'slice value': slice_values, + 'attack type': attack_types, + 'attack advantage': advantages, + 'roc auc': aucs + }) return df def summary(self, by_slices=False) -> str: @@ -452,3 +452,27 @@ class AttackResults: """Loads AttackResults from a pickle file.""" with open(filepath, 'rb') as inp: return pickle.load(inp) + + +def get_flattened_attack_metrics(results: AttackResults): + """Get flattened attack metrics. + + Args: + results: membership inference attack results. + + Returns: + properties: a list of (slice, attack_type, metric name) + values: a list of metric values, i-th element correspond to properties[i] + """ + properties = [] + values = [] + for attack_result in results.single_attack_results: + slice_spec = attack_result.slice_spec + prop = [str(slice_spec), str(attack_result.attack_type)] + properties += [prop + ['adv'], prop + ['auc']] + values += [ + float(attack_result.get_attacker_advantage()), + float(attack_result.get_auc()) + ] + + return properties, values diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 8314d95..e38c174 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -47,13 +47,15 @@ class AttackInputDataTest(absltest.TestCase): def test_get_loss(self): attack_input = AttackInputData( - logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]), - logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]), + logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]), + logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]), labels_train=np.array([1, 0]), - labels_test=np.array([0, 1])) + labels_test=np.array([0, 2])) - np.testing.assert_equal(attack_input.get_loss_train().tolist(), [0.5, 0.2]) - np.testing.assert_equal(attack_input.get_loss_test().tolist(), [0.2, 0.5]) + np.testing.assert_allclose( + attack_input.get_loss_train(), [0.36313551, 1.37153903], atol=1e-7) + np.testing.assert_allclose( + attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7) def test_get_loss_explicitly_provided(self): attack_input = AttackInputData( @@ -237,8 +239,9 @@ class AttackResultsTest(absltest.TestCase): self.assertEqual(repr(results), repr(loaded_results)) def test_calculate_pd_dataframe(self): - single_results = [self.perfect_classifier_result, - self.random_classifier_result] + single_results = [ + self.perfect_classifier_result, self.random_classifier_result + ] results = AttackResults(single_results) df = results.calculate_pd_dataframe() df_expected = pd.DataFrame({ diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py index 938c42e..e43346a 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py @@ -24,8 +24,8 @@ import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec -from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard @@ -86,7 +86,7 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback): self._attack_types) logging.info(results) - attack_properties, attack_values = get_all_attack_results(results) + attack_properties, attack_values = get_flattened_attack_metrics(results) print('Attack result:') print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in zip(attack_properties, attack_values)])) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py index 2cdc029..50102dc 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py @@ -21,11 +21,10 @@ from absl import flags import numpy as np import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model -from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results - GradientDescentOptimizer = tf.train.GradientDescentOptimizer @@ -40,11 +39,16 @@ flags.DEFINE_string('model_dir', None, 'Model directory.') def cnn_model(): """Define a CNN model.""" model = tf.keras.Sequential([ - tf.keras.layers.Conv2D(16, 8, strides=2, padding='same', - activation='relu', input_shape=(28, 28, 1)), + tf.keras.layers.Conv2D( + 16, + 8, + strides=2, + padding='same', + activation='relu', + input_shape=(28, 28, 1)), tf.keras.layers.MaxPool2D(2, 1), - tf.keras.layers.Conv2D(32, 4, strides=2, padding='valid', - activation='relu'), + tf.keras.layers.Conv2D( + 32, 4, strides=2, padding='valid', activation='relu'), tf.keras.layers.MaxPool2D(2, 1), tf.keras.layers.Flatten(), tf.keras.layers.Dense(32, activation='relu'), @@ -83,31 +87,34 @@ def main(unused_argv): # Get callback for membership inference attack. mia_callback = MembershipInferenceCallback( - (train_data, train_labels), - (test_data, test_labels), + (train_data, train_labels), (test_data, test_labels), attack_types=[AttackType.THRESHOLD_ATTACK], tensorboard_dir=FLAGS.model_dir) # Train model with Keras - model.fit(train_data, train_labels, - epochs=FLAGS.epochs, - validation_data=(test_data, test_labels), - batch_size=FLAGS.batch_size, - callbacks=[mia_callback], - verbose=2) + model.fit( + train_data, + train_labels, + epochs=FLAGS.epochs, + validation_data=(test_data, test_labels), + batch_size=FLAGS.batch_size, + callbacks=[mia_callback], + verbose=2) print('End of training attack:') attack_results = run_attack_on_keras_model( - model, - (train_data, train_labels), - (test_data, test_labels), + model, (train_data, train_labels), (test_data, test_labels), slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), - attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS] - ) + attack_types=[ + AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS + ]) - attack_properties, attack_values = get_all_attack_results(attack_results) - print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in - zip(attack_properties, attack_values)])) + attack_properties, attack_values = get_flattened_attack_metrics( + attack_results) + print('\n'.join([ + ' %s: %.4f' % (', '.join(p), r) + for p, r in zip(attack_properties, attack_values) + ])) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py index 8ce5c19..6bdb0fc 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py @@ -23,7 +23,7 @@ import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType -from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics class UtilsTest(absltest.TestCase): @@ -67,7 +67,7 @@ class UtilsTest(absltest.TestCase): (self.test_data, self.test_labels), attack_types=[AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, AttackResults) - attack_properties, attack_values = get_all_attack_results(results) + attack_properties, attack_values = get_flattened_attack_metrics(results) self.assertLen(attack_properties, 2) self.assertLen(attack_values, 2) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py index 954ad5b..9a61cd0 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation.py @@ -26,8 +26,8 @@ import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec -from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard @@ -101,7 +101,7 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook): self._attack_types) logging.info(results) - attack_properties, attack_values = get_all_attack_results(results) + attack_properties, attack_values = get_flattened_attack_metrics(results) print('Attack result:') print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in zip(attack_properties, attack_values)])) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py index 6d4a1ba..c579491 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_example.py @@ -22,10 +22,10 @@ from absl import logging import numpy as np import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model -from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results GradientDescentOptimizer = tf.train.GradientDescentOptimizer @@ -63,9 +63,7 @@ def cnn_model_fn(features, labels, mode): global_step = tf.train.get_global_step() train_op = optimizer.minimize(loss=scalar_loss, global_step=global_step) return tf.estimator.EstimatorSpec( - mode=mode, - loss=scalar_loss, - train_op=train_op) + mode=mode, loss=scalar_loss, train_op=train_op) # Add evaluation metrics (for EVAL mode). elif mode == tf.estimator.ModeKeys.EVAL: @@ -108,8 +106,8 @@ def main(unused_argv): train_data, train_labels, test_data, test_labels = load_mnist() # Instantiate the tf.Estimator. - mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, - model_dir=FLAGS.model_dir) + mnist_classifier = tf.estimator.Estimator( + model_fn=cnn_model_fn, model_dir=FLAGS.model_dir) # A function to construct input_fn given (data, label), to be used by the # membership inference training hook. @@ -124,9 +122,7 @@ def main(unused_argv): else: summary_writer = None mia_hook = MembershipInferenceTrainingHook( - mnist_classifier, - (train_data, train_labels), - (test_data, test_labels), + mnist_classifier, (train_data, train_labels), (test_data, test_labels), input_fn_constructor, attack_types=[AttackType.THRESHOLD_ATTACK], writer=summary_writer) @@ -145,8 +141,8 @@ def main(unused_argv): steps_per_epoch = 60000 // FLAGS.batch_size for epoch in range(1, FLAGS.epochs + 1): # Train the model, with the membership inference hook. - mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch, - hooks=[mia_hook]) + mnist_classifier.train( + input_fn=train_input_fn, steps=steps_per_epoch, hooks=[mia_hook]) # Evaluate the model and print results eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) @@ -155,16 +151,18 @@ def main(unused_argv): print('End of training attack') attack_results = run_attack_on_tf_estimator_model( - mnist_classifier, - (train_data, train_labels), - (test_data, test_labels), + mnist_classifier, (train_data, train_labels), (test_data, test_labels), input_fn_constructor, slicing_spec=SlicingSpec(entire_dataset=True, by_class=True), - attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS] - ) - attack_properties, attack_values = get_all_attack_results(attack_results) - print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in - zip(attack_properties, attack_values)])) + attack_types=[ + AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS + ]) + attack_properties, attack_values = get_flattened_attack_metrics( + attack_results) + print('\n'.join([ + ' %s: %.4f' % (', '.join(p), r) + for p, r in zip(attack_properties, attack_values) + ])) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py index bfb1585..f8f8ce2 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/tf_estimator_evaluation_test.py @@ -23,7 +23,7 @@ import tensorflow.compat.v1 as tf from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType -from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics class UtilsTest(absltest.TestCase): @@ -88,7 +88,7 @@ class UtilsTest(absltest.TestCase): self.test_labels, attack_types=[AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, AttackResults) - attack_properties, attack_values = get_all_attack_results(results) + attack_properties, attack_values = get_flattened_attack_metrics(results) self.assertLen(attack_properties, 2) self.assertLen(attack_values, 2) @@ -104,7 +104,7 @@ class UtilsTest(absltest.TestCase): input_fn_constructor, attack_types=[AttackType.THRESHOLD_ATTACK]) self.assertIsInstance(results, AttackResults) - attack_properties, attack_values = get_all_attack_results(results) + attack_properties, attack_values = get_flattened_attack_metrics(results) self.assertLen(attack_properties, 2) self.assertLen(attack_values, 2) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/utils.py b/tensorflow_privacy/privacy/membership_inference_attack/utils.py index cb2c660..67d7e99 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/utils.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/utils.py @@ -18,10 +18,9 @@ from typing import Text, Dict, Union, List, Any, Tuple import numpy as np +import scipy.special from sklearn import metrics import tensorflow.compat.v1 as tf -from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults - ArrayDict = Dict[Text, np.ndarray] Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] @@ -74,25 +73,6 @@ def prepend_to_keys(in_dict: Dict[Text, Any], prefix: Text) -> Dict[Text, Any]: return {prefix + k: v for k, v in in_dict.items()} -# ------------------------------------------------------------------------------ -# Utilities for managing result. -# ------------------------------------------------------------------------------ - - -def get_all_attack_results(results: AttackResults): - """Get all results as a list of attack properties and a list of attack result.""" - properties = [] - values = [] - for attack_result in results.single_attack_results: - slice_spec = attack_result.slice_spec - prop = [str(slice_spec), str(attack_result.attack_type)] - properties += [prop + ['adv'], prop + ['auc']] - values += [float(attack_result.get_attacker_advantage()), - float(attack_result.get_auc())] - - return properties, values - - # ------------------------------------------------------------------------------ # Subsampling and data selection functionality # ------------------------------------------------------------------------------ @@ -245,19 +225,24 @@ def compute_performance_metrics(true_labels: np.ndarray, # ------------------------------------------------------------------------------ -def log_loss(y, pred, small_value=1e-8): +def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8): """Compute the cross entropy loss. Args: - y: numpy array, y[i] is the true label (scalar) of the i-th sample + labels: numpy array, labels[i] is the true label (scalar) of the i-th sample pred: numpy array, pred[i] is the probability vector of the i-th sample - small_value: np.log can become -inf if the probability is too close to 0, - so the probability is clipped below by small_value. + small_value: np.log can become -inf if the probability is too close to 0, so + the probability is clipped below by small_value. Returns: the cross-entropy loss of each sample """ - return -np.log(np.maximum(pred[range(y.size), y], small_value)) + return -np.log(np.maximum(pred[range(labels.size), labels], small_value)) + + +def log_loss_from_logits(labels: np.ndarray, logits: np.ndarray): + """Compute the cross entropy loss from logits.""" + return log_loss(labels, scipy.special.softmax(logits, axis=-1)) # ------------------------------------------------------------------------------ diff --git a/tensorflow_privacy/privacy/membership_inference_attack/utils_test.py b/tensorflow_privacy/privacy/membership_inference_attack/utils_test.py index fd9fa31..d96bbbc 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/utils_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/utils_test.py @@ -39,8 +39,10 @@ class UtilsTest(absltest.TestCase): results = utils.compute_performance_metrics(true, pred, threshold=0.5) - for k in ['precision', 'recall', 'accuracy', 'f1_score', 'fpr', 'tpr', - 'thresholds', 'auc', 'advantage']: + for k in [ + 'precision', 'recall', 'accuracy', 'f1_score', 'fpr', 'tpr', + 'thresholds', 'auc', 'advantage' + ]: self.assertIn(k, results) np.testing.assert_almost_equal(results['accuracy'], 1. / 2.) @@ -107,10 +109,16 @@ class UtilsTest(absltest.TestCase): [0.75, 0.25], [0.9, 0.1], [0.99, 0.01]]) # Test the cases when true label (for all samples) is 0 and 1 expected_losses = { - 0: np.array([4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207, - 0.10536052, 0.01005034]), - 1: np.array([0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436, - 2.30258509, 4.60517019]) + 0: + np.array([ + 4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207, + 0.10536052, 0.01005034 + ]), + 1: + np.array([ + 0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436, + 2.30258509, 4.60517019 + ]) } for c in [0, 1]: # true label y = np.ones(shape=pred.shape[0], dtype=int) * c @@ -139,8 +147,18 @@ class UtilsTest(absltest.TestCase): expected_losses = np.array([18.42068074, 46.05170186, 115.12925465]) for i, small_value in enumerate(small_values): loss = utils.log_loss(y, pred, small_value) - np.testing.assert_allclose(loss, np.array([expected_losses[i], 0]), - atol=1e-7) + np.testing.assert_allclose( + loss, np.array([expected_losses[i], 0]), atol=1e-7) + + def test_log_loss_from_logits(self): + """Test computing cross-entropy loss from logits.""" + + logits = np.array([[1, 2, 0, -1], [1, 2, 0, -1], [-1, 3, 0, 0]]) + labels = np.array([0, 3, 1]) + expected_loss = np.array([1.4401897, 3.4401897, 0.11144278]) + + loss = utils.log_loss_from_logits(labels, logits) + np.testing.assert_allclose(expected_loss, loss, atol=1e-7) if __name__ == '__main__':