forked from 626_privacy/tensorflow_privacy
Fixing calculating loss on logits.
PiperOrigin-RevId: 329966058
This commit is contained in:
parent
f4fc9b2623
commit
8f3a61b50d
10 changed files with 144 additions and 109 deletions
|
@ -22,7 +22,7 @@ from dataclasses import dataclass
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn import metrics
|
||||
|
||||
import tensorflow_privacy.privacy.membership_inference_attack.utils as utils
|
||||
|
||||
ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)'
|
||||
|
||||
|
@ -173,21 +173,19 @@ class AttackInputData:
|
|||
'Please set labels_train and labels_test')
|
||||
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
|
||||
|
||||
@staticmethod
|
||||
def _get_loss(logits: np.ndarray, true_labels: np.ndarray):
|
||||
return logits[range(logits.shape[0]), true_labels]
|
||||
|
||||
def get_loss_train(self):
|
||||
"""Calculates cross-entropy losses for the training set."""
|
||||
if self.loss_train is not None:
|
||||
"""Calculates (if needed) cross-entropy losses for the training set."""
|
||||
if self.loss_train is None:
|
||||
self.loss_train = utils.log_loss_from_logits(self.labels_train,
|
||||
self.logits_train)
|
||||
return self.loss_train
|
||||
return self._get_loss(self.logits_train, self.labels_train)
|
||||
|
||||
def get_loss_test(self):
|
||||
"""Calculates cross-entropy losses for the test set."""
|
||||
if self.loss_test is not None:
|
||||
"""Calculates (if needed) cross-entropy losses for the test set."""
|
||||
if self.loss_test is None:
|
||||
self.loss_test = utils.log_loss_from_logits(self.labels_test,
|
||||
self.logits_test)
|
||||
return self.loss_test
|
||||
return self._get_loss(self.logits_test, self.labels_test)
|
||||
|
||||
def get_train_size(self):
|
||||
"""Returns size of the training set."""
|
||||
|
@ -365,11 +363,13 @@ class AttackResults:
|
|||
advantages.append(float(attack_result.get_attacker_advantage()))
|
||||
aucs.append(float(attack_result.get_auc()))
|
||||
|
||||
df = pd.DataFrame({'slice feature': slice_features,
|
||||
df = pd.DataFrame({
|
||||
'slice feature': slice_features,
|
||||
'slice value': slice_values,
|
||||
'attack type': attack_types,
|
||||
'attack advantage': advantages,
|
||||
'roc auc': aucs})
|
||||
'roc auc': aucs
|
||||
})
|
||||
return df
|
||||
|
||||
def summary(self, by_slices=False) -> str:
|
||||
|
@ -452,3 +452,27 @@ class AttackResults:
|
|||
"""Loads AttackResults from a pickle file."""
|
||||
with open(filepath, 'rb') as inp:
|
||||
return pickle.load(inp)
|
||||
|
||||
|
||||
def get_flattened_attack_metrics(results: AttackResults):
|
||||
"""Get flattened attack metrics.
|
||||
|
||||
Args:
|
||||
results: membership inference attack results.
|
||||
|
||||
Returns:
|
||||
properties: a list of (slice, attack_type, metric name)
|
||||
values: a list of metric values, i-th element correspond to properties[i]
|
||||
"""
|
||||
properties = []
|
||||
values = []
|
||||
for attack_result in results.single_attack_results:
|
||||
slice_spec = attack_result.slice_spec
|
||||
prop = [str(slice_spec), str(attack_result.attack_type)]
|
||||
properties += [prop + ['adv'], prop + ['auc']]
|
||||
values += [
|
||||
float(attack_result.get_attacker_advantage()),
|
||||
float(attack_result.get_auc())
|
||||
]
|
||||
|
||||
return properties, values
|
||||
|
|
|
@ -47,13 +47,15 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
|
||||
def test_get_loss(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]),
|
||||
logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]),
|
||||
logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]),
|
||||
logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]),
|
||||
labels_train=np.array([1, 0]),
|
||||
labels_test=np.array([0, 1]))
|
||||
labels_test=np.array([0, 2]))
|
||||
|
||||
np.testing.assert_equal(attack_input.get_loss_train().tolist(), [0.5, 0.2])
|
||||
np.testing.assert_equal(attack_input.get_loss_test().tolist(), [0.2, 0.5])
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_train(), [0.36313551, 1.37153903], atol=1e-7)
|
||||
np.testing.assert_allclose(
|
||||
attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7)
|
||||
|
||||
def test_get_loss_explicitly_provided(self):
|
||||
attack_input = AttackInputData(
|
||||
|
@ -237,8 +239,9 @@ class AttackResultsTest(absltest.TestCase):
|
|||
self.assertEqual(repr(results), repr(loaded_results))
|
||||
|
||||
def test_calculate_pd_dataframe(self):
|
||||
single_results = [self.perfect_classifier_result,
|
||||
self.random_classifier_result]
|
||||
single_results = [
|
||||
self.perfect_classifier_result, self.random_classifier_result
|
||||
]
|
||||
results = AttackResults(single_results)
|
||||
df = results.calculate_pd_dataframe()
|
||||
df_expected = pd.DataFrame({
|
||||
|
|
|
@ -24,8 +24,8 @@ import tensorflow.compat.v1 as tf
|
|||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
||||
|
||||
|
@ -86,7 +86,7 @@ class MembershipInferenceCallback(tf.keras.callbacks.Callback):
|
|||
self._attack_types)
|
||||
logging.info(results)
|
||||
|
||||
attack_properties, attack_values = get_all_attack_results(results)
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(results)
|
||||
print('Attack result:')
|
||||
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||
zip(attack_properties, attack_values)]))
|
||||
|
|
|
@ -21,11 +21,10 @@ from absl import flags
|
|||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
|
||||
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
|
||||
|
@ -40,11 +39,16 @@ flags.DEFINE_string('model_dir', None, 'Model directory.')
|
|||
def cnn_model():
|
||||
"""Define a CNN model."""
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2D(16, 8, strides=2, padding='same',
|
||||
activation='relu', input_shape=(28, 28, 1)),
|
||||
tf.keras.layers.Conv2D(
|
||||
16,
|
||||
8,
|
||||
strides=2,
|
||||
padding='same',
|
||||
activation='relu',
|
||||
input_shape=(28, 28, 1)),
|
||||
tf.keras.layers.MaxPool2D(2, 1),
|
||||
tf.keras.layers.Conv2D(32, 4, strides=2, padding='valid',
|
||||
activation='relu'),
|
||||
tf.keras.layers.Conv2D(
|
||||
32, 4, strides=2, padding='valid', activation='relu'),
|
||||
tf.keras.layers.MaxPool2D(2, 1),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(32, activation='relu'),
|
||||
|
@ -83,13 +87,14 @@ def main(unused_argv):
|
|||
|
||||
# Get callback for membership inference attack.
|
||||
mia_callback = MembershipInferenceCallback(
|
||||
(train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
(train_data, train_labels), (test_data, test_labels),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK],
|
||||
tensorboard_dir=FLAGS.model_dir)
|
||||
|
||||
# Train model with Keras
|
||||
model.fit(train_data, train_labels,
|
||||
model.fit(
|
||||
train_data,
|
||||
train_labels,
|
||||
epochs=FLAGS.epochs,
|
||||
validation_data=(test_data, test_labels),
|
||||
batch_size=FLAGS.batch_size,
|
||||
|
@ -98,16 +103,18 @@ def main(unused_argv):
|
|||
|
||||
print('End of training attack:')
|
||||
attack_results = run_attack_on_keras_model(
|
||||
model,
|
||||
(train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
model, (train_data, train_labels), (test_data, test_labels),
|
||||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
|
||||
)
|
||||
attack_types=[
|
||||
AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
|
||||
])
|
||||
|
||||
attack_properties, attack_values = get_all_attack_results(attack_results)
|
||||
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||
zip(attack_properties, attack_values)]))
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(
|
||||
attack_results)
|
||||
print('\n'.join([
|
||||
' %s: %.4f' % (', '.join(p), r)
|
||||
for p, r in zip(attack_properties, attack_values)
|
||||
]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -23,7 +23,7 @@ import tensorflow.compat.v1 as tf
|
|||
from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||
|
||||
|
||||
class UtilsTest(absltest.TestCase):
|
||||
|
@ -67,7 +67,7 @@ class UtilsTest(absltest.TestCase):
|
|||
(self.test_data, self.test_labels),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, AttackResults)
|
||||
attack_properties, attack_values = get_all_attack_results(results)
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(results)
|
||||
self.assertLen(attack_properties, 2)
|
||||
self.assertLen(attack_values, 2)
|
||||
|
||||
|
|
|
@ -26,8 +26,8 @@ import tensorflow.compat.v1 as tf
|
|||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack_new as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard
|
||||
|
||||
|
@ -101,7 +101,7 @@ class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
|||
self._attack_types)
|
||||
logging.info(results)
|
||||
|
||||
attack_properties, attack_values = get_all_attack_results(results)
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(results)
|
||||
print('Attack result:')
|
||||
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||
zip(attack_properties, attack_values)]))
|
||||
|
|
|
@ -22,10 +22,10 @@ from absl import logging
|
|||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
|
||||
|
@ -63,9 +63,7 @@ def cnn_model_fn(features, labels, mode):
|
|||
global_step = tf.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss=scalar_loss, global_step=global_step)
|
||||
return tf.estimator.EstimatorSpec(
|
||||
mode=mode,
|
||||
loss=scalar_loss,
|
||||
train_op=train_op)
|
||||
mode=mode, loss=scalar_loss, train_op=train_op)
|
||||
|
||||
# Add evaluation metrics (for EVAL mode).
|
||||
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||
|
@ -108,8 +106,8 @@ def main(unused_argv):
|
|||
train_data, train_labels, test_data, test_labels = load_mnist()
|
||||
|
||||
# Instantiate the tf.Estimator.
|
||||
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
|
||||
model_dir=FLAGS.model_dir)
|
||||
mnist_classifier = tf.estimator.Estimator(
|
||||
model_fn=cnn_model_fn, model_dir=FLAGS.model_dir)
|
||||
|
||||
# A function to construct input_fn given (data, label), to be used by the
|
||||
# membership inference training hook.
|
||||
|
@ -124,9 +122,7 @@ def main(unused_argv):
|
|||
else:
|
||||
summary_writer = None
|
||||
mia_hook = MembershipInferenceTrainingHook(
|
||||
mnist_classifier,
|
||||
(train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
mnist_classifier, (train_data, train_labels), (test_data, test_labels),
|
||||
input_fn_constructor,
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK],
|
||||
writer=summary_writer)
|
||||
|
@ -145,8 +141,8 @@ def main(unused_argv):
|
|||
steps_per_epoch = 60000 // FLAGS.batch_size
|
||||
for epoch in range(1, FLAGS.epochs + 1):
|
||||
# Train the model, with the membership inference hook.
|
||||
mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch,
|
||||
hooks=[mia_hook])
|
||||
mnist_classifier.train(
|
||||
input_fn=train_input_fn, steps=steps_per_epoch, hooks=[mia_hook])
|
||||
|
||||
# Evaluate the model and print results
|
||||
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
|
||||
|
@ -155,16 +151,18 @@ def main(unused_argv):
|
|||
|
||||
print('End of training attack')
|
||||
attack_results = run_attack_on_tf_estimator_model(
|
||||
mnist_classifier,
|
||||
(train_data, train_labels),
|
||||
(test_data, test_labels),
|
||||
mnist_classifier, (train_data, train_labels), (test_data, test_labels),
|
||||
input_fn_constructor,
|
||||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS]
|
||||
)
|
||||
attack_properties, attack_values = get_all_attack_results(attack_results)
|
||||
print('\n'.join([' %s: %.4f' % (', '.join(p), r) for p, r in
|
||||
zip(attack_properties, attack_values)]))
|
||||
attack_types=[
|
||||
AttackType.THRESHOLD_ATTACK, AttackType.K_NEAREST_NEIGHBORS
|
||||
])
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(
|
||||
attack_results)
|
||||
print('\n'.join([
|
||||
' %s: %.4f' % (', '.join(p), r)
|
||||
for p, r in zip(attack_properties, attack_values)
|
||||
]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -23,7 +23,7 @@ import tensorflow.compat.v1 as tf
|
|||
from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.utils import get_all_attack_results
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||
|
||||
|
||||
class UtilsTest(absltest.TestCase):
|
||||
|
@ -88,7 +88,7 @@ class UtilsTest(absltest.TestCase):
|
|||
self.test_labels,
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, AttackResults)
|
||||
attack_properties, attack_values = get_all_attack_results(results)
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(results)
|
||||
self.assertLen(attack_properties, 2)
|
||||
self.assertLen(attack_values, 2)
|
||||
|
||||
|
@ -104,7 +104,7 @@ class UtilsTest(absltest.TestCase):
|
|||
input_fn_constructor,
|
||||
attack_types=[AttackType.THRESHOLD_ATTACK])
|
||||
self.assertIsInstance(results, AttackResults)
|
||||
attack_properties, attack_values = get_all_attack_results(results)
|
||||
attack_properties, attack_values = get_flattened_attack_metrics(results)
|
||||
self.assertLen(attack_properties, 2)
|
||||
self.assertLen(attack_values, 2)
|
||||
|
||||
|
|
|
@ -18,10 +18,9 @@
|
|||
from typing import Text, Dict, Union, List, Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
import scipy.special
|
||||
from sklearn import metrics
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
|
||||
|
||||
ArrayDict = Dict[Text, np.ndarray]
|
||||
Dataset = Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
|
||||
|
@ -74,25 +73,6 @@ def prepend_to_keys(in_dict: Dict[Text, Any], prefix: Text) -> Dict[Text, Any]:
|
|||
return {prefix + k: v for k, v in in_dict.items()}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Utilities for managing result.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_all_attack_results(results: AttackResults):
|
||||
"""Get all results as a list of attack properties and a list of attack result."""
|
||||
properties = []
|
||||
values = []
|
||||
for attack_result in results.single_attack_results:
|
||||
slice_spec = attack_result.slice_spec
|
||||
prop = [str(slice_spec), str(attack_result.attack_type)]
|
||||
properties += [prop + ['adv'], prop + ['auc']]
|
||||
values += [float(attack_result.get_attacker_advantage()),
|
||||
float(attack_result.get_auc())]
|
||||
|
||||
return properties, values
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Subsampling and data selection functionality
|
||||
# ------------------------------------------------------------------------------
|
||||
|
@ -245,19 +225,24 @@ def compute_performance_metrics(true_labels: np.ndarray,
|
|||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def log_loss(y, pred, small_value=1e-8):
|
||||
def log_loss(labels: np.ndarray, pred: np.ndarray, small_value=1e-8):
|
||||
"""Compute the cross entropy loss.
|
||||
|
||||
Args:
|
||||
y: numpy array, y[i] is the true label (scalar) of the i-th sample
|
||||
labels: numpy array, labels[i] is the true label (scalar) of the i-th sample
|
||||
pred: numpy array, pred[i] is the probability vector of the i-th sample
|
||||
small_value: np.log can become -inf if the probability is too close to 0,
|
||||
so the probability is clipped below by small_value.
|
||||
small_value: np.log can become -inf if the probability is too close to 0, so
|
||||
the probability is clipped below by small_value.
|
||||
|
||||
Returns:
|
||||
the cross-entropy loss of each sample
|
||||
"""
|
||||
return -np.log(np.maximum(pred[range(y.size), y], small_value))
|
||||
return -np.log(np.maximum(pred[range(labels.size), labels], small_value))
|
||||
|
||||
|
||||
def log_loss_from_logits(labels: np.ndarray, logits: np.ndarray):
|
||||
"""Compute the cross entropy loss from logits."""
|
||||
return log_loss(labels, scipy.special.softmax(logits, axis=-1))
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
|
|
|
@ -39,8 +39,10 @@ class UtilsTest(absltest.TestCase):
|
|||
|
||||
results = utils.compute_performance_metrics(true, pred, threshold=0.5)
|
||||
|
||||
for k in ['precision', 'recall', 'accuracy', 'f1_score', 'fpr', 'tpr',
|
||||
'thresholds', 'auc', 'advantage']:
|
||||
for k in [
|
||||
'precision', 'recall', 'accuracy', 'f1_score', 'fpr', 'tpr',
|
||||
'thresholds', 'auc', 'advantage'
|
||||
]:
|
||||
self.assertIn(k, results)
|
||||
|
||||
np.testing.assert_almost_equal(results['accuracy'], 1. / 2.)
|
||||
|
@ -107,10 +109,16 @@ class UtilsTest(absltest.TestCase):
|
|||
[0.75, 0.25], [0.9, 0.1], [0.99, 0.01]])
|
||||
# Test the cases when true label (for all samples) is 0 and 1
|
||||
expected_losses = {
|
||||
0: np.array([4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207,
|
||||
0.10536052, 0.01005034]),
|
||||
1: np.array([0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436,
|
||||
2.30258509, 4.60517019])
|
||||
0:
|
||||
np.array([
|
||||
4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207,
|
||||
0.10536052, 0.01005034
|
||||
]),
|
||||
1:
|
||||
np.array([
|
||||
0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436,
|
||||
2.30258509, 4.60517019
|
||||
])
|
||||
}
|
||||
for c in [0, 1]: # true label
|
||||
y = np.ones(shape=pred.shape[0], dtype=int) * c
|
||||
|
@ -139,8 +147,18 @@ class UtilsTest(absltest.TestCase):
|
|||
expected_losses = np.array([18.42068074, 46.05170186, 115.12925465])
|
||||
for i, small_value in enumerate(small_values):
|
||||
loss = utils.log_loss(y, pred, small_value)
|
||||
np.testing.assert_allclose(loss, np.array([expected_losses[i], 0]),
|
||||
atol=1e-7)
|
||||
np.testing.assert_allclose(
|
||||
loss, np.array([expected_losses[i], 0]), atol=1e-7)
|
||||
|
||||
def test_log_loss_from_logits(self):
|
||||
"""Test computing cross-entropy loss from logits."""
|
||||
|
||||
logits = np.array([[1, 2, 0, -1], [1, 2, 0, -1], [-1, 3, 0, 0]])
|
||||
labels = np.array([0, 3, 1])
|
||||
expected_loss = np.array([1.4401897, 3.4401897, 0.11144278])
|
||||
|
||||
loss = utils.log_loss_from_logits(labels, logits)
|
||||
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in a new issue