forked from 626_privacy/tensorflow_privacy
Refactor of common functions in binary and multiclass heads.
PiperOrigin-RevId: 325957037
This commit is contained in:
parent
3a641e077e
commit
99afaed68e
4 changed files with 128 additions and 144 deletions
|
@ -20,8 +20,7 @@ py_library(
|
|||
],
|
||||
deps = [
|
||||
"//third_party/py/tensorflow",
|
||||
# TODO(b/163395075): Remove this dependency once necessary function is public.
|
||||
"//third_party/tensorflow/python:keras_lib",
|
||||
"//third_party/tensorflow/python:keras_lib", # TODO(b/163395075): Remove when fixed.
|
||||
"//third_party/tensorflow_estimator",
|
||||
],
|
||||
)
|
||||
|
@ -33,12 +32,21 @@ py_library(
|
|||
],
|
||||
deps = [
|
||||
"//third_party/py/tensorflow",
|
||||
# TODO(b/163395075): Remove this dependency once necessary function is public.
|
||||
"//third_party/tensorflow/python:keras_lib",
|
||||
"//third_party/tensorflow/python:keras_lib", # TODO(b/163395075): Remove when fixed.
|
||||
"//third_party/tensorflow_estimator",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_utils",
|
||||
srcs = [
|
||||
"test_utils.py",
|
||||
],
|
||||
deps = [
|
||||
"//third_party/py/tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "binary_class_head_test",
|
||||
timeout = "long",
|
||||
|
@ -46,8 +54,7 @@ py_test(
|
|||
python_version = "PY3",
|
||||
deps = [
|
||||
":binary_class_head",
|
||||
"//third_party/py/absl/testing:parameterized",
|
||||
"//third_party/py/six",
|
||||
":test_utils",
|
||||
"//third_party/py/tensorflow",
|
||||
"//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
|
||||
],
|
||||
|
@ -60,8 +67,7 @@ py_test(
|
|||
python_version = "PY3",
|
||||
deps = [
|
||||
":multi_class_head",
|
||||
"//third_party/py/absl/testing:parameterized",
|
||||
"//third_party/py/six",
|
||||
":test_utils",
|
||||
"//third_party/py/tensorflow",
|
||||
"//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
|
||||
],
|
||||
|
|
|
@ -21,64 +21,12 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.estimators import binary_class_head
|
||||
from tensorflow_privacy.privacy.estimators import test_utils
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
|
||||
|
||||
|
||||
class DPBinaryClassHeadTest(tf.test.TestCase):
|
||||
"""Tests for DP-enabled heads."""
|
||||
|
||||
def _make_input_data(self, size):
|
||||
"""Create raw input data."""
|
||||
feature_a = np.random.normal(4, 1, (size))
|
||||
feature_b = np.random.normal(5, 0.7, (size))
|
||||
feature_c = np.random.normal(6, 2, (size))
|
||||
noise = np.random.normal(0, 30, (size))
|
||||
features = {
|
||||
'feature_a': feature_a,
|
||||
'feature_b': feature_b,
|
||||
'feature_c': feature_c,
|
||||
}
|
||||
labels = np.array(
|
||||
np.power(feature_a, 3) + np.power(feature_b, 2) +
|
||||
np.power(feature_c, 1) + noise > 125).astype(int)
|
||||
return features, labels
|
||||
|
||||
def _make_input_fn(self, features, labels, training, batch_size=16):
|
||||
|
||||
def input_fn():
|
||||
"""An input function for training or evaluating."""
|
||||
# Convert the inputs to a Dataset.
|
||||
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
|
||||
|
||||
# Shuffle if in training mode.
|
||||
if training:
|
||||
dataset = dataset.shuffle(1000)
|
||||
|
||||
return dataset.batch(batch_size)
|
||||
|
||||
return input_fn
|
||||
|
||||
def _make_model_fn(self, head, optimizer, feature_columns):
|
||||
"""Constructs and returns a model_fn using DPBinaryClassHead."""
|
||||
|
||||
def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument
|
||||
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
|
||||
inputs = feature_layer(features)
|
||||
hidden_layer = tf.keras.layers.Dense(units=3, activation='relu')
|
||||
hidden_layer_values = hidden_layer(inputs)
|
||||
logits_layer = tf.keras.layers.Dense(
|
||||
units=head.logits_dimension, activation=None)
|
||||
logits = logits_layer(hidden_layer_values)
|
||||
return head.create_estimator_spec(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
logits=logits,
|
||||
trainable_variables=hidden_layer.trainable_weights +
|
||||
logits_layer.trainable_weights,
|
||||
optimizer=optimizer)
|
||||
|
||||
return model_fn
|
||||
"""Tests for DP-enabled binary class heads."""
|
||||
|
||||
def testLoss(self):
|
||||
"""Tests loss() returns per-example losses."""
|
||||
|
@ -104,7 +52,7 @@ class DPBinaryClassHeadTest(tf.test.TestCase):
|
|||
def testCreateTPUEstimatorSpec(self):
|
||||
"""Tests that an Estimator built with this head works."""
|
||||
|
||||
train_features, train_labels = self._make_input_data(256)
|
||||
train_features, train_labels = test_utils.make_input_data(256, 2)
|
||||
feature_columns = []
|
||||
for key in train_features:
|
||||
feature_columns.append(tf.feature_column.numeric_column(key=key))
|
||||
|
@ -115,21 +63,22 @@ class DPBinaryClassHeadTest(tf.test.TestCase):
|
|||
l2_norm_clip=1.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=2)
|
||||
model_fn = self._make_model_fn(head, optimizer, feature_columns)
|
||||
model_fn = test_utils.make_model_fn(head, optimizer, feature_columns)
|
||||
classifier = tf.estimator.Estimator(model_fn=model_fn)
|
||||
|
||||
classifier.train(
|
||||
input_fn=self._make_input_fn(train_features, train_labels, True),
|
||||
input_fn=test_utils.make_input_fn(train_features, train_labels, True),
|
||||
steps=4)
|
||||
|
||||
test_features, test_labels = self._make_input_data(64)
|
||||
test_features, test_labels = test_utils.make_input_data(64, 2)
|
||||
classifier.evaluate(
|
||||
input_fn=self._make_input_fn(test_features, test_labels, False),
|
||||
input_fn=test_utils.make_input_fn(test_features, test_labels, False),
|
||||
steps=4)
|
||||
|
||||
predict_features, predict_labels_ = self._make_input_data(64)
|
||||
predict_features, predict_labels_ = test_utils.make_input_data(64, 2)
|
||||
classifier.predict(
|
||||
input_fn=self._make_input_fn(predict_features, predict_labels_, False))
|
||||
input_fn=test_utils.make_input_fn(predict_features, predict_labels_,
|
||||
False))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -17,78 +17,15 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.estimators import multi_class_head
|
||||
from tensorflow_privacy.privacy.estimators import test_utils
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
|
||||
|
||||
|
||||
class DPMultiClassHeadTest(tf.test.TestCase):
|
||||
"""Tests for DP-enabled heads."""
|
||||
|
||||
def _make_input_data(self, size):
|
||||
"""Create raw input data."""
|
||||
feature_a = np.random.normal(4, 1, (size))
|
||||
feature_b = np.random.normal(5, 0.7, (size))
|
||||
feature_c = np.random.normal(6, 2, (size))
|
||||
noise = np.random.normal(0, 30, (size))
|
||||
features = {
|
||||
'feature_a': feature_a,
|
||||
'feature_b': feature_b,
|
||||
'feature_c': feature_c,
|
||||
}
|
||||
|
||||
def label_fn(x):
|
||||
if x < 110.0:
|
||||
return 0
|
||||
elif x < 140.0:
|
||||
return 1
|
||||
else:
|
||||
return 2
|
||||
|
||||
labels_list = map(
|
||||
label_fn,
|
||||
np.power(feature_a, 3) + np.power(feature_b, 2) +
|
||||
np.power(feature_c, 1) + noise)
|
||||
return features, list(labels_list)
|
||||
|
||||
def _make_input_fn(self, features, labels, training, batch_size=16):
|
||||
|
||||
def input_fn():
|
||||
"""An input function for training or evaluating."""
|
||||
# Convert the inputs to a Dataset.
|
||||
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
|
||||
|
||||
# Shuffle if in training mode.
|
||||
if training:
|
||||
dataset = dataset.shuffle(1000)
|
||||
|
||||
return dataset.batch(batch_size)
|
||||
|
||||
return input_fn
|
||||
|
||||
def _make_model_fn(self, head, optimizer, feature_columns):
|
||||
"""Constructs and returns a model_fn using DPBinaryClassHead."""
|
||||
|
||||
def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument
|
||||
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
|
||||
inputs = feature_layer(features)
|
||||
hidden_layer = tf.keras.layers.Dense(units=3, activation='relu')
|
||||
hidden_layer_values = hidden_layer(inputs)
|
||||
logits_layer = tf.keras.layers.Dense(
|
||||
units=head.logits_dimension, activation=None)
|
||||
logits = logits_layer(hidden_layer_values)
|
||||
return head.create_estimator_spec(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
logits=logits,
|
||||
trainable_variables=hidden_layer.trainable_weights +
|
||||
logits_layer.trainable_weights,
|
||||
optimizer=optimizer)
|
||||
|
||||
return model_fn
|
||||
"""Tests for DP-enabled multiclass heads."""
|
||||
|
||||
def testLoss(self):
|
||||
"""Tests loss() returns per-example losses."""
|
||||
|
@ -118,7 +55,7 @@ class DPMultiClassHeadTest(tf.test.TestCase):
|
|||
def testCreateTPUEstimatorSpec(self):
|
||||
"""Tests that an Estimator built with this head works."""
|
||||
|
||||
train_features, train_labels = self._make_input_data(256)
|
||||
train_features, train_labels = test_utils.make_input_data(256, 3)
|
||||
feature_columns = []
|
||||
for key in train_features:
|
||||
feature_columns.append(tf.feature_column.numeric_column(key=key))
|
||||
|
@ -129,23 +66,22 @@ class DPMultiClassHeadTest(tf.test.TestCase):
|
|||
l2_norm_clip=1.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=2)
|
||||
model_fn = self._make_model_fn(head, optimizer, feature_columns)
|
||||
model_fn = test_utils.make_model_fn(head, optimizer, feature_columns)
|
||||
classifier = tf.estimator.Estimator(model_fn=model_fn)
|
||||
|
||||
classifier.train(
|
||||
input_fn=self._make_input_fn(train_features, train_labels, True),
|
||||
input_fn=test_utils.make_input_fn(train_features, train_labels, True),
|
||||
steps=4)
|
||||
|
||||
test_features, test_labels = self._make_input_data(64)
|
||||
test_features, test_labels = test_utils.make_input_data(64, 3)
|
||||
classifier.evaluate(
|
||||
input_fn=self._make_input_fn(test_features, test_labels, False),
|
||||
input_fn=test_utils.make_input_fn(test_features, test_labels, False),
|
||||
steps=4)
|
||||
|
||||
predict_features, predict_labels_ = self._make_input_data(64)
|
||||
predictions = classifier.predict(
|
||||
input_fn=self._make_input_fn(predict_features, predict_labels_, False))
|
||||
for p in predictions:
|
||||
print('schien p: ', p)
|
||||
predict_features, predict_labels_ = test_utils.make_input_data(64, 3)
|
||||
classifier.predict(
|
||||
input_fn=test_utils.make_input_fn(predict_features, predict_labels_,
|
||||
False))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
93
tensorflow_privacy/privacy/estimators/test_utils.py
Normal file
93
tensorflow_privacy/privacy/estimators/test_utils.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2020, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helper functions for unit tests for DP-enabled Estimators."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def make_input_data(size, classes):
|
||||
"""Create raw input data for testing."""
|
||||
feature_a = np.random.normal(4, 1, (size))
|
||||
feature_b = np.random.normal(5, 0.7, (size))
|
||||
feature_c = np.random.normal(6, 2, (size))
|
||||
noise = np.random.normal(0, 30, (size))
|
||||
features = {
|
||||
'feature_a': feature_a,
|
||||
'feature_b': feature_b,
|
||||
'feature_c': feature_c,
|
||||
}
|
||||
|
||||
if classes == 2:
|
||||
labels = np.array(
|
||||
np.power(feature_a, 3) + np.power(feature_b, 2) +
|
||||
np.power(feature_c, 1) + noise > 125).astype(int)
|
||||
else:
|
||||
def label_fn(x):
|
||||
if x < 110.0:
|
||||
return 0
|
||||
elif x < 140.0:
|
||||
return 1
|
||||
else:
|
||||
return 2
|
||||
|
||||
labels = list(map(
|
||||
label_fn,
|
||||
np.power(feature_a, 3) + np.power(feature_b, 2) +
|
||||
np.power(feature_c, 1) + noise))
|
||||
|
||||
return features, labels
|
||||
|
||||
|
||||
def make_input_fn(features, labels, training, batch_size=16):
|
||||
"""Returns an input function suitable for an estimator."""
|
||||
|
||||
def input_fn():
|
||||
"""An input function for training or evaluating."""
|
||||
# Convert the inputs to a Dataset.
|
||||
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
|
||||
|
||||
# Shuffle if in training mode.
|
||||
if training:
|
||||
dataset = dataset.shuffle(1000)
|
||||
|
||||
return dataset.batch(batch_size)
|
||||
return input_fn
|
||||
|
||||
|
||||
def make_model_fn(head, optimizer, feature_columns):
|
||||
"""Constructs and returns a model_fn using supplied head."""
|
||||
|
||||
def model_fn(features, labels, mode, params, config=None): # pylint: disable=unused-argument
|
||||
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
|
||||
inputs = feature_layer(features)
|
||||
hidden_layer = tf.keras.layers.Dense(units=3, activation='relu')
|
||||
hidden_layer_values = hidden_layer(inputs)
|
||||
logits_layer = tf.keras.layers.Dense(
|
||||
units=head.logits_dimension, activation=None)
|
||||
logits = logits_layer(hidden_layer_values)
|
||||
return head.create_estimator_spec(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
logits=logits,
|
||||
trainable_variables=hidden_layer.trainable_weights +
|
||||
logits_layer.trainable_weights,
|
||||
optimizer=optimizer)
|
||||
|
||||
return model_fn
|
Loading…
Reference in a new issue