Refactor of common functions in binary and multiclass heads.

PiperOrigin-RevId: 325957037
This commit is contained in:
Steve Chien 2020-08-10 22:30:00 -07:00 committed by A. Unique TensorFlower
parent 3a641e077e
commit 99afaed68e
4 changed files with 128 additions and 144 deletions

View file

@ -20,8 +20,7 @@ py_library(
], ],
deps = [ deps = [
"//third_party/py/tensorflow", "//third_party/py/tensorflow",
# TODO(b/163395075): Remove this dependency once necessary function is public. "//third_party/tensorflow/python:keras_lib", # TODO(b/163395075): Remove when fixed.
"//third_party/tensorflow/python:keras_lib",
"//third_party/tensorflow_estimator", "//third_party/tensorflow_estimator",
], ],
) )
@ -33,12 +32,21 @@ py_library(
], ],
deps = [ deps = [
"//third_party/py/tensorflow", "//third_party/py/tensorflow",
# TODO(b/163395075): Remove this dependency once necessary function is public. "//third_party/tensorflow/python:keras_lib", # TODO(b/163395075): Remove when fixed.
"//third_party/tensorflow/python:keras_lib",
"//third_party/tensorflow_estimator", "//third_party/tensorflow_estimator",
], ],
) )
py_library(
name = "test_utils",
srcs = [
"test_utils.py",
],
deps = [
"//third_party/py/tensorflow",
],
)
py_test( py_test(
name = "binary_class_head_test", name = "binary_class_head_test",
timeout = "long", timeout = "long",
@ -46,8 +54,7 @@ py_test(
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
":binary_class_head", ":binary_class_head",
"//third_party/py/absl/testing:parameterized", ":test_utils",
"//third_party/py/six",
"//third_party/py/tensorflow", "//third_party/py/tensorflow",
"//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", "//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
], ],
@ -60,8 +67,7 @@ py_test(
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
":multi_class_head", ":multi_class_head",
"//third_party/py/absl/testing:parameterized", ":test_utils",
"//third_party/py/six",
"//third_party/py/tensorflow", "//third_party/py/tensorflow",
"//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", "//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
], ],

View file

