From 99afaed68e01b84285a4e22f537eac8b70b96575 Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Mon, 10 Aug 2020 22:30:00 -0700 Subject: [PATCH] Refactor of common functions in binary and multiclass heads. PiperOrigin-RevId: 325957037 --- tensorflow_privacy/privacy/estimators/BUILD | 22 +++-- .../estimators/binary_class_head_test.py | 71 ++------------ .../estimators/multi_class_head_test.py | 86 +++-------------- .../privacy/estimators/test_utils.py | 93 +++++++++++++++++++ 4 files changed, 128 insertions(+), 144 deletions(-) create mode 100644 tensorflow_privacy/privacy/estimators/test_utils.py diff --git a/tensorflow_privacy/privacy/estimators/BUILD b/tensorflow_privacy/privacy/estimators/BUILD index 3d72ad6..73b0f5c 100644 --- a/tensorflow_privacy/privacy/estimators/BUILD +++ b/tensorflow_privacy/privacy/estimators/BUILD @@ -20,8 +20,7 @@ py_library( ], deps = [ "//third_party/py/tensorflow", - # TODO(b/163395075): Remove this dependency once necessary function is public. - "//third_party/tensorflow/python:keras_lib", + "//third_party/tensorflow/python:keras_lib", # TODO(b/163395075): Remove when fixed. "//third_party/tensorflow_estimator", ], ) @@ -33,12 +32,21 @@ py_library( ], deps = [ "//third_party/py/tensorflow", - # TODO(b/163395075): Remove this dependency once necessary function is public. - "//third_party/tensorflow/python:keras_lib", + "//third_party/tensorflow/python:keras_lib", # TODO(b/163395075): Remove when fixed. "//third_party/tensorflow_estimator", ], ) +py_library( + name = "test_utils", + srcs = [ + "test_utils.py", + ], + deps = [ + "//third_party/py/tensorflow", + ], +) + py_test( name = "binary_class_head_test", timeout = "long", @@ -46,8 +54,7 @@ py_test( python_version = "PY3", deps = [ ":binary_class_head", - "//third_party/py/absl/testing:parameterized", - "//third_party/py/six", + ":test_utils", "//third_party/py/tensorflow", "//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", ], @@ -60,8 +67,7 @@ py_test( python_version = "PY3", deps = [ ":multi_class_head", - "//third_party/py/absl/testing:parameterized", - "//third_party/py/six", + ":test_utils", "//third_party/py/tensorflow", "//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", ], diff --git a/tensorflow_privacy/privacy/estimators/binary_class_head_test.py b/tensorflow_privacy/privacy/estimators/binary_class_head_test.py index c7024d6..ab56db7 100644 --- a/tensorflow_privacy/privacy/estimators/binary_class_head_test.py +++ b/tensorflow_privacy/privacy/estimators/binary_class_head_test.py @@ -21,64 +21,12 @@ from __future__ import print_function import numpy as np import tensorflow as tf from tensorflow_privacy.privacy.estimators import binary_class_head +from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer class DPBinaryClassHeadTest(tf.test.TestCase): - """Tests for DP-enabled heads.""" - - def _make_input_data(self, size): - """Create raw input data.""" - feature_a = np.random.normal(4, 1, (size)) - feature_b = np.random.normal(5, 0.7, (size)) - feature_c = np.random.normal(6, 2, (size)) - noise = np.random.normal(0, 30, (size)) - features = { - 'feature_a': feature_a, - 'feature_b': feature_b, - 'feature_c': feature_c, - } - labels = np.array( - np.power(feature_a, 3) + np.power(feature_b, 2) + - np.power(feature_c, 1) + noise > 125).astype(int) - return features, labels - - def _make_input_fn(self, features, labels, training, batch_size=16): - - def input_fn(): - """An input function for training or evaluating.""" - # Convert the inputs to a Dataset. - dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) - - # Shuffle if in training mode. - if training: - dataset = dataset.shuffle(1000) - - return dataset.batch(batch_size) - - return input_fn - - def _make_model_fn(self, head, optimizer, feature_columns): - """Constructs and returns a model_fn using DPBinaryClassHead.""" - - def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument - feature_layer = tf.keras.layers.DenseFeatures(feature_columns) - inputs = feature_layer(features) - hidden_layer = tf.keras.layers.Dense(units=3, activation='relu') - hidden_layer_values = hidden_layer(inputs) - logits_layer = tf.keras.layers.Dense( - units=head.logits_dimension, activation=None) - logits = logits_layer(hidden_layer_values) - return head.create_estimator_spec( - features=features, - labels=labels, - mode=mode, - logits=logits, - trainable_variables=hidden_layer.trainable_weights + - logits_layer.trainable_weights, - optimizer=optimizer) - - return model_fn + """Tests for DP-enabled binary class heads.""" def testLoss(self): """Tests loss() returns per-example losses.""" @@ -104,7 +52,7 @@ class DPBinaryClassHeadTest(tf.test.TestCase): def testCreateTPUEstimatorSpec(self): """Tests that an Estimator built with this head works.""" - train_features, train_labels = self._make_input_data(256) + train_features, train_labels = test_utils.make_input_data(256, 2) feature_columns = [] for key in train_features: feature_columns.append(tf.feature_column.numeric_column(key=key)) @@ -115,21 +63,22 @@ class DPBinaryClassHeadTest(tf.test.TestCase): l2_norm_clip=1.0, noise_multiplier=0.0, num_microbatches=2) - model_fn = self._make_model_fn(head, optimizer, feature_columns) + model_fn = test_utils.make_model_fn(head, optimizer, feature_columns) classifier = tf.estimator.Estimator(model_fn=model_fn) classifier.train( - input_fn=self._make_input_fn(train_features, train_labels, True), + input_fn=test_utils.make_input_fn(train_features, train_labels, True), steps=4) - test_features, test_labels = self._make_input_data(64) + test_features, test_labels = test_utils.make_input_data(64, 2) classifier.evaluate( - input_fn=self._make_input_fn(test_features, test_labels, False), + input_fn=test_utils.make_input_fn(test_features, test_labels, False), steps=4) - predict_features, predict_labels_ = self._make_input_data(64) + predict_features, predict_labels_ = test_utils.make_input_data(64, 2) classifier.predict( - input_fn=self._make_input_fn(predict_features, predict_labels_, False)) + input_fn=test_utils.make_input_fn(predict_features, predict_labels_, + False)) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/estimators/multi_class_head_test.py b/tensorflow_privacy/privacy/estimators/multi_class_head_test.py index ed5f98c..957cbf8 100644 --- a/tensorflow_privacy/privacy/estimators/multi_class_head_test.py +++ b/tensorflow_privacy/privacy/estimators/multi_class_head_test.py @@ -17,78 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# from absl.testing import parameterized import numpy as np import tensorflow as tf from tensorflow_privacy.privacy.estimators import multi_class_head +from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer class DPMultiClassHeadTest(tf.test.TestCase): - """Tests for DP-enabled heads.""" - - def _make_input_data(self, size): - """Create raw input data.""" - feature_a = np.random.normal(4, 1, (size)) - feature_b = np.random.normal(5, 0.7, (size)) - feature_c = np.random.normal(6, 2, (size)) - noise = np.random.normal(0, 30, (size)) - features = { - 'feature_a': feature_a, - 'feature_b': feature_b, - 'feature_c': feature_c, - } - - def label_fn(x): - if x < 110.0: - return 0 - elif x < 140.0: - return 1 - else: - return 2 - - labels_list = map( - label_fn, - np.power(feature_a, 3) + np.power(feature_b, 2) + - np.power(feature_c, 1) + noise) - return features, list(labels_list) - - def _make_input_fn(self, features, labels, training, batch_size=16): - - def input_fn(): - """An input function for training or evaluating.""" - # Convert the inputs to a Dataset. - dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) - - # Shuffle if in training mode. - if training: - dataset = dataset.shuffle(1000) - - return dataset.batch(batch_size) - - return input_fn - - def _make_model_fn(self, head, optimizer, feature_columns): - """Constructs and returns a model_fn using DPBinaryClassHead.""" - - def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument - feature_layer = tf.keras.layers.DenseFeatures(feature_columns) - inputs = feature_layer(features) - hidden_layer = tf.keras.layers.Dense(units=3, activation='relu') - hidden_layer_values = hidden_layer(inputs) - logits_layer = tf.keras.layers.Dense( - units=head.logits_dimension, activation=None) - logits = logits_layer(hidden_layer_values) - return head.create_estimator_spec( - features=features, - labels=labels, - mode=mode, - logits=logits, - trainable_variables=hidden_layer.trainable_weights + - logits_layer.trainable_weights, - optimizer=optimizer) - - return model_fn + """Tests for DP-enabled multiclass heads.""" def testLoss(self): """Tests loss() returns per-example losses.""" @@ -118,7 +55,7 @@ class DPMultiClassHeadTest(tf.test.TestCase): def testCreateTPUEstimatorSpec(self): """Tests that an Estimator built with this head works.""" - train_features, train_labels = self._make_input_data(256) + train_features, train_labels = test_utils.make_input_data(256, 3) feature_columns = [] for key in train_features: feature_columns.append(tf.feature_column.numeric_column(key=key)) @@ -129,23 +66,22 @@ class DPMultiClassHeadTest(tf.test.TestCase): l2_norm_clip=1.0, noise_multiplier=0.0, num_microbatches=2) - model_fn = self._make_model_fn(head, optimizer, feature_columns) + model_fn = test_utils.make_model_fn(head, optimizer, feature_columns) classifier = tf.estimator.Estimator(model_fn=model_fn) classifier.train( - input_fn=self._make_input_fn(train_features, train_labels, True), + input_fn=test_utils.make_input_fn(train_features, train_labels, True), steps=4) - test_features, test_labels = self._make_input_data(64) + test_features, test_labels = test_utils.make_input_data(64, 3) classifier.evaluate( - input_fn=self._make_input_fn(test_features, test_labels, False), + input_fn=test_utils.make_input_fn(test_features, test_labels, False), steps=4) - predict_features, predict_labels_ = self._make_input_data(64) - predictions = classifier.predict( - input_fn=self._make_input_fn(predict_features, predict_labels_, False)) - for p in predictions: - print('schien p: ', p) + predict_features, predict_labels_ = test_utils.make_input_data(64, 3) + classifier.predict( + input_fn=test_utils.make_input_fn(predict_features, predict_labels_, + False)) if __name__ == '__main__': diff --git a/tensorflow_privacy/privacy/estimators/test_utils.py b/tensorflow_privacy/privacy/estimators/test_utils.py new file mode 100644 index 0000000..5b4ad75 --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/test_utils.py @@ -0,0 +1,93 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions for unit tests for DP-enabled Estimators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +def make_input_data(size, classes): + """Create raw input data for testing.""" + feature_a = np.random.normal(4, 1, (size)) + feature_b = np.random.normal(5, 0.7, (size)) + feature_c = np.random.normal(6, 2, (size)) + noise = np.random.normal(0, 30, (size)) + features = { + 'feature_a': feature_a, + 'feature_b': feature_b, + 'feature_c': feature_c, + } + + if classes == 2: + labels = np.array( + np.power(feature_a, 3) + np.power(feature_b, 2) + + np.power(feature_c, 1) + noise > 125).astype(int) + else: + def label_fn(x): + if x < 110.0: + return 0 + elif x < 140.0: + return 1 + else: + return 2 + + labels = list(map( + label_fn, + np.power(feature_a, 3) + np.power(feature_b, 2) + + np.power(feature_c, 1) + noise)) + + return features, labels + + +def make_input_fn(features, labels, training, batch_size=16): + """Returns an input function suitable for an estimator.""" + + def input_fn(): + """An input function for training or evaluating.""" + # Convert the inputs to a Dataset. + dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) + + # Shuffle if in training mode. + if training: + dataset = dataset.shuffle(1000) + + return dataset.batch(batch_size) + return input_fn + + +def make_model_fn(head, optimizer, feature_columns): + """Constructs and returns a model_fn using supplied head.""" + + def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument + feature_layer = tf.keras.layers.DenseFeatures(feature_columns) + inputs = feature_layer(features) + hidden_layer = tf.keras.layers.Dense(units=3, activation='relu') + hidden_layer_values = hidden_layer(inputs) + logits_layer = tf.keras.layers.Dense( + units=head.logits_dimension, activation=None) + logits = logits_layer(hidden_layer_values) + return head.create_estimator_spec( + features=features, + labels=labels, + mode=mode, + logits=logits, + trainable_variables=hidden_layer.trainable_weights + + logits_layer.trainable_weights, + optimizer=optimizer) + + return model_fn