From 134c898ded0e6ad351738421b27cd65bb0cb4060 Mon Sep 17 00:00:00 2001 From: Steve Chien Date: Wed, 26 Jul 2023 16:38:09 -0700 Subject: [PATCH] Add DP-SGD version of v1 LinearClassifier. PiperOrigin-RevId: 551350685 --- .../privacy/estimators/v1/BUILD | 22 ++++++ .../privacy/estimators/v1/dnn.py | 2 +- .../privacy/estimators/v1/linear.py | 72 ++++++++++++++++++ .../privacy/estimators/v1/linear_test.py | 76 +++++++++++++++++++ 4 files changed, 171 insertions(+), 1 deletion(-) create mode 100644 tensorflow_privacy/privacy/estimators/v1/linear.py create mode 100644 tensorflow_privacy/privacy/estimators/v1/linear_test.py diff --git a/tensorflow_privacy/privacy/estimators/v1/BUILD b/tensorflow_privacy/privacy/estimators/v1/BUILD index 0522c75..7d9420e 100644 --- a/tensorflow_privacy/privacy/estimators/v1/BUILD +++ b/tensorflow_privacy/privacy/estimators/v1/BUILD @@ -26,6 +26,15 @@ py_library( deps = [":head"], ) +py_library( + name = "linear", + srcs = [ + "linear.py", + ], + srcs_version = "PY3", + deps = [":head"], +) + py_test( name = "head_test", timeout = "long", @@ -51,3 +60,16 @@ py_test( "//tensorflow_privacy/privacy/optimizers:dp_optimizer", ], ) + +py_test( + name = "linear_test", + timeout = "long", + srcs = ["linear_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":linear", + "//tensorflow_privacy/privacy/estimators:test_utils", + "//tensorflow_privacy/privacy/optimizers:dp_optimizer", + ], +) diff --git a/tensorflow_privacy/privacy/estimators/v1/dnn.py b/tensorflow_privacy/privacy/estimators/v1/dnn.py index c819593..bc57de5 100644 --- a/tensorflow_privacy/privacy/estimators/v1/dnn.py +++ b/tensorflow_privacy/privacy/estimators/v1/dnn.py @@ -38,7 +38,7 @@ class DNNClassifier(estimator.Estimator): input_layer_partitioner=None, config=None, warm_start_from=None, - loss_reduction=tf.compat.v1.losses.Reduction.SUM, + loss_reduction=tf.compat.v1.losses.Reduction.SUM, # For scalar summary. batch_norm=False, ): """See `tf.compat.v1.estimator.DNNClassifier`.""" diff --git a/tensorflow_privacy/privacy/estimators/v1/linear.py b/tensorflow_privacy/privacy/estimators/v1/linear.py new file mode 100644 index 0000000..d62279f --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/v1/linear.py @@ -0,0 +1,72 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""DP version of LinearClassifier v1.""" + +import tensorflow as tf +from tensorflow_privacy.privacy.estimators.v1 import head as head_lib +from tensorflow_estimator.python.estimator import estimator # pylint: disable=g-deprecated-tf-checker +from tensorflow_estimator.python.estimator.canned import linear # pylint: disable=g-deprecated-tf-checker + + +class LinearClassifier(estimator.Estimator): + """DP version of `tf.compat.v1.estimator.LinearClassifier`.""" + + def __init__( + self, + feature_columns, + model_dir=None, + n_classes=2, + weight_column=None, + label_vocabulary=None, + optimizer='Ftrl', + config=None, + partitioner=None, + warm_start_from=None, + loss_reduction=tf.compat.v1.losses.Reduction.SUM, # For scalar summary. + sparse_combiner='sum', + ): + """See `tf.compat.v1.estimator.LinearClassifier`.""" + linear._validate_linear_sdca_optimizer_for_linear_classifier( # pylint: disable=protected-access + feature_columns=feature_columns, + n_classes=n_classes, + optimizer=optimizer, + sparse_combiner=sparse_combiner, + ) + estimator._canned_estimator_api_gauge.get_cell('Classifier').set('Linear') # pylint: disable=protected-access + + head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access + n_classes, weight_column, label_vocabulary, loss_reduction + ) + + def _model_fn(features, labels, mode, config): + """Call the defined shared _linear_model_fn.""" + return linear._linear_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + feature_columns=tuple(feature_columns or []), + optimizer=optimizer, + partitioner=partitioner, + config=config, + sparse_combiner=sparse_combiner, + ) + + super(LinearClassifier, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config, + warm_start_from=warm_start_from, + ) diff --git a/tensorflow_privacy/privacy/estimators/v1/linear_test.py b/tensorflow_privacy/privacy/estimators/v1/linear_test.py new file mode 100644 index 0000000..adf9586 --- /dev/null +++ b/tensorflow_privacy/privacy/estimators/v1/linear_test.py @@ -0,0 +1,76 @@ +# Copyright 2023, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for LinearClassifier.""" + +import functools + +from absl.testing import parameterized +import tensorflow as tf +from tensorflow_privacy.privacy.estimators import test_utils +from tensorflow_privacy.privacy.estimators.v1 import linear +from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer + + +class DPLinearClassifierClassifierTest( + tf.test.TestCase, parameterized.TestCase +): + """Tests for DP-enabled LinearClassifier.""" + + @parameterized.named_parameters( + ('BinaryClassLinear', 2), + ('MultiClassLinear 3', 3), + ('MultiClassLinear 4', 4), + ) + def testLinearClassifier(self, n_classes): + train_features, train_labels = test_utils.make_input_data(256, n_classes) + feature_columns = [] + for key in train_features: + feature_columns.append(tf.feature_column.numeric_column(key=key)) # pylint: disable=g-deprecated-tf-checker + + optimizer = functools.partial( + DPGradientDescentGaussianOptimizer, + learning_rate=0.5, + l2_norm_clip=1.0, + noise_multiplier=0.0, + num_microbatches=1, + ) + + classifier = linear.LinearClassifier( + feature_columns=feature_columns, + n_classes=n_classes, + optimizer=optimizer, + loss_reduction=tf.compat.v1.losses.Reduction.SUM, + ) + + classifier.train( + input_fn=test_utils.make_input_fn( + train_features, train_labels, True, 16 + ) + ) + + test_features, test_labels = test_utils.make_input_data(64, n_classes) + classifier.evaluate( + input_fn=test_utils.make_input_fn(test_features, test_labels, False, 16) + ) + + predict_features, predict_labels = test_utils.make_input_data(64, n_classes) + classifier.predict( + input_fn=test_utils.make_input_fn( + predict_features, predict_labels, False + ) + ) + + +if __name__ == '__main__': + tf.test.main()