forked from 626_privacy/tensorflow_privacy
A training hook and a function to be called in the end of training for tf estimator to perform membership inference attack.
PiperOrigin-RevId: 321648371
This commit is contained in:
parent
51eb7c3712
commit
a0e1b72838
5 changed files with 497 additions and 0 deletions
|
@ -0,0 +1,166 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""A hook and a function in tf estimator for membership inference attack."""
|
||||||
|
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_losses(estimator, input_fn, labels):
|
||||||
|
"""Get predictions and losses for samples.
|
||||||
|
|
||||||
|
The assumptions are 1) the loss is cross-entropy loss, and 2) user have
|
||||||
|
specified prediction mode to return predictions, e.g.,
|
||||||
|
when mode == tf.estimator.ModeKeys.PREDICT, the model function returns
|
||||||
|
tf.estimator.EstimatorSpec(mode=mode, predictions=tf.nn.softmax(logits)).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
estimator: model to make prediction
|
||||||
|
input_fn: input function to be used in estimator.predict
|
||||||
|
labels: true labels of samples
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
preds: probability vector of each sample
|
||||||
|
loss: cross entropy loss of each sample
|
||||||
|
"""
|
||||||
|
pred = np.array(list(estimator.predict(input_fn=input_fn)))
|
||||||
|
loss = log_loss(labels, pred)
|
||||||
|
return pred, loss
|
||||||
|
|
||||||
|
|
||||||
|
class MembershipInferenceTrainingHook(tf.estimator.SessionRunHook):
|
||||||
|
"""Training hook to perform membership inference attack after an epoch."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
estimator,
|
||||||
|
in_train,
|
||||||
|
out_train,
|
||||||
|
input_fn_constructor,
|
||||||
|
attack_classifiers,
|
||||||
|
writer=None):
|
||||||
|
"""Initalizes the hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
estimator: model to be tested
|
||||||
|
in_train: (in_training samples, in_training labels)
|
||||||
|
out_train: (out_training samples, out_training labels)
|
||||||
|
input_fn_constructor: a function that receives sample, label and construct
|
||||||
|
the input_fn for model prediction
|
||||||
|
attack_classifiers: a list of classifiers to be used by attacker, must be
|
||||||
|
a subset of ['lr', 'mlp', 'rf', 'knn']
|
||||||
|
writer: summary writer for tensorboard
|
||||||
|
"""
|
||||||
|
in_train_data, self._in_train_labels = in_train
|
||||||
|
out_train_data, self._out_train_labels = out_train
|
||||||
|
|
||||||
|
# Define the input functions for both in and out-training samples.
|
||||||
|
self._in_train_input_fn = input_fn_constructor(in_train_data,
|
||||||
|
self._in_train_labels)
|
||||||
|
self._out_train_input_fn = input_fn_constructor(out_train_data,
|
||||||
|
self._out_train_labels)
|
||||||
|
self._estimator = estimator
|
||||||
|
self._attack_classifiers = attack_classifiers
|
||||||
|
self._writer = writer
|
||||||
|
if self._writer:
|
||||||
|
logging.info('Will write to tensorboard.')
|
||||||
|
|
||||||
|
def end(self, session):
|
||||||
|
results = run_attack_helper(self._estimator,
|
||||||
|
self._in_train_input_fn,
|
||||||
|
self._out_train_input_fn,
|
||||||
|
self._in_train_labels, self._out_train_labels,
|
||||||
|
self._attack_classifiers)
|
||||||
|
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
|
||||||
|
logging.info(results)
|
||||||
|
|
||||||
|
if self._writer:
|
||||||
|
summary = tf.Summary()
|
||||||
|
summary.value.add(tag='attack advantage',
|
||||||
|
simple_value=results['all_thresh_loss_advantage'])
|
||||||
|
global_step = self._estimator.get_variable_value('global_step')
|
||||||
|
self._writer.add_summary(summary, global_step)
|
||||||
|
self._writer.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def run_attack_on_tf_estimator_model(estimator, in_train, out_train,
|
||||||
|
input_fn_constructor, attack_classifiers):
|
||||||
|
"""A function to perform the attack in the end of training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
estimator: model to be tested
|
||||||
|
in_train: (in_training samples, in_training labels)
|
||||||
|
out_train: (out_training samples, out_training labels)
|
||||||
|
input_fn_constructor: a function that receives sample, label and construct
|
||||||
|
the input_fn for model prediction
|
||||||
|
attack_classifiers: a list of classifiers to be used by attacker, must be
|
||||||
|
a subset of ['lr', 'mlp', 'rf', 'knn']
|
||||||
|
Returns:
|
||||||
|
Results of the attack
|
||||||
|
"""
|
||||||
|
in_train_data, in_train_labels = in_train
|
||||||
|
out_train_data, out_train_labels = out_train
|
||||||
|
|
||||||
|
# Define the input functions for both in and out-training samples.
|
||||||
|
in_train_input_fn = input_fn_constructor(in_train_data, in_train_labels)
|
||||||
|
out_train_input_fn = input_fn_constructor(out_train_data, out_train_labels)
|
||||||
|
|
||||||
|
# Call the helper to run the attack.
|
||||||
|
results = run_attack_helper(estimator,
|
||||||
|
in_train_input_fn, out_train_input_fn,
|
||||||
|
in_train_labels, out_train_labels,
|
||||||
|
attack_classifiers)
|
||||||
|
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage'])
|
||||||
|
logging.info('End of training attack:')
|
||||||
|
logging.info(results)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def run_attack_helper(estimator,
|
||||||
|
in_train_input_fn, out_train_input_fn,
|
||||||
|
in_train_labels, out_train_labels,
|
||||||
|
attack_classifiers):
|
||||||
|
"""A helper function to perform attack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
estimator: model to be tested
|
||||||
|
in_train_input_fn: input_fn for in training data
|
||||||
|
out_train_input_fn: input_fn for out of training data
|
||||||
|
in_train_labels: in training labels
|
||||||
|
out_train_labels: out of training labels
|
||||||
|
attack_classifiers: a list of classifiers to be used by attacker, must be
|
||||||
|
a subset of ['lr', 'mlp', 'rf', 'knn']
|
||||||
|
Returns:
|
||||||
|
Results of the attack
|
||||||
|
"""
|
||||||
|
# Compute predictions and losses
|
||||||
|
in_train_pred, in_train_loss = calculate_losses(estimator,
|
||||||
|
in_train_input_fn,
|
||||||
|
in_train_labels)
|
||||||
|
out_train_pred, out_train_loss = calculate_losses(estimator,
|
||||||
|
out_train_input_fn,
|
||||||
|
out_train_labels)
|
||||||
|
results = mia.run_all_attacks(in_train_loss, out_train_loss,
|
||||||
|
in_train_pred, out_train_pred,
|
||||||
|
in_train_labels, out_train_labels,
|
||||||
|
attack_classifiers=attack_classifiers)
|
||||||
|
return results
|
||||||
|
|
|
@ -0,0 +1,162 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""An example for using tf_estimator_evaluation."""
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import MembershipInferenceTrainingHook
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation import run_attack_on_tf_estimator_model
|
||||||
|
|
||||||
|
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
|
||||||
|
flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
||||||
|
flags.DEFINE_integer('epochs', 10, 'Number of epochs')
|
||||||
|
flags.DEFINE_string('model_dir', None, 'Model directory.')
|
||||||
|
|
||||||
|
|
||||||
|
def cnn_model_fn(features, labels, mode):
|
||||||
|
"""Model function for a CNN."""
|
||||||
|
|
||||||
|
# Define CNN architecture using tf.keras.layers.
|
||||||
|
input_layer = tf.reshape(features['x'], [-1, 28, 28, 1])
|
||||||
|
y = tf.keras.layers.Conv2D(
|
||||||
|
16, 8, strides=2, padding='same', activation='relu').apply(input_layer)
|
||||||
|
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
||||||
|
y = tf.keras.layers.Conv2D(
|
||||||
|
32, 4, strides=2, padding='valid', activation='relu').apply(y)
|
||||||
|
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
||||||
|
y = tf.keras.layers.Flatten().apply(y)
|
||||||
|
y = tf.keras.layers.Dense(32, activation='relu').apply(y)
|
||||||
|
logits = tf.keras.layers.Dense(10).apply(y)
|
||||||
|
|
||||||
|
if mode != tf.estimator.ModeKeys.PREDICT:
|
||||||
|
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||||
|
labels=labels, logits=logits)
|
||||||
|
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
|
||||||
|
|
||||||
|
# Configure the training op (for TRAIN mode).
|
||||||
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||||
|
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
|
||||||
|
global_step = tf.train.get_global_step()
|
||||||
|
train_op = optimizer.minimize(loss=scalar_loss, global_step=global_step)
|
||||||
|
return tf.estimator.EstimatorSpec(
|
||||||
|
mode=mode,
|
||||||
|
loss=scalar_loss,
|
||||||
|
train_op=train_op)
|
||||||
|
|
||||||
|
# Add evaluation metrics (for EVAL mode).
|
||||||
|
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||||
|
eval_metric_ops = {
|
||||||
|
'accuracy':
|
||||||
|
tf.metrics.accuracy(
|
||||||
|
labels=labels, predictions=tf.argmax(input=logits, axis=1))
|
||||||
|
}
|
||||||
|
return tf.estimator.EstimatorSpec(
|
||||||
|
mode=mode, loss=scalar_loss, eval_metric_ops=eval_metric_ops)
|
||||||
|
|
||||||
|
# Output the prediction probability (for PREDICT mode).
|
||||||
|
elif mode == tf.estimator.ModeKeys.PREDICT:
|
||||||
|
predictions = tf.nn.softmax(logits)
|
||||||
|
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
|
||||||
|
|
||||||
|
|
||||||
|
def load_mnist():
|
||||||
|
"""Loads MNIST and preprocesses to combine training and validation data."""
|
||||||
|
(train_data,
|
||||||
|
train_labels), (test_data,
|
||||||
|
test_labels) = tf.keras.datasets.mnist.load_data()
|
||||||
|
|
||||||
|
train_data = np.array(train_data, dtype=np.float32) / 255
|
||||||
|
test_data = np.array(test_data, dtype=np.float32) / 255
|
||||||
|
|
||||||
|
train_labels = np.array(train_labels, dtype=np.int32)
|
||||||
|
test_labels = np.array(test_labels, dtype=np.int32)
|
||||||
|
|
||||||
|
return train_data, train_labels, test_data, test_labels
|
||||||
|
|
||||||
|
|
||||||
|
def main(unused_argv):
|
||||||
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
logging.set_stderrthreshold(logging.INFO)
|
||||||
|
logging.get_absl_handler().use_absl_log_file()
|
||||||
|
|
||||||
|
# Load training and test data.
|
||||||
|
train_data, train_labels, test_data, test_labels = load_mnist()
|
||||||
|
|
||||||
|
# Instantiate the tf.Estimator.
|
||||||
|
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
|
||||||
|
model_dir=FLAGS.model_dir)
|
||||||
|
|
||||||
|
# A function to construct input_fn given (data, label), to be used by the
|
||||||
|
# membership inference training hook.
|
||||||
|
def input_fn_constructor(x, y):
|
||||||
|
return tf.estimator.inputs.numpy_input_fn(x={'x': x}, y=y, shuffle=False)
|
||||||
|
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
# Get a summary writer for the hook to write to tensorboard.
|
||||||
|
# Can set summary_writer to None if not needed.
|
||||||
|
if FLAGS.model_dir:
|
||||||
|
summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
|
||||||
|
else:
|
||||||
|
summary_writer = None
|
||||||
|
mia_hook = MembershipInferenceTrainingHook(mnist_classifier,
|
||||||
|
(train_data, train_labels),
|
||||||
|
(test_data, test_labels),
|
||||||
|
input_fn_constructor,
|
||||||
|
[],
|
||||||
|
summary_writer)
|
||||||
|
|
||||||
|
# Create tf.Estimator input functions for the training and test data.
|
||||||
|
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||||
|
x={'x': train_data},
|
||||||
|
y=train_labels,
|
||||||
|
batch_size=FLAGS.batch_size,
|
||||||
|
num_epochs=FLAGS.epochs,
|
||||||
|
shuffle=True)
|
||||||
|
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||||
|
x={'x': test_data}, y=test_labels, num_epochs=1, shuffle=False)
|
||||||
|
|
||||||
|
# Training loop.
|
||||||
|
steps_per_epoch = 60000 // FLAGS.batch_size
|
||||||
|
for epoch in range(1, FLAGS.epochs + 1):
|
||||||
|
# Train the model, with the membership inference hook.
|
||||||
|
mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch,
|
||||||
|
hooks=[mia_hook])
|
||||||
|
|
||||||
|
# Evaluate the model and print results
|
||||||
|
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
|
||||||
|
test_accuracy = eval_results['accuracy']
|
||||||
|
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
||||||
|
|
||||||
|
print('End of training attack')
|
||||||
|
run_attack_on_tf_estimator_model(mnist_classifier,
|
||||||
|
(train_data, train_labels),
|
||||||
|
(test_data, test_labels),
|
||||||
|
input_fn_constructor,
|
||||||
|
['lr'])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main)
|
|
@ -0,0 +1,107 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.tf_estimator_evaluation."""
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import tf_estimator_evaluation
|
||||||
|
|
||||||
|
|
||||||
|
class UtilsTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self, methodname):
|
||||||
|
"""Initialize the test class."""
|
||||||
|
super().__init__(methodname)
|
||||||
|
|
||||||
|
self.ntrain, self.ntest = 50, 100
|
||||||
|
self.nclass = 5
|
||||||
|
self.ndim = 10
|
||||||
|
|
||||||
|
# Generate random training and test data
|
||||||
|
self.train_data = np.random.rand(self.ntrain, self.ndim)
|
||||||
|
self.test_data = np.random.rand(self.ntest, self.ndim)
|
||||||
|
self.train_labels = np.random.randint(self.nclass, size=self.ntrain)
|
||||||
|
self.test_labels = np.random.randint(self.nclass, size=self.ntest)
|
||||||
|
|
||||||
|
# Define a simple model function
|
||||||
|
def model_fn(features, labels, mode):
|
||||||
|
"""Model function for logistic regression."""
|
||||||
|
del labels
|
||||||
|
|
||||||
|
input_layer = tf.reshape(features['x'], [-1, self.ndim])
|
||||||
|
logits = tf.keras.layers.Dense(self.nclass).apply(input_layer)
|
||||||
|
|
||||||
|
# Define the PREDICT mode becasue we only need that
|
||||||
|
if mode == tf.estimator.ModeKeys.PREDICT:
|
||||||
|
predictions = tf.nn.softmax(logits)
|
||||||
|
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
|
||||||
|
|
||||||
|
# Define the classifier, input_fn for training and test data
|
||||||
|
self.classifier = tf.estimator.Estimator(model_fn=model_fn)
|
||||||
|
self.input_fn_train = tf.estimator.inputs.numpy_input_fn(
|
||||||
|
x={'x': self.train_data}, y=self.train_labels, num_epochs=1,
|
||||||
|
shuffle=False)
|
||||||
|
self.input_fn_test = tf.estimator.inputs.numpy_input_fn(
|
||||||
|
x={'x': self.test_data}, y=self.test_labels, num_epochs=1,
|
||||||
|
shuffle=False)
|
||||||
|
|
||||||
|
def test_calculate_losses(self):
|
||||||
|
"""Test calculating the loss."""
|
||||||
|
pred, loss = tf_estimator_evaluation.calculate_losses(self.classifier,
|
||||||
|
self.input_fn_train,
|
||||||
|
self.train_labels)
|
||||||
|
self.assertEqual(pred.shape, (self.ntrain, self.nclass))
|
||||||
|
self.assertEqual(loss.shape, (self.ntrain,))
|
||||||
|
|
||||||
|
pred, loss = tf_estimator_evaluation.calculate_losses(self.classifier,
|
||||||
|
self.input_fn_test,
|
||||||
|
self.test_labels)
|
||||||
|
self.assertEqual(pred.shape, (self.ntest, self.nclass))
|
||||||
|
self.assertEqual(loss.shape, (self.ntest,))
|
||||||
|
|
||||||
|
def test_run_attack_helper(self):
|
||||||
|
"""Test the attack."""
|
||||||
|
results = tf_estimator_evaluation.run_attack_helper(self.classifier,
|
||||||
|
self.input_fn_train,
|
||||||
|
self.input_fn_test,
|
||||||
|
self.train_labels,
|
||||||
|
self.test_labels,
|
||||||
|
[])
|
||||||
|
self.assertIsInstance(results, dict)
|
||||||
|
self.assertIn('all_thresh_loss_auc', results)
|
||||||
|
self.assertIn('all_thresh_loss_advantage', results)
|
||||||
|
|
||||||
|
def test_run_attack_on_tf_estimator_model(self):
|
||||||
|
"""Test the attack on the final models."""
|
||||||
|
def input_fn_constructor(x, y):
|
||||||
|
return tf.estimator.inputs.numpy_input_fn(x={'x': x}, y=y, shuffle=False)
|
||||||
|
|
||||||
|
results = tf_estimator_evaluation.run_attack_on_tf_estimator_model(
|
||||||
|
self.classifier,
|
||||||
|
(self.train_data, self.train_labels),
|
||||||
|
(self.test_data, self.test_labels),
|
||||||
|
input_fn_constructor,
|
||||||
|
[])
|
||||||
|
self.assertIsInstance(results, dict)
|
||||||
|
self.assertIn('all_thresh_loss_auc', results)
|
||||||
|
self.assertIn('all_thresh_loss_advantage', results)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
|
@ -216,3 +216,23 @@ def compute_performance_metrics(true_labels: np.ndarray,
|
||||||
'advantage': advantage,
|
'advantage': advantage,
|
||||||
})
|
})
|
||||||
return ensure_1d(results)
|
return ensure_1d(results)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Loss functions
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def log_loss(y, pred, small_value=1e-8):
|
||||||
|
"""Compute the cross entropy loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: numpy array, y[i] is the true label (scalar) of the i-th sample
|
||||||
|
pred: numpy array, pred[i] is the probability vector of the i-th sample
|
||||||
|
small_value: np.log can become -inf if the probability is too close to 0,
|
||||||
|
so the probability is clipped below by small_value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the cross-entropy loss of each sample
|
||||||
|
"""
|
||||||
|
return -np.log(np.maximum(pred[range(y.size), y], small_value))
|
||||||
|
|
|
@ -100,6 +100,48 @@ class UtilsTest(absltest.TestCase):
|
||||||
self.assertEqual(x_test.shape, (n_test, 11))
|
self.assertEqual(x_test.shape, (n_test, 11))
|
||||||
self.assertEqual(y_test.shape, (n_test,))
|
self.assertEqual(y_test.shape, (n_test,))
|
||||||
|
|
||||||
|
def test_log_loss(self):
|
||||||
|
"""Test computing cross-entropy loss."""
|
||||||
|
# Test binary case with a few normal values
|
||||||
|
pred = np.array([[0.01, 0.99], [0.1, 0.9], [0.25, 0.75], [0.5, 0.5],
|
||||||
|
[0.75, 0.25], [0.9, 0.1], [0.99, 0.01]])
|
||||||
|
# Test the cases when true label (for all samples) is 0 and 1
|
||||||
|
expected_losses = {
|
||||||
|
0: np.array([4.60517019, 2.30258509, 1.38629436, 0.69314718, 0.28768207,
|
||||||
|
0.10536052, 0.01005034]),
|
||||||
|
1: np.array([0.01005034, 0.10536052, 0.28768207, 0.69314718, 1.38629436,
|
||||||
|
2.30258509, 4.60517019])
|
||||||
|
}
|
||||||
|
for c in [0, 1]: # true label
|
||||||
|
y = np.ones(shape=pred.shape[0], dtype=int) * c
|
||||||
|
loss = utils.log_loss(y, pred)
|
||||||
|
np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7)
|
||||||
|
|
||||||
|
# Test multiclass case with a few normal values
|
||||||
|
# (values from http://bit.ly/RJJHWA)
|
||||||
|
pred = np.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3],
|
||||||
|
[0.99, 0.002, 0.008]])
|
||||||
|
# Test the cases when true label (for all samples) is 0, 1, and 2
|
||||||
|
expected_losses = {
|
||||||
|
0: np.array([1.60943791, 0.51082562, 0.51082562, 0.01005034]),
|
||||||
|
1: np.array([0.35667494, 1.60943791, 2.30258509, 6.2146081]),
|
||||||
|
2: np.array([2.30258509, 1.60943791, 1.2039728, 4.82831374])
|
||||||
|
}
|
||||||
|
for c in range(3): # true label
|
||||||
|
y = np.ones(shape=pred.shape[0], dtype=int) * c
|
||||||
|
loss = utils.log_loss(y, pred)
|
||||||
|
np.testing.assert_allclose(loss, expected_losses[c], atol=1e-7)
|
||||||
|
|
||||||
|
# Test boundary values 0 and 1
|
||||||
|
pred = np.array([[0, 1]] * 2)
|
||||||
|
y = np.array([0, 1])
|
||||||
|
small_values = [1e-8, 1e-20, 1e-50]
|
||||||
|
expected_losses = np.array([18.42068074, 46.05170186, 115.12925465])
|
||||||
|
for i, small_value in enumerate(small_values):
|
||||||
|
loss = utils.log_loss(y, pred, small_value)
|
||||||
|
np.testing.assert_allclose(loss, np.array([expected_losses[i], 0]),
|
||||||
|
atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue