Add head for multi-label estimators in TF estimator framework.
PiperOrigin-RevId: 327048185
This commit is contained in:
parent
d939b22463
commit
a69b013390
4 changed files with 290 additions and 0 deletions
|
@ -37,6 +37,18 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "multi_label_head",
|
||||
srcs = [
|
||||
"multi_label_head.py",
|
||||
],
|
||||
deps = [
|
||||
"//third_party/py/tensorflow",
|
||||
"//third_party/tensorflow/python:keras_lib", # TODO(b/163395075): Remove when fixed.
|
||||
"//third_party/tensorflow_estimator",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dnn",
|
||||
srcs = [
|
||||
|
@ -85,6 +97,19 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "multi_label_head_test",
|
||||
timeout = "long",
|
||||
srcs = ["multi_label_head_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":multi_label_head",
|
||||
":test_utils",
|
||||
"//third_party/py/tensorflow",
|
||||
"//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dnn_test",
|
||||
timeout = "long",
|
||||
|
|
152
tensorflow_privacy/privacy/estimators/multi_label_head.py
Normal file
152
tensorflow_privacy/privacy/estimators/multi_label_head.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Multiclass head for Estimator that allow integration with TF Privacy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow_estimator.python.estimator import model_fn
|
||||
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||
from tensorflow_estimator.python.estimator.export import export_output
|
||||
from tensorflow_estimator.python.estimator.head import base_head
|
||||
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
|
||||
|
||||
|
||||
class DPMultiLabelHead(tf.estimator.MultiLabelHead):
|
||||
"""Creates a TF Privacy-enabled version of MultiLabelHead."""
|
||||
|
||||
def __init__(self,
|
||||
n_classes,
|
||||
weight_column=None,
|
||||
thresholds=None,
|
||||
label_vocabulary=None,
|
||||
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||
loss_fn=None,
|
||||
classes_for_class_based_metrics=None,
|
||||
name=None):
|
||||
if loss_reduction == tf.keras.losses.Reduction.NONE:
|
||||
loss_reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
|
||||
super(DPMultiLabelHead, self).__init__(
|
||||
n_classes=n_classes,
|
||||
weight_column=weight_column,
|
||||
thresholds=thresholds,
|
||||
label_vocabulary=label_vocabulary,
|
||||
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||
loss_fn=loss_fn,
|
||||
classes_for_class_based_metrics=classes_for_class_based_metrics,
|
||||
name=name)
|
||||
|
||||
def loss(self,
|
||||
labels,
|
||||
logits,
|
||||
features=None,
|
||||
mode=None,
|
||||
regularization_losses=None):
|
||||
"""Returns regularized training loss. See `base_head.Head` for details."""
|
||||
del mode # Unused for this head.
|
||||
with tf.compat.v1.name_scope(
|
||||
'losses', values=(logits, labels, regularization_losses, features)):
|
||||
logits = base_head.check_logits_final_dim(logits, self.logits_dimension)
|
||||
labels = self._processed_labels(logits, labels)
|
||||
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
||||
logits, labels, features)
|
||||
vector_training_loss = losses_utils.compute_weighted_loss(
|
||||
unweighted_loss,
|
||||
sample_weight=weights,
|
||||
reduction=tf.keras.losses.Reduction.NONE)
|
||||
regularization_loss = tf.math.add_n(
|
||||
regularization_losses) if regularization_losses is not None else None
|
||||
vector_regularized_training_loss = (
|
||||
tf.add(vector_training_loss, regularization_loss)
|
||||
if regularization_loss is not None else vector_training_loss)
|
||||
|
||||
return vector_regularized_training_loss
|
||||
|
||||
def _create_tpu_estimator_spec(self,
|
||||
features,
|
||||
mode,
|
||||
logits,
|
||||
labels=None,
|
||||
optimizer=None,
|
||||
trainable_variables=None,
|
||||
train_op_fn=None,
|
||||
update_ops=None,
|
||||
regularization_losses=None):
|
||||
"""See superclass for description."""
|
||||
|
||||
with tf.compat.v1.name_scope(self._name, 'head'):
|
||||
# Predict.
|
||||
pred_keys = prediction_keys.PredictionKeys
|
||||
predictions = self.predictions(logits)
|
||||
if mode == ModeKeys.PREDICT:
|
||||
probabilities = predictions[pred_keys.PROBABILITIES]
|
||||
classifier_output = base_head.classification_output(
|
||||
scores=probabilities,
|
||||
n_classes=self._n_classes,
|
||||
label_vocabulary=self._label_vocabulary)
|
||||
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
|
||||
mode=ModeKeys.PREDICT,
|
||||
predictions=predictions,
|
||||
export_outputs={
|
||||
base_head.DEFAULT_SERVING_KEY:
|
||||
classifier_output,
|
||||
base_head.CLASSIFY_SERVING_KEY:
|
||||
classifier_output,
|
||||
base_head.PREDICT_SERVING_KEY:
|
||||
export_output.PredictOutput(predictions)
|
||||
})
|
||||
regularized_training_loss = self.loss(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
features=features,
|
||||
mode=mode,
|
||||
regularization_losses=regularization_losses)
|
||||
scalar_loss = tf.reduce_mean(regularized_training_loss)
|
||||
# Eval.
|
||||
if mode == ModeKeys.EVAL:
|
||||
eval_metrics = self.metrics(regularization_losses=regularization_losses)
|
||||
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
|
||||
mode=ModeKeys.EVAL,
|
||||
predictions=predictions,
|
||||
loss=scalar_loss,
|
||||
eval_metrics=base_head.create_eval_metrics_tuple(
|
||||
self.update_metrics, {
|
||||
'eval_metrics': eval_metrics,
|
||||
'features': features,
|
||||
'logits': logits,
|
||||
'labels': labels,
|
||||
'regularization_losses': regularization_losses
|
||||
}))
|
||||
# Train.
|
||||
train_op = base_head.create_estimator_spec_train_op(
|
||||
head_name=self._name,
|
||||
optimizer=optimizer,
|
||||
train_op_fn=train_op_fn,
|
||||
update_ops=update_ops,
|
||||
trainable_variables=trainable_variables,
|
||||
regularized_training_loss=regularized_training_loss,
|
||||
loss_reduction=self._loss_reduction)
|
||||
# Create summary.
|
||||
base_head.create_estimator_spec_summary(
|
||||
regularized_training_loss=scalar_loss,
|
||||
regularization_losses=regularization_losses,
|
||||
summary_key_fn=self._summary_key)
|
||||
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
|
||||
mode=ModeKeys.TRAIN,
|
||||
predictions=predictions,
|
||||
loss=scalar_loss,
|
||||
train_op=train_op)
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for DP-enabled binary class heads."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.estimators import multi_label_head
|
||||
from tensorflow_privacy.privacy.estimators import test_utils
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
|
||||
|
||||
|
||||
class DPMultiLabelHeadTest(tf.test.TestCase):
|
||||
"""Tests for DP-enabled multilabel heads."""
|
||||
|
||||
def testLoss(self):
|
||||
"""Tests loss() returns per-example losses."""
|
||||
|
||||
head = multi_label_head.DPMultiLabelHead(3)
|
||||
features = {'feature_a': np.full((4), 1.0)}
|
||||
labels = np.array([[0, 1, 1], [1, 1, 0], [0, 1, 0], [1, 1, 1]])
|
||||
logits = np.array([[2.0, 1.5, 4.1], [2.0, 1.5, 4.1], [2.0, 1.5, 4.1],
|
||||
[2.0, 1.5, 4.1]])
|
||||
|
||||
actual_loss = head.loss(labels, logits, features)
|
||||
expected_loss = tf.reduce_mean(
|
||||
tf.compat.v1.losses.sigmoid_cross_entropy(
|
||||
multi_class_labels=labels,
|
||||
logits=logits,
|
||||
reduction=tf.keras.losses.Reduction.NONE),
|
||||
axis=-1,
|
||||
keepdims=True)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
self.assertEqual(actual_loss.shape, [4, 1])
|
||||
self.assertAllClose(actual_loss, expected_loss)
|
||||
return
|
||||
|
||||
self.assertEqual(actual_loss.shape, [4, 1])
|
||||
self.assertAllClose(expected_loss, self.evaluate(actual_loss))
|
||||
|
||||
def testCreateTPUEstimatorSpec(self):
|
||||
"""Tests that an Estimator built with this head works."""
|
||||
|
||||
train_features, train_labels = test_utils.make_multilabel_input_data(256)
|
||||
feature_columns = []
|
||||
for key in train_features:
|
||||
feature_columns.append(tf.feature_column.numeric_column(key=key))
|
||||
|
||||
head = multi_label_head.DPMultiLabelHead(3)
|
||||
optimizer = DPKerasSGDOptimizer(
|
||||
learning_rate=0.5,
|
||||
l2_norm_clip=1.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=2)
|
||||
model_fn = test_utils.make_model_fn(head, optimizer, feature_columns)
|
||||
classifier = tf.estimator.Estimator(model_fn=model_fn)
|
||||
|
||||
classifier.train(
|
||||
input_fn=test_utils.make_input_fn(train_features, train_labels, True),
|
||||
steps=4)
|
||||
|
||||
test_features, test_labels = test_utils.make_multilabel_input_data(64)
|
||||
classifier.evaluate(
|
||||
input_fn=test_utils.make_input_fn(test_features, test_labels, False),
|
||||
steps=4)
|
||||
|
||||
predict_features, predict_labels = test_utils.make_multilabel_input_data(64)
|
||||
classifier.predict(
|
||||
input_fn=test_utils.make_input_fn(predict_features, predict_labels,
|
||||
False))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -54,6 +54,30 @@ def make_input_data(size, classes):
|
|||
return features, labels
|
||||
|
||||
|
||||
def make_multilabel_input_data(size):
|
||||
"""Create raw input data for testing."""
|
||||
feature_a = np.random.normal(4, 1, (size))
|
||||
feature_b = np.random.normal(5, 0.7, (size))
|
||||
feature_c = np.random.normal(6, 2, (size))
|
||||
noise_a = np.random.normal(0, 1, (size))
|
||||
noise_b = np.random.normal(0, 1, (size))
|
||||
noise_c = np.random.normal(0, 1, (size))
|
||||
features = {
|
||||
'feature_a': feature_a,
|
||||
'feature_b': feature_b,
|
||||
'feature_c': feature_c,
|
||||
}
|
||||
|
||||
def label_fn(a, b, c):
|
||||
return [int(a > 4), int(b > 5), int(c > 6)]
|
||||
|
||||
labels = list(
|
||||
map(label_fn, feature_a + noise_a, feature_b + noise_b,
|
||||
feature_c + noise_c))
|
||||
|
||||
return features, labels
|
||||
|
||||
|
||||
def make_input_fn(features, labels, training, batch_size=16):
|
||||
"""Returns an input function suitable for an estimator."""
|
||||
|
||||
|
|
Loading…
Reference in a new issue