Add DP-enabled binary-class head and multi-class heads for Estimator.
PiperOrigin-RevId: 325921076
This commit is contained in:
parent
43a0e4be8a
commit
3a641e077e
7 changed files with 713 additions and 1 deletions
68
tensorflow_privacy/privacy/estimators/BUILD
Normal file
68
tensorflow_privacy/privacy/estimators/BUILD
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "head_utils",
|
||||||
|
srcs = [
|
||||||
|
"head_utils.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":binary_class_head",
|
||||||
|
":multi_class_head",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "binary_class_head",
|
||||||
|
srcs = [
|
||||||
|
"binary_class_head.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
# TODO(b/163395075): Remove this dependency once necessary function is public.
|
||||||
|
"//third_party/tensorflow/python:keras_lib",
|
||||||
|
"//third_party/tensorflow_estimator",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "multi_class_head",
|
||||||
|
srcs = [
|
||||||
|
"multi_class_head.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
# TODO(b/163395075): Remove this dependency once necessary function is public.
|
||||||
|
"//third_party/tensorflow/python:keras_lib",
|
||||||
|
"//third_party/tensorflow_estimator",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "binary_class_head_test",
|
||||||
|
timeout = "long",
|
||||||
|
srcs = ["binary_class_head_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":binary_class_head",
|
||||||
|
"//third_party/py/absl/testing:parameterized",
|
||||||
|
"//third_party/py/six",
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
"//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "multi_class_head_test",
|
||||||
|
timeout = "long",
|
||||||
|
srcs = ["multi_class_head_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":multi_class_head",
|
||||||
|
"//third_party/py/absl/testing:parameterized",
|
||||||
|
"//third_party/py/six",
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
"//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
|
||||||
|
],
|
||||||
|
)
|
147
tensorflow_privacy/privacy/estimators/binary_class_head.py
Normal file
147
tensorflow_privacy/privacy/estimators/binary_class_head.py
Normal file
|
@ -0,0 +1,147 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Binary class head for Estimator that allow integration with TF Privacy."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
||||||
|
from tensorflow_estimator.python.estimator import model_fn
|
||||||
|
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||||
|
from tensorflow_estimator.python.estimator.export import export_output
|
||||||
|
from tensorflow_estimator.python.estimator.head import base_head
|
||||||
|
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
|
||||||
|
|
||||||
|
|
||||||
|
class DPBinaryClassHead(tf.estimator.BinaryClassHead):
|
||||||
|
"""Creates a TF Privacy-enabled version of BinaryClassHead."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
weight_column=None,
|
||||||
|
thresholds=None,
|
||||||
|
label_vocabulary=None,
|
||||||
|
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||||
|
loss_fn=None,
|
||||||
|
name=None):
|
||||||
|
super(DPBinaryClassHead, self).__init__(
|
||||||
|
weight_column=weight_column,
|
||||||
|
thresholds=thresholds,
|
||||||
|
label_vocabulary=label_vocabulary,
|
||||||
|
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
def loss(self,
|
||||||
|
labels,
|
||||||
|
logits,
|
||||||
|
features=None,
|
||||||
|
mode=None,
|
||||||
|
regularization_losses=None):
|
||||||
|
"""Returns regularized training loss. See `base_head.Head` for details."""
|
||||||
|
del mode # Unused for this head.
|
||||||
|
with tf.compat.v1.name_scope(
|
||||||
|
'losses', values=(logits, labels, regularization_losses, features)):
|
||||||
|
logits = base_head.check_logits_final_dim(logits, self.logits_dimension)
|
||||||
|
labels = self._processed_labels(logits, labels)
|
||||||
|
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
||||||
|
logits, labels, features)
|
||||||
|
vector_training_loss = losses_utils.compute_weighted_loss(
|
||||||
|
unweighted_loss,
|
||||||
|
sample_weight=weights,
|
||||||
|
reduction=tf.keras.losses.Reduction.NONE)
|
||||||
|
regularization_loss = tf.math.add_n(
|
||||||
|
regularization_losses) if regularization_losses is not None else None
|
||||||
|
vector_regularized_training_loss = (
|
||||||
|
tf.add(vector_training_loss, regularization_loss)
|
||||||
|
if regularization_loss is not None else vector_training_loss)
|
||||||
|
|
||||||
|
return vector_regularized_training_loss
|
||||||
|
|
||||||
|
def _create_tpu_estimator_spec(self,
|
||||||
|
features,
|
||||||
|
mode,
|
||||||
|
logits,
|
||||||
|
labels=None,
|
||||||
|
optimizer=None,
|
||||||
|
trainable_variables=None,
|
||||||
|
train_op_fn=None,
|
||||||
|
update_ops=None,
|
||||||
|
regularization_losses=None):
|
||||||
|
"""See superclass for description."""
|
||||||
|
|
||||||
|
with tf.compat.v1.name_scope(self._name, 'head'):
|
||||||
|
# Predict.
|
||||||
|
pred_keys = prediction_keys.PredictionKeys
|
||||||
|
predictions = self.predictions(logits)
|
||||||
|
if mode == ModeKeys.PREDICT:
|
||||||
|
probabilities = predictions[pred_keys.PROBABILITIES]
|
||||||
|
logistic = predictions[pred_keys.LOGISTIC]
|
||||||
|
classifier_output = base_head.classification_output(
|
||||||
|
scores=probabilities,
|
||||||
|
n_classes=2,
|
||||||
|
label_vocabulary=self._label_vocabulary)
|
||||||
|
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
|
||||||
|
mode=ModeKeys.PREDICT,
|
||||||
|
predictions=predictions,
|
||||||
|
export_outputs={
|
||||||
|
base_head.DEFAULT_SERVING_KEY: classifier_output,
|
||||||
|
base_head.CLASSIFY_SERVING_KEY: classifier_output,
|
||||||
|
base_head.REGRESS_SERVING_KEY:
|
||||||
|
export_output.RegressionOutput(value=logistic),
|
||||||
|
base_head.PREDICT_SERVING_KEY:
|
||||||
|
export_output.PredictOutput(predictions)
|
||||||
|
})
|
||||||
|
regularized_training_loss = self.loss(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
features=features,
|
||||||
|
mode=mode,
|
||||||
|
regularization_losses=regularization_losses)
|
||||||
|
scalar_loss = tf.reduce_mean(regularized_training_loss)
|
||||||
|
# Eval.
|
||||||
|
if mode == ModeKeys.EVAL:
|
||||||
|
eval_metrics = self.metrics(regularization_losses=regularization_losses)
|
||||||
|
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
|
||||||
|
mode=ModeKeys.EVAL,
|
||||||
|
predictions=predictions,
|
||||||
|
loss=scalar_loss,
|
||||||
|
eval_metrics=base_head.create_eval_metrics_tuple(
|
||||||
|
self.update_metrics, {
|
||||||
|
'eval_metrics': eval_metrics,
|
||||||
|
'features': features,
|
||||||
|
'logits': logits,
|
||||||
|
'labels': labels,
|
||||||
|
'regularization_losses': regularization_losses
|
||||||
|
}))
|
||||||
|
# Train.
|
||||||
|
train_op = base_head.create_estimator_spec_train_op(
|
||||||
|
head_name=self._name,
|
||||||
|
optimizer=optimizer,
|
||||||
|
train_op_fn=train_op_fn,
|
||||||
|
update_ops=update_ops,
|
||||||
|
trainable_variables=trainable_variables,
|
||||||
|
regularized_training_loss=regularized_training_loss,
|
||||||
|
loss_reduction=self._loss_reduction)
|
||||||
|
# Create summary.
|
||||||
|
base_head.create_estimator_spec_summary(
|
||||||
|
regularized_training_loss=scalar_loss,
|
||||||
|
regularization_losses=regularization_losses,
|
||||||
|
summary_key_fn=self._summary_key)
|
||||||
|
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
|
||||||
|
mode=ModeKeys.TRAIN,
|
||||||
|
predictions=predictions,
|
||||||
|
loss=scalar_loss,
|
||||||
|
train_op=train_op)
|
136
tensorflow_privacy/privacy/estimators/binary_class_head_test.py
Normal file
136
tensorflow_privacy/privacy/estimators/binary_class_head_test.py
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for DP-enabled binary class heads."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.estimators import binary_class_head
|
||||||
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
|
||||||
|
|
||||||
|
|
||||||
|
class DPBinaryClassHeadTest(tf.test.TestCase):
|
||||||
|
"""Tests for DP-enabled heads."""
|
||||||
|
|
||||||
|
def _make_input_data(self, size):
|
||||||
|
"""Create raw input data."""
|
||||||
|
feature_a = np.random.normal(4, 1, (size))
|
||||||
|
feature_b = np.random.normal(5, 0.7, (size))
|
||||||
|
feature_c = np.random.normal(6, 2, (size))
|
||||||
|
noise = np.random.normal(0, 30, (size))
|
||||||
|
features = {
|
||||||
|
'feature_a': feature_a,
|
||||||
|
'feature_b': feature_b,
|
||||||
|
'feature_c': feature_c,
|
||||||
|
}
|
||||||
|
labels = np.array(
|
||||||
|
np.power(feature_a, 3) + np.power(feature_b, 2) +
|
||||||
|
np.power(feature_c, 1) + noise > 125).astype(int)
|
||||||
|
return features, labels
|
||||||
|
|
||||||
|
def _make_input_fn(self, features, labels, training, batch_size=16):
|
||||||
|
|
||||||
|
def input_fn():
|
||||||
|
"""An input function for training or evaluating."""
|
||||||
|
# Convert the inputs to a Dataset.
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
|
||||||
|
|
||||||
|
# Shuffle if in training mode.
|
||||||
|
if training:
|
||||||
|
dataset = dataset.shuffle(1000)
|
||||||
|
|
||||||
|
return dataset.batch(batch_size)
|
||||||
|
|
||||||
|
return input_fn
|
||||||
|
|
||||||
|
def _make_model_fn(self, head, optimizer, feature_columns):
|
||||||
|
"""Constructs and returns a model_fn using DPBinaryClassHead."""
|
||||||
|
|
||||||
|
def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument
|
||||||
|
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
|
||||||
|
inputs = feature_layer(features)
|
||||||
|
hidden_layer = tf.keras.layers.Dense(units=3, activation='relu')
|
||||||
|
hidden_layer_values = hidden_layer(inputs)
|
||||||
|
logits_layer = tf.keras.layers.Dense(
|
||||||
|
units=head.logits_dimension, activation=None)
|
||||||
|
logits = logits_layer(hidden_layer_values)
|
||||||
|
return head.create_estimator_spec(
|
||||||
|
features=features,
|
||||||
|
labels=labels,
|
||||||
|
mode=mode,
|
||||||
|
logits=logits,
|
||||||
|
trainable_variables=hidden_layer.trainable_weights +
|
||||||
|
logits_layer.trainable_weights,
|
||||||
|
optimizer=optimizer)
|
||||||
|
|
||||||
|
return model_fn
|
||||||
|
|
||||||
|
def testLoss(self):
|
||||||
|
"""Tests loss() returns per-example losses."""
|
||||||
|
|
||||||
|
head = binary_class_head.DPBinaryClassHead()
|
||||||
|
features = {'feature_a': np.full((4), 1.0)}
|
||||||
|
labels = np.array([[1.0], [1.0], [1.0], [0.0]])
|
||||||
|
logits = np.full((4, 1), 0.5)
|
||||||
|
|
||||||
|
actual_loss = head.loss(labels, logits, features)
|
||||||
|
expected_loss = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||||
|
labels=labels, logits=logits)
|
||||||
|
|
||||||
|
self.assertEqual(actual_loss.shape, [4, 1])
|
||||||
|
|
||||||
|
if tf.executing_eagerly():
|
||||||
|
self.assertEqual(actual_loss.shape, [4, 1])
|
||||||
|
self.assertAllClose(actual_loss, expected_loss)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.assertAllClose(expected_loss, self.evaluate(actual_loss))
|
||||||
|
|
||||||
|
def testCreateTPUEstimatorSpec(self):
|
||||||
|
"""Tests that an Estimator built with this head works."""
|
||||||
|
|
||||||
|
train_features, train_labels = self._make_input_data(256)
|
||||||
|
feature_columns = []
|
||||||
|
for key in train_features:
|
||||||
|
feature_columns.append(tf.feature_column.numeric_column(key=key))
|
||||||
|
|
||||||
|
head = binary_class_head.DPBinaryClassHead()
|
||||||
|
optimizer = DPKerasSGDOptimizer(
|
||||||
|
learning_rate=0.5,
|
||||||
|
l2_norm_clip=1.0,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=2)
|
||||||
|
model_fn = self._make_model_fn(head, optimizer, feature_columns)
|
||||||
|
classifier = tf.estimator.Estimator(model_fn=model_fn)
|
||||||
|
|
||||||
|
classifier.train(
|
||||||
|
input_fn=self._make_input_fn(train_features, train_labels, True),
|
||||||
|
steps=4)
|
||||||
|
|
||||||
|
test_features, test_labels = self._make_input_data(64)
|
||||||
|
classifier.evaluate(
|
||||||
|
input_fn=self._make_input_fn(test_features, test_labels, False),
|
||||||
|
steps=4)
|
||||||
|
|
||||||
|
predict_features, predict_labels_ = self._make_input_data(64)
|
||||||
|
classifier.predict(
|
||||||
|
input_fn=self._make_input_fn(predict_features, predict_labels_, False))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
60
tensorflow_privacy/privacy/estimators/head_utils.py
Normal file
60
tensorflow_privacy/privacy/estimators/head_utils.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Estimator heads that allow integration with TF Privacy."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.estimators.binary_class_head import DPBinaryClassHead
|
||||||
|
from tensorflow_privacy.privacy.estimators.multi_class_head import DPMultiClassHead
|
||||||
|
|
||||||
|
|
||||||
|
def binary_or_multi_class_head(n_classes, weight_column, label_vocabulary,
|
||||||
|
loss_reduction):
|
||||||
|
"""Creates either binary or multi-class head.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_classes: Number of label classes.
|
||||||
|
weight_column: A string or a `NumericColumn` created by
|
||||||
|
`tf.feature_column.numeric_column` defining feature column representing
|
||||||
|
weights. It is used to down weight or boost examples during training. It
|
||||||
|
will be multiplied by the loss of the example. If it is a string, it is
|
||||||
|
used as a key to fetch weight tensor from the `features`. If it is a
|
||||||
|
`NumericColumn`, raw tensor is fetched by key `weight_column.key`, then
|
||||||
|
weight_column.normalizer_fn is applied on it to get weight tensor.
|
||||||
|
label_vocabulary: A list of strings represents possible label values. If
|
||||||
|
given, labels must be string type and have any value in
|
||||||
|
`label_vocabulary`. If it is not given, that means labels are already
|
||||||
|
encoded as integer or float within [0, 1] for `n_classes=2` and encoded as
|
||||||
|
integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there
|
||||||
|
will be errors if vocabulary is not provided and labels are string.
|
||||||
|
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Defines how to
|
||||||
|
reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Head` instance.
|
||||||
|
"""
|
||||||
|
if n_classes == 2:
|
||||||
|
head = DPBinaryClassHead(
|
||||||
|
weight_column=weight_column,
|
||||||
|
label_vocabulary=label_vocabulary,
|
||||||
|
loss_reduction=loss_reduction)
|
||||||
|
else:
|
||||||
|
head = DPMultiClassHead(
|
||||||
|
n_classes,
|
||||||
|
weight_column=weight_column,
|
||||||
|
label_vocabulary=label_vocabulary,
|
||||||
|
loss_reduction=loss_reduction)
|
||||||
|
return head
|
146
tensorflow_privacy/privacy/estimators/multi_class_head.py
Normal file
146
tensorflow_privacy/privacy/estimators/multi_class_head.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Multiclass head for Estimator that allow integration with TF Privacy."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
||||||
|
from tensorflow_estimator.python.estimator import model_fn
|
||||||
|
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||||
|
from tensorflow_estimator.python.estimator.export import export_output
|
||||||
|
from tensorflow_estimator.python.estimator.head import base_head
|
||||||
|
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
|
||||||
|
|
||||||
|
|
||||||
|
class DPMultiClassHead(tf.estimator.MultiClassHead):
|
||||||
|
"""Creates a TF Privacy-enabled version of MultiClassHead."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n_classes,
|
||||||
|
weight_column=None,
|
||||||
|
label_vocabulary=None,
|
||||||
|
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||||
|
loss_fn=None,
|
||||||
|
name=None):
|
||||||
|
super(DPMultiClassHead, self).__init__(
|
||||||
|
n_classes=n_classes,
|
||||||
|
weight_column=weight_column,
|
||||||
|
label_vocabulary=label_vocabulary,
|
||||||
|
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
def loss(self,
|
||||||
|
labels,
|
||||||
|
logits,
|
||||||
|
features=None,
|
||||||
|
mode=None,
|
||||||
|
regularization_losses=None):
|
||||||
|
"""Returns regularized training loss. See `base_head.Head` for details."""
|
||||||
|
del mode # Unused for this head.
|
||||||
|
with tf.compat.v1.name_scope(
|
||||||
|
'losses', values=(logits, labels, regularization_losses, features)):
|
||||||
|
logits = base_head.check_logits_final_dim(logits, self.logits_dimension)
|
||||||
|
labels = self._processed_labels(logits, labels)
|
||||||
|
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
||||||
|
logits, labels, features)
|
||||||
|
vector_training_loss = losses_utils.compute_weighted_loss(
|
||||||
|
unweighted_loss,
|
||||||
|
sample_weight=weights,
|
||||||
|
reduction=tf.keras.losses.Reduction.NONE)
|
||||||
|
regularization_loss = tf.math.add_n(
|
||||||
|
regularization_losses) if regularization_losses is not None else None
|
||||||
|
vector_regularized_training_loss = (
|
||||||
|
tf.add(vector_training_loss, regularization_loss)
|
||||||
|
if regularization_loss is not None else vector_training_loss)
|
||||||
|
|
||||||
|
return vector_regularized_training_loss
|
||||||
|
|
||||||
|
def _create_tpu_estimator_spec(self,
|
||||||
|
features,
|
||||||
|
mode,
|
||||||
|
logits,
|
||||||
|
labels=None,
|
||||||
|
optimizer=None,
|
||||||
|
trainable_variables=None,
|
||||||
|
train_op_fn=None,
|
||||||
|
update_ops=None,
|
||||||
|
regularization_losses=None):
|
||||||
|
"""See superclass for description."""
|
||||||
|
|
||||||
|
with tf.compat.v1.name_scope(self._name, 'head'):
|
||||||
|
# Predict.
|
||||||
|
pred_keys = prediction_keys.PredictionKeys
|
||||||
|
predictions = self.predictions(logits)
|
||||||
|
if mode == ModeKeys.PREDICT:
|
||||||
|
probabilities = predictions[pred_keys.PROBABILITIES]
|
||||||
|
classifier_output = base_head.classification_output(
|
||||||
|
scores=probabilities,
|
||||||
|
n_classes=self._n_classes,
|
||||||
|
label_vocabulary=self._label_vocabulary)
|
||||||
|
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
|
||||||
|
mode=ModeKeys.PREDICT,
|
||||||
|
predictions=predictions,
|
||||||
|
export_outputs={
|
||||||
|
base_head.DEFAULT_SERVING_KEY:
|
||||||
|
classifier_output,
|
||||||
|
base_head.CLASSIFY_SERVING_KEY:
|
||||||
|
classifier_output,
|
||||||
|
base_head.PREDICT_SERVING_KEY:
|
||||||
|
export_output.PredictOutput(predictions)
|
||||||
|
})
|
||||||
|
regularized_training_loss = self.loss(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
features=features,
|
||||||
|
mode=mode,
|
||||||
|
regularization_losses=regularization_losses)
|
||||||
|
scalar_loss = tf.reduce_mean(regularized_training_loss)
|
||||||
|
# Eval.
|
||||||
|
if mode == ModeKeys.EVAL:
|
||||||
|
eval_metrics = self.metrics(regularization_losses=regularization_losses)
|
||||||
|
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
|
||||||
|
mode=ModeKeys.EVAL,
|
||||||
|
predictions=predictions,
|
||||||
|
loss=scalar_loss,
|
||||||
|
eval_metrics=base_head.create_eval_metrics_tuple(
|
||||||
|
self.update_metrics, {
|
||||||
|
'eval_metrics': eval_metrics,
|
||||||
|
'features': features,
|
||||||
|
'logits': logits,
|
||||||
|
'labels': labels,
|
||||||
|
'regularization_losses': regularization_losses
|
||||||
|
}))
|
||||||
|
# Train.
|
||||||
|
train_op = base_head.create_estimator_spec_train_op(
|
||||||
|
head_name=self._name,
|
||||||
|
optimizer=optimizer,
|
||||||
|
train_op_fn=train_op_fn,
|
||||||
|
update_ops=update_ops,
|
||||||
|
trainable_variables=trainable_variables,
|
||||||
|
regularized_training_loss=regularized_training_loss,
|
||||||
|
loss_reduction=self._loss_reduction)
|
||||||
|
# Create summary.
|
||||||
|
base_head.create_estimator_spec_summary(
|
||||||
|
regularized_training_loss=scalar_loss,
|
||||||
|
regularization_losses=regularization_losses,
|
||||||
|
summary_key_fn=self._summary_key)
|
||||||
|
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
|
||||||
|
mode=ModeKeys.TRAIN,
|
||||||
|
predictions=predictions,
|
||||||
|
loss=scalar_loss,
|
||||||
|
train_op=train_op)
|
152
tensorflow_privacy/privacy/estimators/multi_class_head_test.py
Normal file
152
tensorflow_privacy/privacy/estimators/multi_class_head_test.py
Normal file
|
@ -0,0 +1,152 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for DP-enabled binary class heads."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.estimators import multi_class_head
|
||||||
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
|
||||||
|
|
||||||
|
|
||||||
|
class DPMultiClassHeadTest(tf.test.TestCase):
|
||||||
|
"""Tests for DP-enabled heads."""
|
||||||
|
|
||||||
|
def _make_input_data(self, size):
|
||||||
|
"""Create raw input data."""
|
||||||
|
feature_a = np.random.normal(4, 1, (size))
|
||||||
|
feature_b = np.random.normal(5, 0.7, (size))
|
||||||
|
feature_c = np.random.normal(6, 2, (size))
|
||||||
|
noise = np.random.normal(0, 30, (size))
|
||||||
|
features = {
|
||||||
|
'feature_a': feature_a,
|
||||||
|
'feature_b': feature_b,
|
||||||
|
'feature_c': feature_c,
|
||||||
|
}
|
||||||
|
|
||||||
|
def label_fn(x):
|
||||||
|
if x < 110.0:
|
||||||
|
return 0
|
||||||
|
elif x < 140.0:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return 2
|
||||||
|
|
||||||
|
labels_list = map(
|
||||||
|
label_fn,
|
||||||
|
np.power(feature_a, 3) + np.power(feature_b, 2) +
|
||||||
|
np.power(feature_c, 1) + noise)
|
||||||
|
return features, list(labels_list)
|
||||||
|
|
||||||
|
def _make_input_fn(self, features, labels, training, batch_size=16):
|
||||||
|
|
||||||
|
def input_fn():
|
||||||
|
"""An input function for training or evaluating."""
|
||||||
|
# Convert the inputs to a Dataset.
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
|
||||||
|
|
||||||
|
# Shuffle if in training mode.
|
||||||
|
if training:
|
||||||
|
dataset = dataset.shuffle(1000)
|
||||||
|
|
||||||
|
return dataset.batch(batch_size)
|
||||||
|
|
||||||
|
return input_fn
|
||||||
|
|
||||||
|
def _make_model_fn(self, head, optimizer, feature_columns):
|
||||||
|
"""Constructs and returns a model_fn using DPBinaryClassHead."""
|
||||||
|
|
||||||
|
def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument
|
||||||
|
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
|
||||||
|
inputs = feature_layer(features)
|
||||||
|
hidden_layer = tf.keras.layers.Dense(units=3, activation='relu')
|
||||||
|
hidden_layer_values = hidden_layer(inputs)
|
||||||
|
logits_layer = tf.keras.layers.Dense(
|
||||||
|
units=head.logits_dimension, activation=None)
|
||||||
|
logits = logits_layer(hidden_layer_values)
|
||||||
|
return head.create_estimator_spec(
|
||||||
|
features=features,
|
||||||
|
labels=labels,
|
||||||
|
mode=mode,
|
||||||
|
logits=logits,
|
||||||
|
trainable_variables=hidden_layer.trainable_weights +
|
||||||
|
logits_layer.trainable_weights,
|
||||||
|
optimizer=optimizer)
|
||||||
|
|
||||||
|
return model_fn
|
||||||
|
|
||||||
|
def testLoss(self):
|
||||||
|
"""Tests loss() returns per-example losses."""
|
||||||
|
|
||||||
|
head = multi_class_head.DPMultiClassHead(3)
|
||||||
|
features = {'feature_a': np.full((4), 1.0)}
|
||||||
|
labels = np.array([[2], [1], [1], [0]])
|
||||||
|
logits = np.array([[2.0, 1.5, 4.1], [2.0, 1.5, 4.1], [2.0, 1.5, 4.1],
|
||||||
|
[2.0, 1.5, 4.1]])
|
||||||
|
|
||||||
|
actual_loss = head.loss(labels, logits, features)
|
||||||
|
expected_loss = tf.expand_dims(
|
||||||
|
tf.compat.v1.losses.sparse_softmax_cross_entropy(
|
||||||
|
labels=labels,
|
||||||
|
logits=logits,
|
||||||
|
reduction=tf.keras.losses.Reduction.NONE), -1)
|
||||||
|
|
||||||
|
self.assertEqual(actual_loss.shape, [4, 1])
|
||||||
|
|
||||||
|
if tf.executing_eagerly():
|
||||||
|
self.assertEqual(actual_loss.shape, [4, 1])
|
||||||
|
self.assertAllClose(actual_loss, expected_loss)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.assertAllClose(expected_loss, self.evaluate(actual_loss))
|
||||||
|
|
||||||
|
def testCreateTPUEstimatorSpec(self):
|
||||||
|
"""Tests that an Estimator built with this head works."""
|
||||||
|
|
||||||
|
train_features, train_labels = self._make_input_data(256)
|
||||||
|
feature_columns = []
|
||||||
|
for key in train_features:
|
||||||
|
feature_columns.append(tf.feature_column.numeric_column(key=key))
|
||||||
|
|
||||||
|
head = multi_class_head.DPMultiClassHead(3)
|
||||||
|
optimizer = DPKerasSGDOptimizer(
|
||||||
|
learning_rate=0.5,
|
||||||
|
l2_norm_clip=1.0,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=2)
|
||||||
|
model_fn = self._make_model_fn(head, optimizer, feature_columns)
|
||||||
|
classifier = tf.estimator.Estimator(model_fn=model_fn)
|
||||||
|
|
||||||
|
classifier.train(
|
||||||
|
input_fn=self._make_input_fn(train_features, train_labels, True),
|
||||||
|
steps=4)
|
||||||
|
|
||||||
|
test_features, test_labels = self._make_input_data(64)
|
||||||
|
classifier.evaluate(
|
||||||
|
input_fn=self._make_input_fn(test_features, test_labels, False),
|
||||||
|
steps=4)
|
||||||
|
|
||||||
|
predict_features, predict_labels_ = self._make_input_data(64)
|
||||||
|
predictions = classifier.predict(
|
||||||
|
input_fn=self._make_input_fn(predict_features, predict_labels_, False))
|
||||||
|
for p in predictions:
|
||||||
|
print('schien p: ', p)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -60,7 +60,7 @@ def make_keras_optimizer_class(cls):
|
||||||
self._num_microbatches = num_microbatches
|
self._num_microbatches = num_microbatches
|
||||||
self._dp_sum_query = gaussian_query.GaussianSumQuery(
|
self._dp_sum_query = gaussian_query.GaussianSumQuery(
|
||||||
l2_norm_clip, l2_norm_clip * noise_multiplier)
|
l2_norm_clip, l2_norm_clip * noise_multiplier)
|
||||||
self._global_state = self._dp_sum_query.initial_global_state()
|
self._global_state = None
|
||||||
|
|
||||||
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
|
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
|
||||||
"""DP version of superclass method."""
|
"""DP version of superclass method."""
|
||||||
|
@ -119,6 +119,9 @@ def make_keras_optimizer_class(cls):
|
||||||
def get_gradients(self, loss, params):
|
def get_gradients(self, loss, params):
|
||||||
"""DP version of superclass method."""
|
"""DP version of superclass method."""
|
||||||
|
|
||||||
|
if self._global_state is None:
|
||||||
|
self._global_state = self._dp_sum_query.initial_global_state()
|
||||||
|
|
||||||
# This code mostly follows the logic in the original DPOptimizerClass
|
# This code mostly follows the logic in the original DPOptimizerClass
|
||||||
# in dp_optimizer.py, except that this returns only the gradients,
|
# in dp_optimizer.py, except that this returns only the gradients,
|
||||||
# not the gradients and variables.
|
# not the gradients and variables.
|
||||||
|
|
Loading…
Reference in a new issue