From d703168de2db15936e71b8c7de152b649d3b337b Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Wed, 7 Oct 2020 14:31:31 -0700 Subject: [PATCH] Add TF1-compatible version of DP canned estimators, and some small cleanup.. PiperOrigin-RevId: 335954269 --- .../privacy/estimators/dnn_test.py | 2 +- .../privacy/estimators/v1/BUILD | 52 ++ .../privacy/estimators/v1/dnn.py | 68 +++ .../privacy/estimators/v1/dnn_test.py | 72 +++ .../privacy/estimators/v1/head.py | 458 ++++++++++++++++++ .../privacy/estimators/v1/head_test.py | 95 ++++ .../privacy/optimizers/dp_optimizer.py | 5 +- 7 files changed, 750 insertions(+), 2 deletions(-) create mode 100644 tensorflow_privacy/privacy/estimators/v1/BUILD create mode 100644 tensorflow_privacy/privacy/estimators/v1/dnn.py create mode 100644 tensorflow_privacy/privacy/estimators/v1/dnn_test.py create mode 100644 tensorflow_privacy/privacy/estimators/v1/head.py create mode 100644 tensorflow_privacy/privacy/estimators/v1/head_test.py diff --git a/tensorflow_privacy/privacy/estimators/dnn_test.py b/tensorflow_privacy/privacy/estimators/dnn_test.py index bb7c1ad..19b994e 100644 --- a/tensorflow_privacy/privacy/estimators/dnn_test.py +++ b/tensorflow_privacy/privacy/estimators/dnn_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for DP-enabled binary class heads.""" +"""Tests for DP-enabled DNNClassifier.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow_privacy/privacy/estimators/v1/BUILD b/tensorflow_privacy/privacy/estimators/v1/BUILD new file mode 100644 index 0000000..dda365e --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/v1/BUILD @@ -0,0 +1,52 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +py_library( + name = "head", + srcs = [ + "head.py", + ], + deps = [ + "//third_party/py/tensorflow", + "//third_party/tensorflow_estimator", + ], +) + +py_library( + name = "dnn", + srcs = [ + "dnn.py", + ], + deps = [ + ":head", + "//third_party/py/tensorflow", + "//third_party/tensorflow_estimator", + ], +) + +py_test( + name = "head_test", + timeout = "long", + srcs = ["head_test.py"], + python_version = "PY3", + deps = [ + ":head", + "//third_party/py/tensorflow", + "//third_party/py/tensorflow_privacy/privacy/estimators:test_utils", + "//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer", + ], +) + +py_test( + name = "dnn_test", + timeout = "long", + srcs = ["dnn_test.py"], + python_version = "PY3", + deps = [ + ":dnn", + "//third_party/py/tensorflow", + "//third_party/py/tensorflow_privacy/privacy/estimators:test_utils", + "//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer", + ], +) diff --git a/tensorflow_privacy/privacy/estimators/v1/dnn.py b/tensorflow_privacy/privacy/estimators/v1/dnn.py new file mode 100644 index 0000000..8ad1b94 --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/v1/dnn.py @@ -0,0 +1,68 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""DP version of DNNClassifiers v1.""" + +import tensorflow as tf + +from tensorflow_privacy.privacy.estimators.v1 import head as head_lib +from tensorflow_estimator.python.estimator import estimator +from tensorflow_estimator.python.estimator.canned import dnn + + +class DNNClassifier(tf.estimator.Estimator): + """DP version of tf.estimator.DNNClassifier.""" + + def __init__( + self, + hidden_units, + feature_columns, + model_dir=None, + n_classes=2, + weight_column=None, + label_vocabulary=None, + optimizer='Adagrad', + activation_fn=tf.nn.relu, + dropout=None, + input_layer_partitioner=None, + config=None, + warm_start_from=None, + loss_reduction=tf.compat.v1.losses.Reduction.SUM, + batch_norm=False, + ): + head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access + n_classes, weight_column, label_vocabulary, loss_reduction) + estimator._canned_estimator_api_gauge.get_cell('Classifier').set('DNN') + + def _model_fn(features, labels, mode, config): + """Call the defined shared dnn_model_fn.""" + return dnn._dnn_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + hidden_units=hidden_units, + feature_columns=tuple(feature_columns or []), + optimizer=optimizer, + activation_fn=activation_fn, + dropout=dropout, + input_layer_partitioner=input_layer_partitioner, + config=config, + batch_norm=batch_norm) + + super(DNNClassifier, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config, + warm_start_from=warm_start_from) diff --git a/tensorflow_privacy/privacy/estimators/v1/dnn_test.py b/tensorflow_privacy/privacy/estimators/v1/dnn_test.py new file mode 100644 index 0000000..68a856e --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/v1/dnn_test.py @@ -0,0 +1,72 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for DP-enabled DNNClassifier.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from absl.testing import parameterized +import tensorflow as tf +from tensorflow_privacy.privacy.estimators import test_utils +from tensorflow_privacy.privacy.estimators.v1 import dnn +from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer + + +class DPDNNClassifierTest(tf.test.TestCase, parameterized.TestCase): + """Tests for DP-enabled DNNClassifier.""" + + @parameterized.named_parameters( + ('BinaryClassDNN', 2), + ('MultiClassDNN 3', 3), + ('MultiClassDNN 4', 4), + ) + def testDNN(self, n_classes): + train_features, train_labels = test_utils.make_input_data(256, n_classes) + feature_columns = [] + for key in train_features: + feature_columns.append(tf.feature_column.numeric_column(key=key)) + + optimizer = functools.partial( + DPGradientDescentGaussianOptimizer, + learning_rate=0.5, + l2_norm_clip=1.0, + noise_multiplier=0.0, + num_microbatches=1) + + classifier = dnn.DNNClassifier( + hidden_units=[10], + activation_fn='relu', + feature_columns=feature_columns, + n_classes=n_classes, + optimizer=optimizer, + loss_reduction=tf.losses.Reduction.NONE) + + classifier.train( + input_fn=test_utils.make_input_fn(train_features, train_labels, True, + 16)) + + test_features, test_labels = test_utils.make_input_data(64, n_classes) + classifier.evaluate( + input_fn=test_utils.make_input_fn(test_features, test_labels, False, + 16)) + + predict_features, predict_labels = test_utils.make_input_data(64, n_classes) + classifier.predict( + input_fn=test_utils.make_input_fn(predict_features, predict_labels, + False)) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/estimators/v1/head.py b/tensorflow_privacy/privacy/estimators/v1/head.py new file mode 100644 index 0000000..994ca6b --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/v1/head.py @@ -0,0 +1,458 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Estimator v1 heads that allow integration with TF Privacy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.python.ops import lookup_ops # pylint: disable=g-direct-tensorflow-import +from tensorflow_estimator.python.estimator import model_fn +from tensorflow_estimator.python.estimator.canned import head as head_lib +from tensorflow_estimator.python.estimator.canned import metric_keys +from tensorflow_estimator.python.estimator.canned import prediction_keys +from tensorflow_estimator.python.estimator.export import export_output +from tensorflow_estimator.python.estimator.mode_keys import ModeKeys + +# Collect together all protected access items needed from base head. +# pylint: disable=protected-access +_DEFAULT_SERVING_KEY = head_lib._DEFAULT_SERVING_KEY +_CLASSIFY_SERVING_KEY = head_lib._CLASSIFY_SERVING_KEY +_REGRESS_SERVING_KEY = head_lib._REGRESS_SERVING_KEY +_PREDICT_SERVING_KEY = head_lib._PREDICT_SERVING_KEY + +_all_class_ids = head_lib._all_class_ids +_all_classes = head_lib._all_classes +_append_update_ops = head_lib._append_update_ops +_check_logits_final_dim = head_lib._check_logits_final_dim +_classification_output = head_lib._classification_output +_create_eval_metrics_tuple = head_lib._create_eval_metrics_tuple +_summary_key = head_lib._summary_key +_validate_loss_fn_args = head_lib._validate_loss_fn_args + +_BaseBinaryLogisticHeadWithSigmoidCrossEntropyLoss = head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss +_BaseMultiClassHeadWithSoftmaxCrossEntropyLoss = head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss +# pylint: enable=protected-access + + +def _multi_class_head_with_softmax_cross_entropy_loss( + n_classes, + weight_column=None, + label_vocabulary=None, + loss_reduction=tf.compat.v1.losses.Reduction.SUM, + loss_fn=None, + name=None): + """See `tensorflow_estimator/python/estimator/canned/head.py`.""" + + if label_vocabulary is not None and not isinstance(label_vocabulary, + (list, tuple)): + raise ValueError( + 'label_vocabulary should be a list or a tuple. Given type: {}'.format( + type(label_vocabulary))) + if loss_reduction not in tf.compat.v1.losses.Reduction.all(): + raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) + if loss_fn: + _validate_loss_fn_args(loss_fn) + return _MultiClassHeadWithSoftmaxCrossEntropyLoss( + n_classes=n_classes, + weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction, + loss_fn=loss_fn, + name=name) + + +class _MultiClassHeadWithSoftmaxCrossEntropyLoss( + _BaseMultiClassHeadWithSoftmaxCrossEntropyLoss): + """See `_multi_class_head_with_softmax_cross_entropy_loss`.""" + + def _create_tpu_estimator_spec(self, + features, + mode, + logits, + labels=None, + optimizer=None, + train_op_fn=None, + regularization_losses=None): + """Returns a `model_fn._TPUEstimatorSpec`. + + Args: + features: Input `dict` of `Tensor` or `SparseTensor` objects. + mode: Estimator's `ModeKeys`. + logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`. + For many applications, the shape is `[batch_size, logits_dimension]`. + labels: Labels integer or string `Tensor` with shape matching `logits`, + namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required + argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. + train_op_fn: Function that takes a scalar loss `Tensor` and returns + `train_op`. Used if `optimizer` is `None`. + regularization_losses: A list of additional scalar losses to be added to + the training loss, such as regularization losses. These losses are + usually expressed as a batch average, so for best results users need to + set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid + scaling errors. + + Returns: + A `model_fn._TPUEstimatorSpec` instance. + Raises: + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. + """ + with tf.compat.v1.name_scope(self._name, 'head'): + logits = _check_logits_final_dim(logits, self.logits_dimension) + + # Predict. + pred_keys = prediction_keys.PredictionKeys + with tf.compat.v1.name_scope(None, 'predictions', (logits,)): + all_class_ids = _all_class_ids(logits, self._n_classes) + all_classes = _all_classes( + logits, self._n_classes, label_vocabulary=self._label_vocabulary) + # class_ids's shape is [D0, D1, ... DN]. + class_ids = tf.compat.v1.math.argmax( + logits, axis=-1, name=pred_keys.CLASS_IDS) + class_ids = tf.compat.v1.expand_dims(class_ids, axis=-1) + if self._label_vocabulary: + table = lookup_ops.index_to_string_table_from_tensor( + vocabulary_list=self._label_vocabulary, + name='class_string_lookup') + classes = table.lookup(class_ids) + else: + classes = tf.strings.as_string(class_ids, name='str_classes') + + probabilities = tf.compat.v1.nn.softmax( + logits, name=pred_keys.PROBABILITIES) + predictions = { + pred_keys.LOGITS: logits, + pred_keys.PROBABILITIES: probabilities, + # Expand to [batch_size, 1] + pred_keys.CLASS_IDS: class_ids, + pred_keys.CLASSES: classes, + pred_keys.ALL_CLASS_IDS: all_class_ids, + pred_keys.ALL_CLASSES: all_classes, + } + if mode == ModeKeys.PREDICT: + classifier_output = _classification_output( + scores=probabilities, + n_classes=self._n_classes, + label_vocabulary=self._label_vocabulary) + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.PREDICT, + predictions=predictions, + export_outputs={ + _DEFAULT_SERVING_KEY: classifier_output, + _CLASSIFY_SERVING_KEY: classifier_output, + _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions) + }) + + training_loss, unreduced_loss, weights, label_ids = self.create_loss( + features=features, mode=mode, logits=logits, labels=labels) + if regularization_losses: + regularization_loss = tf.math.add_n(regularization_losses) + regularized_training_loss = tf.math.add_n( + [training_loss, regularization_loss]) + else: + regularization_loss = None + regularized_training_loss = training_loss + + if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE: + scalar_loss = tf.reduce_mean(regularized_training_loss) + else: + scalar_loss = regularized_training_loss + + # Eval. + if mode == ModeKeys.EVAL: + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.EVAL, + predictions=predictions, + loss=scalar_loss, + eval_metrics=_create_eval_metrics_tuple( + self._eval_metric_ops, { + 'labels': label_ids, + 'class_ids': class_ids, + 'weights': weights, + 'unreduced_loss': unreduced_loss, + 'regularization_loss': regularization_loss + })) + + # Train. + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=tf.compat.v1.train.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = _append_update_ops(train_op) + # Only summarize mean_loss for SUM reduction to preserve backwards + # compatibility. Otherwise skip it to avoid unnecessary computation. + if self._loss_reduction == tf.compat.v1.losses.Reduction.SUM: + example_weight_sum = tf.math.reduce_sum( + weights * tf.compat.v1.ones_like(unreduced_loss)) + mean_loss = training_loss / example_weight_sum + else: + mean_loss = None + with tf.compat.v1.name_scope(''): + keys = metric_keys.MetricKeys + tf.compat.v1.summary.scalar( + _summary_key(self._name, keys.LOSS), scalar_loss) + if mean_loss is not None: + tf.compat.v1.summary.scalar( + _summary_key(self._name, keys.LOSS_MEAN), mean_loss) + if regularization_loss is not None: + tf.compat.v1.summary.scalar( + _summary_key(self._name, keys.LOSS_REGULARIZATION), + regularization_loss) + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.TRAIN, + predictions=predictions, + loss=scalar_loss, + train_op=train_op) + + +def _binary_logistic_head_with_sigmoid_cross_entropy_loss( + weight_column=None, + thresholds=None, + label_vocabulary=None, + loss_reduction=tf.compat.v1.losses.Reduction.SUM, + loss_fn=None, + name=None): + """See `tensorflow_estimator/python/estimator/canned/head.py`.""" + + thresholds = tuple(thresholds) if thresholds else tuple() + if label_vocabulary is not None and not isinstance(label_vocabulary, + (list, tuple)): + raise TypeError( + 'label_vocabulary should be a list or tuple. Given type: {}'.format( + type(label_vocabulary))) + + for threshold in thresholds: + if (threshold <= 0.0) or (threshold >= 1.0): + raise ValueError('thresholds not in (0, 1): {}.'.format((thresholds,))) + if loss_reduction not in tf.compat.v1.losses.Reduction.all(): + raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) + if loss_fn: + _validate_loss_fn_args(loss_fn) + return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss( + weight_column=weight_column, + thresholds=thresholds, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction, + loss_fn=loss_fn, + name=name) + + +class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss( + _BaseBinaryLogisticHeadWithSigmoidCrossEntropyLoss): + """DP version of `_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`.""" + + def _create_tpu_estimator_spec(self, + features, + mode, + logits, + labels=None, + optimizer=None, + train_op_fn=None, + regularization_losses=None): + """Returns an `EstimatorSpec`. + + Args: + features: Input `dict` of `Tensor` or `SparseTensor` objects. + mode: Estimator's `ModeKeys`. + logits: logits `Tensor` with shape `[D0, D1, ... DN, 1]`. For many + applications, the shape is `[batch_size, 1]`. + labels: Labels integer or string `Tensor` with shape matching `logits`, + namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required + argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. + train_op_fn: Function that takes a scalar loss `Tensor` and returns + `train_op`. Used if `optimizer` is `None`. + regularization_losses: A list of additional scalar losses to be added to + the training loss, such as regularization losses. These losses are + usually expressed as a batch average, so for best results users need to + set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid + scaling errors. + + Returns: + `EstimatorSpec`. + Raises: + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. + """ + # Predict. + with tf.compat.v1.name_scope(self._name, 'head'): + with tf.compat.v1.name_scope(None, 'predictions', (logits,)): + pred_keys = prediction_keys.PredictionKeys + logits = _check_logits_final_dim(logits, self.logits_dimension) + logistic = tf.math.sigmoid(logits, name=pred_keys.LOGISTIC) + two_class_logits = tf.concat((tf.compat.v1.zeros_like(logits), logits), + axis=-1, + name='two_class_logits') + probabilities = tf.compat.v1.nn.softmax( + two_class_logits, name=pred_keys.PROBABILITIES) + class_ids = tf.compat.v1.math.argmax( + two_class_logits, axis=-1, name=pred_keys.CLASS_IDS) + class_ids = tf.compat.v1.expand_dims(class_ids, axis=-1) + all_class_ids = _all_class_ids(logits, n_classes=2) + all_classes = _all_classes( + logits, n_classes=2, label_vocabulary=self._label_vocabulary) + + if self._label_vocabulary: + table = lookup_ops.index_to_string_table_from_tensor( + vocabulary_list=self._label_vocabulary, + name='class_string_lookup') + classes = table.lookup(class_ids) + else: + classes = tf.strings.as_string(class_ids, name='str_classes') + predictions = { + pred_keys.LOGITS: logits, + pred_keys.LOGISTIC: logistic, + pred_keys.PROBABILITIES: probabilities, + pred_keys.CLASS_IDS: class_ids, + pred_keys.CLASSES: classes, + pred_keys.ALL_CLASS_IDS: all_class_ids, + pred_keys.ALL_CLASSES: all_classes, + } + if mode == ModeKeys.PREDICT: + classifier_output = _classification_output( + scores=probabilities, + n_classes=2, + label_vocabulary=self._label_vocabulary) + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.PREDICT, + predictions=predictions, + export_outputs={ + _DEFAULT_SERVING_KEY: classifier_output, + _CLASSIFY_SERVING_KEY: classifier_output, + _REGRESS_SERVING_KEY: export_output.RegressionOutput( + value=logistic), + _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions) + }) + + (training_loss, unreduced_loss, weights, processed_labels) = ( + self.create_loss( + features=features, mode=mode, logits=logits, labels=labels)) + if regularization_losses: + regularization_loss = tf.math.add_n(regularization_losses) + regularized_training_loss = tf.math.add_n( + [training_loss, regularization_loss]) + else: + regularization_loss = None + regularized_training_loss = training_loss + + if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE: + scalar_loss = tf.reduce_mean(regularized_training_loss) + else: + scalar_loss = regularized_training_loss + # Eval. + if mode == ModeKeys.EVAL: + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.EVAL, + predictions=predictions, + loss=scalar_loss, + eval_metrics=_create_eval_metrics_tuple( + self._eval_metric_ops, { + 'labels': processed_labels, + 'logits': logits, + 'logistic': logistic, + 'class_ids': class_ids, + 'weights': weights, + 'unreduced_loss': unreduced_loss, + 'regularization_loss': regularization_loss + })) + + # Train. + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=tf.compat.v1.train.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = _append_update_ops(train_op) + # Only summarize mean_loss for SUM reduction to preserve backwards + # compatibility. Otherwise skip it to avoid unnecessary computation. + if self._loss_reduction == tf.compat.v1.losses.Reduction.SUM: + example_weight_sum = tf.math.reduce_sum( + weights * tf.compat.v1.ones_like(unreduced_loss)) + mean_loss = training_loss / example_weight_sum + else: + mean_loss = None + with tf.compat.v1.name_scope(''): + keys = metric_keys.MetricKeys + tf.compat.v1.summary.scalar( + _summary_key(self._name, keys.LOSS), scalar_loss) + if mean_loss is not None: + tf.compat.v1.summary.scalar( + _summary_key(self._name, keys.LOSS_MEAN), mean_loss) + if regularization_loss is not None: + tf.compat.v1.summary.scalar( + _summary_key(self._name, keys.LOSS_REGULARIZATION), + regularization_loss) + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.TRAIN, + predictions=predictions, + loss=scalar_loss, + train_op=train_op) + + +def _binary_logistic_or_multi_class_head(n_classes, weight_column, + label_vocabulary, loss_reduction): + """Creates either binary or multi-class head. + + Args: + n_classes: Number of label classes. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then + weight_column.normalizer_fn is applied on it to get weight tensor. + label_vocabulary: A list of strings represents possible label values. If + given, labels must be string type and have any value in + `label_vocabulary`. If it is not given, that means labels are already + encoded as integer or float within [0, 1] for `n_classes=2` and encoded as + integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there + will be errors if vocabulary is not provided and labels are string. + loss_reduction: Describes how to reduce training loss over batch. + Defaults to `SUM`. + + Returns: + `head._Head` instance. + """ + if n_classes == 2: + head = _binary_logistic_head_with_sigmoid_cross_entropy_loss( + weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) + else: + head = _multi_class_head_with_softmax_cross_entropy_loss( + n_classes, + weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) + return head diff --git a/tensorflow_privacy/privacy/estimators/v1/head_test.py b/tensorflow_privacy/privacy/estimators/v1/head_test.py new file mode 100644 index 0000000..db3d4ae --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/v1/head_test.py @@ -0,0 +1,95 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for DP-enabled binary class heads.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import tensorflow as tf +from tensorflow_privacy.privacy.estimators import test_utils +from tensorflow_privacy.privacy.estimators.v1 import head as head_lib +from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer + + +def make_model_fn(head, optimizer, feature_columns): + """Constructs and returns a model_fn using supplied head.""" + + def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument + feature_layer = tf.keras.layers.DenseFeatures(feature_columns) + inputs = feature_layer(features) + hidden_layer = tf.keras.layers.Dense(units=3, activation='relu') + hidden_layer_values = hidden_layer(inputs) + logits_layer = tf.keras.layers.Dense( + units=head.logits_dimension, activation=None) + logits = logits_layer(hidden_layer_values) + return head.create_estimator_spec( + features=features, + labels=labels, + mode=mode, + logits=logits, + optimizer=optimizer) + + return model_fn + + +class DPHeadTest(tf.test.TestCase, parameterized.TestCase): + """Tests for DP-enabled heads.""" + + # Parameters for testing: n_classes. + @parameterized.named_parameters( + ('Binary', 2), + ('MultiClass 3', 3), + ('MultiClass 4', 4), + ) + def testCreateTPUEstimatorSpec(self, n_classes): + """Tests that an Estimator built with a binary head works.""" + + train_features, train_labels = test_utils.make_input_data(256, n_classes) + feature_columns = [] + for key in train_features: + feature_columns.append(tf.feature_column.numeric_column(key=key)) + + head = head_lib._binary_logistic_or_multi_class_head( + n_classes=n_classes, + weight_column=None, + label_vocabulary=None, + loss_reduction=tf.compat.v1.losses.Reduction.NONE) + optimizer = DPGradientDescentGaussianOptimizer( + learning_rate=0.5, + l2_norm_clip=1.0, + noise_multiplier=0.0, + num_microbatches=2) + model_fn = make_model_fn(head, optimizer, feature_columns) + classifier = tf.estimator.Estimator(model_fn=model_fn) + + classifier.train( + input_fn=test_utils.make_input_fn(train_features, train_labels, True), + steps=4) + + test_features, test_labels = test_utils.make_input_data(64, n_classes) + classifier.evaluate( + input_fn=test_utils.make_input_fn(test_features, test_labels, False), + steps=4) + + predict_features, predict_labels = test_utils.make_input_data(64, n_classes) + classifier.predict( + input_fn=test_utils.make_input_fn(predict_features, predict_labels, + False)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py index b970c4a..1e4c281 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py @@ -65,7 +65,7 @@ def make_optimizer_class(cls): super(DPOptimizerClass, self).__init__(*args, **kwargs) self._dp_sum_query = dp_sum_query self._num_microbatches = num_microbatches - self._global_state = self._dp_sum_query.initial_global_state() + self._global_state = None # TODO(b/122613513): Set unroll_microbatches=True to avoid this bug. # Beware: When num_microbatches is large (>100), enabling this parameter # may cause an OOM error. @@ -81,6 +81,9 @@ def make_optimizer_class(cls): grad_loss=None, gradient_tape=None): self._was_compute_gradients_called = True + if self._global_state is None: + self._global_state = self._dp_sum_query.initial_global_state() + if callable(loss): # TF is running in Eager mode, check we received a vanilla tape. if not gradient_tape: