diff --git a/tensorflow_privacy/privacy/estimators/BUILD b/tensorflow_privacy/privacy/estimators/BUILD index 73b0f5c..31cec22 100644 --- a/tensorflow_privacy/privacy/estimators/BUILD +++ b/tensorflow_privacy/privacy/estimators/BUILD @@ -37,6 +37,18 @@ py_library( ], ) +py_library( + name = "dnn", + srcs = [ + "dnn.py", + ], + deps = [ + ":head_utils", + "//third_party/py/tensorflow", + "//third_party/tensorflow_estimator", + ], +) + py_library( name = "test_utils", srcs = [ @@ -72,3 +84,17 @@ py_test( "//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", ], ) + +py_test( + name = "dnn_test", + timeout = "long", + srcs = ["dnn_test.py"], + python_version = "PY3", + deps = [ + ":dnn", + ":test_utils", + "//third_party/py/absl/testing:parameterized", + "//third_party/py/tensorflow", + "//third_party/py/tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", + ], +) diff --git a/tensorflow_privacy/privacy/estimators/dnn.py b/tensorflow_privacy/privacy/estimators/dnn.py new file mode 100644 index 0000000..cac6034 --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/dnn.py @@ -0,0 +1,71 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Estimator heads that allow integration with TF Privacy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_privacy.privacy.estimators import head_utils +from tensorflow_estimator.python.estimator import estimator +from tensorflow_estimator.python.estimator.canned import dnn + + +class DNNClassifier(tf.estimator.Estimator): + """DP version of DNNClassifier.""" + + def __init__( + self, + hidden_units, + feature_columns, + model_dir=None, + n_classes=2, + weight_column=None, + label_vocabulary=None, + optimizer=None, + activation_fn=tf.nn.relu, + dropout=None, + config=None, + warm_start_from=None, + loss_reduction=tf.keras.losses.Reduction.NONE, + batch_norm=False, + ): + head = head_utils.binary_or_multi_class_head( + n_classes, + weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) + estimator._canned_estimator_api_gauge.get_cell('Classifier').set('DNN') + + def _model_fn(features, labels, mode, config): + return dnn.dnn_model_fn_v2( + features=features, + labels=labels, + mode=mode, + head=head, + hidden_units=hidden_units, + feature_columns=tuple(feature_columns or []), + optimizer=optimizer, + activation_fn=activation_fn, + dropout=dropout, + config=config, + batch_norm=batch_norm) + + super(DNNClassifier, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config, + warm_start_from=warm_start_from) diff --git a/tensorflow_privacy/privacy/estimators/dnn_test.py b/tensorflow_privacy/privacy/estimators/dnn_test.py new file mode 100644 index 0000000..bb7c1ad --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/dnn_test.py @@ -0,0 +1,71 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for DP-enabled binary class heads.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from absl.testing import parameterized +import tensorflow as tf +from tensorflow_privacy.privacy.estimators import dnn +from tensorflow_privacy.privacy.estimators import test_utils +from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer + + +class DPDNNClassifierTest(tf.test.TestCase, parameterized.TestCase): + """Tests for DP-enabled DNNClassifier.""" + + @parameterized.named_parameters( + ('BinaryClassDNN 1', 2), + ('MultiClassDNN 1', 3), + ) + def testDNN(self, classes): + train_features, train_labels = test_utils.make_input_data(256, classes) + feature_columns = [] + for key in train_features: + feature_columns.append(tf.feature_column.numeric_column(key=key)) + + optimizer = functools.partial( + DPKerasSGDOptimizer, + learning_rate=0.5, + l2_norm_clip=1.0, + noise_multiplier=0.0, + num_microbatches=1) + + classifier = dnn.DNNClassifier( + hidden_units=[10], + activation_fn='relu', + feature_columns=feature_columns, + n_classes=classes, + optimizer=optimizer, + loss_reduction=tf.losses.Reduction.NONE) + + classifier.train( + input_fn=test_utils.make_input_fn(train_features, train_labels, True, + 16)) + + test_features, test_labels = test_utils.make_input_data(64, classes) + classifier.evaluate( + input_fn=test_utils.make_input_fn(test_features, test_labels, False, + 16)) + + predict_features, predict_labels = test_utils.make_input_data(64, classes) + classifier.predict( + input_fn=test_utils.make_input_fn(predict_features, predict_labels, + False)) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/estimators/multi_class_head_test.py b/tensorflow_privacy/privacy/estimators/multi_class_head_test.py index 957cbf8..2eca0da 100644 --- a/tensorflow_privacy/privacy/estimators/multi_class_head_test.py +++ b/tensorflow_privacy/privacy/estimators/multi_class_head_test.py @@ -78,9 +78,9 @@ class DPMultiClassHeadTest(tf.test.TestCase): input_fn=test_utils.make_input_fn(test_features, test_labels, False), steps=4) - predict_features, predict_labels_ = test_utils.make_input_data(64, 3) + predict_features, predict_labels = test_utils.make_input_data(64, 3) classifier.predict( - input_fn=test_utils.make_input_fn(predict_features, predict_labels_, + input_fn=test_utils.make_input_fn(predict_features, predict_labels, False))