From 3a641e077eb5a68abfb87553b2eccbd991cca8c3 Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Mon, 10 Aug 2020 17:19:05 -0700 Subject: [PATCH] Add DP-enabled binary-class head and multi-class heads for Estimator. PiperOrigin-RevId: 325921076 --- tensorflow_privacy/privacy/estimators/BUILD | 68 ++++++++ .../privacy/estimators/binary_class_head.py | 147 +++++++++++++++++ .../estimators/binary_class_head_test.py | 136 ++++++++++++++++ .../privacy/estimators/head_utils.py | 60 +++++++ .../privacy/estimators/multi_class_head.py | 146 +++++++++++++++++ .../estimators/multi_class_head_test.py | 152 ++++++++++++++++++ .../privacy/optimizers/dp_optimizer_keras.py | 5 +- 7 files changed, 713 insertions(+), 1 deletion(-) create mode 100644 tensorflow_privacy/privacy/estimators/BUILD create mode 100644 tensorflow_privacy/privacy/estimators/binary_class_head.py create mode 100644 tensorflow_privacy/privacy/estimators/binary_class_head_test.py create mode 100644 tensorflow_privacy/privacy/estimators/head_utils.py create mode 100644 tensorflow_privacy/privacy/estimators/multi_class_head.py create mode 100644 tensorflow_privacy/privacy/estimators/multi_class_head_test.py diff --git a/tensorflow_privacy/privacy/estimators/BUILD b/tensorflow_privacy/privacy/estimators/BUILD new file mode 100644 index 0000000..3d72ad6 --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/BUILD @@ -0,0 +1,68 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "head_utils", + srcs = [ + "head_utils.py", + ], + deps = [ + ":binary_class_head", + ":multi_class_head", + ], +) + +py_library( + name = "binary_class_head", + srcs = [ + "binary_class_head.py", + ], + deps = [ + "//third_party/py/tensorflow", + # TODO(b/163395075): Remove this dependency once necessary function is public. + "//third_party/tensorflow/python:keras_lib", + "//third_party/tensorflow_estimator", + ], +) + +py_library( + name = "multi_class_head", + srcs = [ + "multi_class_head.py", + ], + deps = [ + "//third_party/py/tensorflow", + # TODO(b/163395075): Remove this dependency once necessary function is public. + "//third_party/tensorflow/python:keras_lib", + "//third_party/tensorflow_estimator", + ], +) + +py_test( + name = "binary_class_head_test", + timeout = "long", + srcs = ["binary_class_head_test.py"], + python_version = "PY3", + deps = [ + ":binary_class_head", + "//third_party/py/absl/testing:parameterized", + "//third_party/py/six", + "//third_party/py/tensorflow", + "//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", + ], +) + +py_test( + name = "multi_class_head_test", + timeout = "long", + srcs = ["multi_class_head_test.py"], + python_version = "PY3", + deps = [ + ":multi_class_head", + "//third_party/py/absl/testing:parameterized", + "//third_party/py/six", + "//third_party/py/tensorflow", + "//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", + ], +) diff --git a/tensorflow_privacy/privacy/estimators/binary_class_head.py b/tensorflow_privacy/privacy/estimators/binary_class_head.py new file mode 100644 index 0000000..259b553 --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/binary_class_head.py @@ -0,0 +1,147 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Binary class head for Estimator that allow integration with TF Privacy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import +from tensorflow_estimator.python.estimator import model_fn +from tensorflow_estimator.python.estimator.canned import prediction_keys +from tensorflow_estimator.python.estimator.export import export_output +from tensorflow_estimator.python.estimator.head import base_head +from tensorflow_estimator.python.estimator.mode_keys import ModeKeys + + +class DPBinaryClassHead(tf.estimator.BinaryClassHead): + """Creates a TF Privacy-enabled version of BinaryClassHead.""" + + def __init__(self, + weight_column=None, + thresholds=None, + label_vocabulary=None, + loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, + loss_fn=None, + name=None): + super(DPBinaryClassHead, self).__init__( + weight_column=weight_column, + thresholds=thresholds, + label_vocabulary=label_vocabulary, + loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, + loss_fn=loss_fn, + name=name) + + def loss(self, + labels, + logits, + features=None, + mode=None, + regularization_losses=None): + """Returns regularized training loss. See `base_head.Head` for details.""" + del mode # Unused for this head. + with tf.compat.v1.name_scope( + 'losses', values=(logits, labels, regularization_losses, features)): + logits = base_head.check_logits_final_dim(logits, self.logits_dimension) + labels = self._processed_labels(logits, labels) + unweighted_loss, weights = self._unweighted_loss_and_weights( + logits, labels, features) + vector_training_loss = losses_utils.compute_weighted_loss( + unweighted_loss, + sample_weight=weights, + reduction=tf.keras.losses.Reduction.NONE) + regularization_loss = tf.math.add_n( + regularization_losses) if regularization_losses is not None else None + vector_regularized_training_loss = ( + tf.add(vector_training_loss, regularization_loss) + if regularization_loss is not None else vector_training_loss) + + return vector_regularized_training_loss + + def _create_tpu_estimator_spec(self, + features, + mode, + logits, + labels=None, + optimizer=None, + trainable_variables=None, + train_op_fn=None, + update_ops=None, + regularization_losses=None): + """See superclass for description.""" + + with tf.compat.v1.name_scope(self._name, 'head'): + # Predict. + pred_keys = prediction_keys.PredictionKeys + predictions = self.predictions(logits) + if mode == ModeKeys.PREDICT: + probabilities = predictions[pred_keys.PROBABILITIES] + logistic = predictions[pred_keys.LOGISTIC] + classifier_output = base_head.classification_output( + scores=probabilities, + n_classes=2, + label_vocabulary=self._label_vocabulary) + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.PREDICT, + predictions=predictions, + export_outputs={ + base_head.DEFAULT_SERVING_KEY: classifier_output, + base_head.CLASSIFY_SERVING_KEY: classifier_output, + base_head.REGRESS_SERVING_KEY: + export_output.RegressionOutput(value=logistic), + base_head.PREDICT_SERVING_KEY: + export_output.PredictOutput(predictions) + }) + regularized_training_loss = self.loss( + logits=logits, + labels=labels, + features=features, + mode=mode, + regularization_losses=regularization_losses) + scalar_loss = tf.reduce_mean(regularized_training_loss) + # Eval. + if mode == ModeKeys.EVAL: + eval_metrics = self.metrics(regularization_losses=regularization_losses) + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.EVAL, + predictions=predictions, + loss=scalar_loss, + eval_metrics=base_head.create_eval_metrics_tuple( + self.update_metrics, { + 'eval_metrics': eval_metrics, + 'features': features, + 'logits': logits, + 'labels': labels, + 'regularization_losses': regularization_losses + })) + # Train. + train_op = base_head.create_estimator_spec_train_op( + head_name=self._name, + optimizer=optimizer, + train_op_fn=train_op_fn, + update_ops=update_ops, + trainable_variables=trainable_variables, + regularized_training_loss=regularized_training_loss, + loss_reduction=self._loss_reduction) + # Create summary. + base_head.create_estimator_spec_summary( + regularized_training_loss=scalar_loss, + regularization_losses=regularization_losses, + summary_key_fn=self._summary_key) + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.TRAIN, + predictions=predictions, + loss=scalar_loss, + train_op=train_op) diff --git a/tensorflow_privacy/privacy/estimators/binary_class_head_test.py b/tensorflow_privacy/privacy/estimators/binary_class_head_test.py new file mode 100644 index 0000000..c7024d6 --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/binary_class_head_test.py @@ -0,0 +1,136 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for DP-enabled binary class heads.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# from absl.testing import parameterized +import numpy as np +import tensorflow as tf +from tensorflow_privacy.privacy.estimators import binary_class_head +from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer + + +class DPBinaryClassHeadTest(tf.test.TestCase): + """Tests for DP-enabled heads.""" + + def _make_input_data(self, size): + """Create raw input data.""" + feature_a = np.random.normal(4, 1, (size)) + feature_b = np.random.normal(5, 0.7, (size)) + feature_c = np.random.normal(6, 2, (size)) + noise = np.random.normal(0, 30, (size)) + features = { + 'feature_a': feature_a, + 'feature_b': feature_b, + 'feature_c': feature_c, + } + labels = np.array( + np.power(feature_a, 3) + np.power(feature_b, 2) + + np.power(feature_c, 1) + noise > 125).astype(int) + return features, labels + + def _make_input_fn(self, features, labels, training, batch_size=16): + + def input_fn(): + """An input function for training or evaluating.""" + # Convert the inputs to a Dataset. + dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) + + # Shuffle if in training mode. + if training: + dataset = dataset.shuffle(1000) + + return dataset.batch(batch_size) + + return input_fn + + def _make_model_fn(self, head, optimizer, feature_columns): + """Constructs and returns a model_fn using DPBinaryClassHead.""" + + def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument + feature_layer = tf.keras.layers.DenseFeatures(feature_columns) + inputs = feature_layer(features) + hidden_layer = tf.keras.layers.Dense(units=3, activation='relu') + hidden_layer_values = hidden_layer(inputs) + logits_layer = tf.keras.layers.Dense( + units=head.logits_dimension, activation=None) + logits = logits_layer(hidden_layer_values) + return head.create_estimator_spec( + features=features, + labels=labels, + mode=mode, + logits=logits, + trainable_variables=hidden_layer.trainable_weights + + logits_layer.trainable_weights, + optimizer=optimizer) + + return model_fn + + def testLoss(self): + """Tests loss() returns per-example losses.""" + + head = binary_class_head.DPBinaryClassHead() + features = {'feature_a': np.full((4), 1.0)} + labels = np.array([[1.0], [1.0], [1.0], [0.0]]) + logits = np.full((4, 1), 0.5) + + actual_loss = head.loss(labels, logits, features) + expected_loss = tf.nn.sigmoid_cross_entropy_with_logits( + labels=labels, logits=logits) + + self.assertEqual(actual_loss.shape, [4, 1]) + + if tf.executing_eagerly(): + self.assertEqual(actual_loss.shape, [4, 1]) + self.assertAllClose(actual_loss, expected_loss) + return + + self.assertAllClose(expected_loss, self.evaluate(actual_loss)) + + def testCreateTPUEstimatorSpec(self): + """Tests that an Estimator built with this head works.""" + + train_features, train_labels = self._make_input_data(256) + feature_columns = [] + for key in train_features: + feature_columns.append(tf.feature_column.numeric_column(key=key)) + + head = binary_class_head.DPBinaryClassHead() + optimizer = DPKerasSGDOptimizer( + learning_rate=0.5, + l2_norm_clip=1.0, + noise_multiplier=0.0, + num_microbatches=2) + model_fn = self._make_model_fn(head, optimizer, feature_columns) + classifier = tf.estimator.Estimator(model_fn=model_fn) + + classifier.train( + input_fn=self._make_input_fn(train_features, train_labels, True), + steps=4) + + test_features, test_labels = self._make_input_data(64) + classifier.evaluate( + input_fn=self._make_input_fn(test_features, test_labels, False), + steps=4) + + predict_features, predict_labels_ = self._make_input_data(64) + classifier.predict( + input_fn=self._make_input_fn(predict_features, predict_labels_, False)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/estimators/head_utils.py b/tensorflow_privacy/privacy/estimators/head_utils.py new file mode 100644 index 0000000..0b4723d --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/head_utils.py @@ -0,0 +1,60 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Estimator heads that allow integration with TF Privacy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow_privacy.privacy.estimators.binary_class_head import DPBinaryClassHead +from tensorflow_privacy.privacy.estimators.multi_class_head import DPMultiClassHead + + +def binary_or_multi_class_head(n_classes, weight_column, label_vocabulary, + loss_reduction): + """Creates either binary or multi-class head. + + Args: + n_classes: Number of label classes. + weight_column: A string or a `NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `NumericColumn`, raw tensor is fetched by key `weight_column.key`, then + weight_column.normalizer_fn is applied on it to get weight tensor. + label_vocabulary: A list of strings represents possible label values. If + given, labels must be string type and have any value in + `label_vocabulary`. If it is not given, that means labels are already + encoded as integer or float within [0, 1] for `n_classes=2` and encoded as + integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there + will be errors if vocabulary is not provided and labels are string. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Defines how to + reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`. + + Returns: + A `Head` instance. + """ + if n_classes == 2: + head = DPBinaryClassHead( + weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) + else: + head = DPMultiClassHead( + n_classes, + weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) + return head diff --git a/tensorflow_privacy/privacy/estimators/multi_class_head.py b/tensorflow_privacy/privacy/estimators/multi_class_head.py new file mode 100644 index 0000000..2bbbe2d --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/multi_class_head.py @@ -0,0 +1,146 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multiclass head for Estimator that allow integration with TF Privacy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import +from tensorflow_estimator.python.estimator import model_fn +from tensorflow_estimator.python.estimator.canned import prediction_keys +from tensorflow_estimator.python.estimator.export import export_output +from tensorflow_estimator.python.estimator.head import base_head +from tensorflow_estimator.python.estimator.mode_keys import ModeKeys + + +class DPMultiClassHead(tf.estimator.MultiClassHead): + """Creates a TF Privacy-enabled version of MultiClassHead.""" + + def __init__(self, + n_classes, + weight_column=None, + label_vocabulary=None, + loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + loss_fn=None, + name=None): + super(DPMultiClassHead, self).__init__( + n_classes=n_classes, + weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + loss_fn=loss_fn, + name=name) + + def loss(self, + labels, + logits, + features=None, + mode=None, + regularization_losses=None): + """Returns regularized training loss. See `base_head.Head` for details.""" + del mode # Unused for this head. + with tf.compat.v1.name_scope( + 'losses', values=(logits, labels, regularization_losses, features)): + logits = base_head.check_logits_final_dim(logits, self.logits_dimension) + labels = self._processed_labels(logits, labels) + unweighted_loss, weights = self._unweighted_loss_and_weights( + logits, labels, features) + vector_training_loss = losses_utils.compute_weighted_loss( + unweighted_loss, + sample_weight=weights, + reduction=tf.keras.losses.Reduction.NONE) + regularization_loss = tf.math.add_n( + regularization_losses) if regularization_losses is not None else None + vector_regularized_training_loss = ( + tf.add(vector_training_loss, regularization_loss) + if regularization_loss is not None else vector_training_loss) + + return vector_regularized_training_loss + + def _create_tpu_estimator_spec(self, + features, + mode, + logits, + labels=None, + optimizer=None, + trainable_variables=None, + train_op_fn=None, + update_ops=None, + regularization_losses=None): + """See superclass for description.""" + + with tf.compat.v1.name_scope(self._name, 'head'): + # Predict. + pred_keys = prediction_keys.PredictionKeys + predictions = self.predictions(logits) + if mode == ModeKeys.PREDICT: + probabilities = predictions[pred_keys.PROBABILITIES] + classifier_output = base_head.classification_output( + scores=probabilities, + n_classes=self._n_classes, + label_vocabulary=self._label_vocabulary) + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.PREDICT, + predictions=predictions, + export_outputs={ + base_head.DEFAULT_SERVING_KEY: + classifier_output, + base_head.CLASSIFY_SERVING_KEY: + classifier_output, + base_head.PREDICT_SERVING_KEY: + export_output.PredictOutput(predictions) + }) + regularized_training_loss = self.loss( + logits=logits, + labels=labels, + features=features, + mode=mode, + regularization_losses=regularization_losses) + scalar_loss = tf.reduce_mean(regularized_training_loss) + # Eval. + if mode == ModeKeys.EVAL: + eval_metrics = self.metrics(regularization_losses=regularization_losses) + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.EVAL, + predictions=predictions, + loss=scalar_loss, + eval_metrics=base_head.create_eval_metrics_tuple( + self.update_metrics, { + 'eval_metrics': eval_metrics, + 'features': features, + 'logits': logits, + 'labels': labels, + 'regularization_losses': regularization_losses + })) + # Train. + train_op = base_head.create_estimator_spec_train_op( + head_name=self._name, + optimizer=optimizer, + train_op_fn=train_op_fn, + update_ops=update_ops, + trainable_variables=trainable_variables, + regularized_training_loss=regularized_training_loss, + loss_reduction=self._loss_reduction) + # Create summary. + base_head.create_estimator_spec_summary( + regularized_training_loss=scalar_loss, + regularization_losses=regularization_losses, + summary_key_fn=self._summary_key) + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access + mode=ModeKeys.TRAIN, + predictions=predictions, + loss=scalar_loss, + train_op=train_op) diff --git a/tensorflow_privacy/privacy/estimators/multi_class_head_test.py b/tensorflow_privacy/privacy/estimators/multi_class_head_test.py new file mode 100644 index 0000000..ed5f98c --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/multi_class_head_test.py @@ -0,0 +1,152 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for DP-enabled binary class heads.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# from absl.testing import parameterized +import numpy as np +import tensorflow as tf +from tensorflow_privacy.privacy.estimators import multi_class_head +from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer + + +class DPMultiClassHeadTest(tf.test.TestCase): + """Tests for DP-enabled heads.""" + + def _make_input_data(self, size): + """Create raw input data.""" + feature_a = np.random.normal(4, 1, (size)) + feature_b = np.random.normal(5, 0.7, (size)) + feature_c = np.random.normal(6, 2, (size)) + noise = np.random.normal(0, 30, (size)) + features = { + 'feature_a': feature_a, + 'feature_b': feature_b, + 'feature_c': feature_c, + } + + def label_fn(x): + if x < 110.0: + return 0 + elif x < 140.0: + return 1 + else: + return 2 + + labels_list = map( + label_fn, + np.power(feature_a, 3) + np.power(feature_b, 2) + + np.power(feature_c, 1) + noise) + return features, list(labels_list) + + def _make_input_fn(self, features, labels, training, batch_size=16): + + def input_fn(): + """An input function for training or evaluating.""" + # Convert the inputs to a Dataset. + dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) + + # Shuffle if in training mode. + if training: + dataset = dataset.shuffle(1000) + + return dataset.batch(batch_size) + + return input_fn + + def _make_model_fn(self, head, optimizer, feature_columns): + """Constructs and returns a model_fn using DPBinaryClassHead.""" + + def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument + feature_layer = tf.keras.layers.DenseFeatures(feature_columns) + inputs = feature_layer(features) + hidden_layer = tf.keras.layers.Dense(units=3, activation='relu') + hidden_layer_values = hidden_layer(inputs) + logits_layer = tf.keras.layers.Dense( + units=head.logits_dimension, activation=None) + logits = logits_layer(hidden_layer_values) + return head.create_estimator_spec( + features=features, + labels=labels, + mode=mode, + logits=logits, + trainable_variables=hidden_layer.trainable_weights + + logits_layer.trainable_weights, + optimizer=optimizer) + + return model_fn + + def testLoss(self): + """Tests loss() returns per-example losses.""" + + head = multi_class_head.DPMultiClassHead(3) + features = {'feature_a': np.full((4), 1.0)} + labels = np.array([[2], [1], [1], [0]]) + logits = np.array([[2.0, 1.5, 4.1], [2.0, 1.5, 4.1], [2.0, 1.5, 4.1], + [2.0, 1.5, 4.1]]) + + actual_loss = head.loss(labels, logits, features) + expected_loss = tf.expand_dims( + tf.compat.v1.losses.sparse_softmax_cross_entropy( + labels=labels, + logits=logits, + reduction=tf.keras.losses.Reduction.NONE), -1) + + self.assertEqual(actual_loss.shape, [4, 1]) + + if tf.executing_eagerly(): + self.assertEqual(actual_loss.shape, [4, 1]) + self.assertAllClose(actual_loss, expected_loss) + return + + self.assertAllClose(expected_loss, self.evaluate(actual_loss)) + + def testCreateTPUEstimatorSpec(self): + """Tests that an Estimator built with this head works.""" + + train_features, train_labels = self._make_input_data(256) + feature_columns = [] + for key in train_features: + feature_columns.append(tf.feature_column.numeric_column(key=key)) + + head = multi_class_head.DPMultiClassHead(3) + optimizer = DPKerasSGDOptimizer( + learning_rate=0.5, + l2_norm_clip=1.0, + noise_multiplier=0.0, + num_microbatches=2) + model_fn = self._make_model_fn(head, optimizer, feature_columns) + classifier = tf.estimator.Estimator(model_fn=model_fn) + + classifier.train( + input_fn=self._make_input_fn(train_features, train_labels, True), + steps=4) + + test_features, test_labels = self._make_input_data(64) + classifier.evaluate( + input_fn=self._make_input_fn(test_features, test_labels, False), + steps=4) + + predict_features, predict_labels_ = self._make_input_data(64) + predictions = classifier.predict( + input_fn=self._make_input_fn(predict_features, predict_labels_, False)) + for p in predictions: + print('schien p: ', p) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py index 2fff9c2..157043d 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py @@ -60,7 +60,7 @@ def make_keras_optimizer_class(cls): self._num_microbatches = num_microbatches self._dp_sum_query = gaussian_query.GaussianSumQuery( l2_norm_clip, l2_norm_clip * noise_multiplier) - self._global_state = self._dp_sum_query.initial_global_state() + self._global_state = None def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): """DP version of superclass method.""" @@ -119,6 +119,9 @@ def make_keras_optimizer_class(cls): def get_gradients(self, loss, params): """DP version of superclass method.""" + if self._global_state is None: + self._global_state = self._dp_sum_query.initial_global_state() + # This code mostly follows the logic in the original DPOptimizerClass # in dp_optimizer.py, except that this returns only the gradients, # not the gradients and variables.