Add DP-enabled version of DNNClassifier.

PiperOrigin-RevId: 326482309
This commit is contained in:
Steve Chien 2020-08-13 11:03:07 -07:00 committed by A. Unique TensorFlower
parent 3240a71965
commit d72e3400b7
4 changed files with 170 additions and 2 deletions

View file

@ -37,6 +37,18 @@ py_library(
],
)
py_library(
name = "dnn",
srcs = [
"dnn.py",
],
deps = [
":head_utils",
"//third_party/py/tensorflow",
"//third_party/tensorflow_estimator",
],
)
py_library(
name = "test_utils",
srcs = [
@ -72,3 +84,17 @@ py_test(
"//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
],
)
py_test(
name = "dnn_test",
timeout = "long",
srcs = ["dnn_test.py"],
python_version = "PY3",
deps = [
":dnn",
":test_utils",
"//third_party/py/absl/testing:parameterized",
"//third_party/py/tensorflow",
"//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
],
)

View file

@ -0,0 +1,71 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Estimator heads that allow integration with TF Privacy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow_privacy.privacy.estimators import head_utils
from tensorflow_estimator.python.estimator import estimator
from tensorflow_estimator.python.estimator.canned import dnn
class DNNClassifier(tf.estimator.Estimator):
"""DP version of DNNClassifier."""
def __init__(
self,
hidden_units,
feature_columns,
model_dir=None,
n_classes=2,
weight_column=None,
label_vocabulary=None,
optimizer=None,
activation_fn=tf.nn.relu,
dropout=None,
config=None,
warm_start_from=None,
loss_reduction=tf.keras.losses.Reduction.NONE,
batch_norm=False,
):
head = head_utils.binary_or_multi_class_head(
n_classes,
weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction)
estimator._canned_estimator_api_gauge.get_cell('Classifier').set('DNN')
def _model_fn(features, labels, mode, config):
return dnn.dnn_model_fn_v2(
features=features,
labels=labels,
mode=mode,
head=head,
hidden_units=hidden_units,
feature_columns=tuple(feature_columns or []),
optimizer=optimizer,
activation_fn=activation_fn,
dropout=dropout,
config=config,
batch_norm=batch_norm)
super(DNNClassifier, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
config=config,
warm_start_from=warm_start_from)

View file

@ -0,0 +1,71 @@
# Copyright 2020, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for DP-enabled binary class heads."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.estimators import dnn
from tensorflow_privacy.privacy.estimators import test_utils
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
class DPDNNClassifierTest(tf.test.TestCase, parameterized.TestCase):
"""Tests for DP-enabled DNNClassifier."""
@parameterized.named_parameters(
('BinaryClassDNN 1', 2),
('MultiClassDNN 1', 3),
)
def testDNN(self, classes):
train_features, train_labels = test_utils.make_input_data(256, classes)
feature_columns = []
for key in train_features:
feature_columns.append(tf.feature_column.numeric_column(key=key))
optimizer = functools.partial(
DPKerasSGDOptimizer,
learning_rate=0.5,
l2_norm_clip=1.0,
noise_multiplier=0.0,
num_microbatches=1)
classifier = dnn.DNNClassifier(
hidden_units=[10],
activation_fn='relu',
feature_columns=feature_columns,
n_classes=classes,
optimizer=optimizer,
loss_reduction=tf.losses.Reduction.NONE)
classifier.train(
input_fn=test_utils.make_input_fn(train_features, train_labels, True,
16))
test_features, test_labels = test_utils.make_input_data(64, classes)
classifier.evaluate(
input_fn=test_utils.make_input_fn(test_features, test_labels, False,
16))
predict_features, predict_labels = test_utils.make_input_data(64, classes)
classifier.predict(
input_fn=test_utils.make_input_fn(predict_features, predict_labels,
False))
if __name__ == '__main__':
tf.test.main()

View file

@ -78,9 +78,9 @@ class DPMultiClassHeadTest(tf.test.TestCase):
input_fn=test_utils.make_input_fn(test_features, test_labels, False),
steps=4)
predict_features, predict_labels_ = test_utils.make_input_data(64, 3)
predict_features, predict_labels = test_utils.make_input_data(64, 3)
classifier.predict(
input_fn=test_utils.make_input_fn(predict_features, predict_labels_,
input_fn=test_utils.make_input_fn(predict_features, predict_labels,
False))