@ -21,64 +21,12 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.estimators import binary_class_head from tensorflow_privacy.privacy.estimators import binary_class_head
from tensorflow_privacy.privacy.estimators import test_utils
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
class DPBinaryClassHeadTest(tf.test.TestCase): class DPBinaryClassHeadTest(tf.test.TestCase):
"""Tests for DP-enabled heads.""" """Tests for DP-enabled binary class heads."""
def _make_input_data(self, size):
"""Create raw input data."""
feature_a = np.random.normal(4, 1, (size))
feature_b = np.random.normal(5, 0.7, (size))
feature_c = np.random.normal(6, 2, (size))
noise = np.random.normal(0, 30, (size))
features = {
'feature_a': feature_a,
'feature_b': feature_b,
'feature_c': feature_c,
}
labels = np.array(
np.power(feature_a, 3) + np.power(feature_b, 2) +
np.power(feature_c, 1) + noise > 125).astype(int)
return features, labels
def _make_input_fn(self, features, labels, training, batch_size=16):
def input_fn():
"""An input function for training or evaluating."""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle if in training mode.
if training:
dataset = dataset.shuffle(1000)
return dataset.batch(batch_size)
return input_fn
def _make_model_fn(self, head, optimizer, feature_columns):
"""Constructs and returns a model_fn using DPBinaryClassHead."""
def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
inputs = feature_layer(features)
hidden_layer = tf.keras.layers.Dense(units=3, activation='relu')
hidden_layer_values = hidden_layer(inputs)
logits_layer = tf.keras.layers.Dense(
units=head.logits_dimension, activation=None)
logits = logits_layer(hidden_layer_values)
return head.create_estimator_spec(
features=features,
labels=labels,
mode=mode,
logits=logits,
trainable_variables=hidden_layer.trainable_weights +
logits_layer.trainable_weights,
optimizer=optimizer)
return model_fn
def testLoss(self): def testLoss(self):
"""Tests loss() returns per-example losses.""" """Tests loss() returns per-example losses."""
@ -104,7 +52,7 @@ class DPBinaryClassHeadTest(tf.test.TestCase):
def testCreateTPUEstimatorSpec(self): def testCreateTPUEstimatorSpec(self):
"""Tests that an Estimator built with this head works.""" """Tests that an Estimator built with this head works."""
train_features, train_labels = self._make_input_data(256) train_features, train_labels = test_utils.make_input_data(256, 2)
feature_columns = [] feature_columns = []
for key in train_features: for key in train_features:
feature_columns.append(tf.feature_column.numeric_column(key=key)) feature_columns.append(tf.feature_column.numeric_column(key=key))
@ -115,21 +63,22 @@ class DPBinaryClassHeadTest(tf.test.TestCase):
l2_norm_clip=1.0, l2_norm_clip=1.0,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=2) num_microbatches=2)
model_fn = self._make_model_fn(head, optimizer, feature_columns) model_fn = test_utils.make_model_fn(head, optimizer, feature_columns)
classifier = tf.estimator.Estimator(model_fn=model_fn) classifier = tf.estimator.Estimator(model_fn=model_fn)
classifier.train( classifier.train(
input_fn=self._make_input_fn(train_features, train_labels, True), input_fn=test_utils.make_input_fn(train_features, train_labels, True),
steps=4) steps=4)
test_features, test_labels = self._make_input_data(64) test_features, test_labels = test_utils.make_input_data(64, 2)
classifier.evaluate( classifier.evaluate(
input_fn=self._make_input_fn(test_features, test_labels, False), input_fn=test_utils.make_input_fn(test_features, test_labels, False),
steps=4) steps=4)
predict_features, predict_labels_ = self._make_input_data(64) predict_features, predict_labels_ = test_utils.make_input_data(64, 2)
classifier.predict( classifier.predict(
input_fn=self._make_input_fn(predict_features, predict_labels_, False)) input_fn=test_utils.make_input_fn(predict_features, predict_labels_,
False))
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -17,78 +17,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.estimators import multi_class_head from tensorflow_privacy.privacy.estimators import multi_class_head
from tensorflow_privacy.privacy.estimators import test_utils
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
class DPMultiClassHeadTest(tf.test.TestCase): class DPMultiClassHeadTest(tf.test.TestCase):
"""Tests for DP-enabled heads.""" """Tests for DP-enabled multiclass heads."""
def _make_input_data(self, size):
"""Create raw input data."""
feature_a = np.random.normal(4, 1, (size))
feature_b = np.random.normal(5, 0.7, (size))
feature_c = np.random.normal(6, 2, (size))
noise = np.random.normal(0, 30, (size))
features = {
'feature_a': feature_a,
'feature_b': feature_b,
'feature_c': feature_c,
}
def label_fn(x):
if x < 110.0:
return 0
elif x < 140.0:
return 1
else:
return 2
labels_list = map(
label_fn,
np.power(feature_a, 3) + np.power(feature_b, 2) +
np.power(feature_c, 1) + noise)
return features, list(labels_list)
def _make_input_fn(self, features, labels, training, batch_size=16):
def input_fn():
"""An input function for training or evaluating."""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle if in training mode.
if training:
dataset = dataset.shuffle(1000)
return dataset.batch(batch_size)
return input_fn
def _make_model_fn(self, head, optimizer, feature_columns):
"""Constructs and returns a model_fn using DPBinaryClassHead."""
def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
inputs = feature_layer(features)
hidden_layer = tf.keras.layers.Dense(units=3, activation='relu')
hidden_layer_values = hidden_layer(inputs)
logits_layer = tf.keras.layers.Dense(
units=head.logits_dimension, activation=None)
logits = logits_layer(hidden_layer_values)
return head.create_estimator_spec(
features=features,
labels=labels,
mode=mode,
logits=logits,
trainable_variables=hidden_layer.trainable_weights +
logits_layer.trainable_weights,
optimizer=optimizer)
return model_fn
def testLoss(self): def testLoss(self):
"""Tests loss() returns per-example losses.""" """Tests loss() returns per-example losses."""
@ -118,7 +55,7 @@ class DPMultiClassHeadTest(tf.test.TestCase):
def testCreateTPUEstimatorSpec(self): def testCreateTPUEstimatorSpec(self):
"""Tests that an Estimator built with this head works.""" """Tests that an Estimator built with this head works."""
train_features, train_labels = self._make_input_data(256) train_features, train_labels = test_utils.make_input_data(256, 3)
feature_columns = [] feature_columns = []
for key in train_features: for key in train_features:
feature_columns.append(tf.feature_column.numeric_column(key=key)) feature_columns.append(tf.feature_column.numeric_column(key=key))
@ -129,23 +66,22 @@ class DPMultiClassHeadTest(tf.test.TestCase):
l2_norm_clip=1.0, l2_norm_clip=1.0,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=2) num_microbatches=2)
model_fn = self._make_model_fn(head, optimizer, feature_columns) model_fn = test_utils.make_model_fn(head, optimizer, feature_columns)
classifier = tf.estimator.Estimator(model_fn=model_fn) classifier = tf.estimator.Estimator(model_fn=model_fn)
classifier.train( classifier.train(
input_fn=self._make_input_fn(train_features, train_labels, True), input_fn=test_utils.make_input_fn(train_features, train_labels, True),
steps=4) steps=4)
test_features, test_labels = self._make_input_data(64) test_features, test_labels = test_utils.make_input_data(64, 3)
classifier.evaluate( classifier.evaluate(
input_fn=self._make_input_fn(test_features, test_labels, False), input_fn=test_utils.make_input_fn(test_features, test_labels, False),
steps=4) steps=4)
predict_features, predict_labels_ = self._make_input_data(64) predict_features, predict_labels_ = test_utils.make_input_data(64, 3)
predictions = classifier.predict( classifier.predict(
input_fn=self._make_input_fn(predict_features, predict_labels_, False)) input_fn=test_utils.make_input_fn(predict_features, predict_labels_,
for p in predictions: False))
print('schien p: ', p)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -0,0 +1,93 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for unit tests for DP-enabled Estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def make_input_data(size, classes):
"""Create raw input data for testing."""
feature_a = np.random.normal(4, 1, (size))
feature_b = np.random.normal(5, 0.7, (size))
feature_c = np.random.normal(6, 2, (size))
noise = np.random.normal(0, 30, (size))
features = {
'feature_a': feature_a,
'feature_b': feature_b,
'feature_c': feature_c,
}
if classes == 2:
labels = np.array(
np.power(feature_a, 3) + np.power(feature_b, 2) +
np.power(feature_c, 1) + noise > 125).astype(int)
else:
def label_fn(x):
if x < 110.0:
return 0
elif x < 140.0:
return 1
else:
return 2
labels = list(map(
label_fn,
np.power(feature_a, 3) + np.power(feature_b, 2) +
np.power(feature_c, 1) + noise))
return features, labels
def make_input_fn(features, labels, training, batch_size=16):
"""Returns an input function suitable for an estimator."""
def input_fn():
"""An input function for training or evaluating."""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle if in training mode.
if training:
dataset = dataset.shuffle(1000)
return dataset.batch(batch_size)
return input_fn
def make_model_fn(head, optimizer, feature_columns):
"""Constructs and returns a model_fn using supplied head."""
def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
inputs = feature_layer(features)
hidden_layer = tf.keras.layers.Dense(units=3, activation='relu')
hidden_layer_values = hidden_layer(inputs)
logits_layer = tf.keras.layers.Dense(
units=head.logits_dimension, activation=None)
logits = logits_layer(hidden_layer_values)
return head.create_estimator_spec(
features=features,
labels=labels,
mode=mode,
logits=logits,
trainable_variables=hidden_layer.trainable_weights +
logits_layer.trainable_weights,
optimizer=optimizer)
return model_fn