Add DP-SGD version of v1 LinearClassifier.

PiperOrigin-RevId: 551350685
This commit is contained in:
Steve Chien 2023-07-26 16:38:09 -07:00 committed by A. Unique TensorFlower
parent 225355258c
commit 134c898ded
4 changed files with 171 additions and 1 deletions

View file

@ -26,6 +26,15 @@ py_library(
deps = [":head"],
)
py_library(
name = "linear",
srcs = [
"linear.py",
],
srcs_version = "PY3",
deps = [":head"],
)
py_test(
name = "head_test",
timeout = "long",
@ -51,3 +60,16 @@ py_test(
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
],
)
py_test(
name = "linear_test",
timeout = "long",
srcs = ["linear_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":linear",
"//tensorflow_privacy/privacy/estimators:test_utils",
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
],
)

View file

@ -38,7 +38,7 @@ class DNNClassifier(estimator.Estimator):
input_layer_partitioner=None,
config=None,
warm_start_from=None,
loss_reduction=tf.compat.v1.losses.Reduction.SUM,
loss_reduction=tf.compat.v1.losses.Reduction.SUM, # For scalar summary.
batch_norm=False,
):
"""See `tf.compat.v1.estimator.DNNClassifier`."""

View file

@ -0,0 +1,72 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DP version of LinearClassifier v1."""
import tensorflow as tf
from tensorflow_privacy.privacy.estimators.v1 import head as head_lib
from tensorflow_estimator.python.estimator import estimator # pylint: disable=g-deprecated-tf-checker
from tensorflow_estimator.python.estimator.canned import linear # pylint: disable=g-deprecated-tf-checker
class LinearClassifier(estimator.Estimator):
"""DP version of `tf.compat.v1.estimator.LinearClassifier`."""
def __init__(
self,
feature_columns,
model_dir=None,
n_classes=2,
weight_column=None,
label_vocabulary=None,
optimizer='Ftrl',
config=None,
partitioner=None,
warm_start_from=None,
loss_reduction=tf.compat.v1.losses.Reduction.SUM, # For scalar summary.
sparse_combiner='sum',
):
"""See `tf.compat.v1.estimator.LinearClassifier`."""
linear._validate_linear_sdca_optimizer_for_linear_classifier( # pylint: disable=protected-access
feature_columns=feature_columns,
n_classes=n_classes,
optimizer=optimizer,
sparse_combiner=sparse_combiner,
)
estimator._canned_estimator_api_gauge.get_cell('Classifier').set('Linear') # pylint: disable=protected-access
head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access
n_classes, weight_column, label_vocabulary, loss_reduction
)
def _model_fn(features, labels, mode, config):
"""Call the defined shared _linear_model_fn."""
return linear._linear_model_fn( # pylint: disable=protected-access
features=features,
labels=labels,
mode=mode,
head=head,
feature_columns=tuple(feature_columns or []),
optimizer=optimizer,
partitioner=partitioner,
config=config,
sparse_combiner=sparse_combiner,
)
super(LinearClassifier, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
config=config,
warm_start_from=warm_start_from,
)

View file

@ -0,0 +1,76 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for LinearClassifier."""
import functools
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.estimators import test_utils
from tensorflow_privacy.privacy.estimators.v1 import linear
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
class DPLinearClassifierClassifierTest(
tf.test.TestCase, parameterized.TestCase
):
"""Tests for DP-enabled LinearClassifier."""
@parameterized.named_parameters(
('BinaryClassLinear', 2),
('MultiClassLinear 3', 3),
('MultiClassLinear 4', 4),
)
def testLinearClassifier(self, n_classes):
train_features, train_labels = test_utils.make_input_data(256, n_classes)
feature_columns = []
for key in train_features:
feature_columns.append(tf.feature_column.numeric_column(key=key)) # pylint: disable=g-deprecated-tf-checker
optimizer = functools.partial(
DPGradientDescentGaussianOptimizer,
learning_rate=0.5,
l2_norm_clip=1.0,
noise_multiplier=0.0,
num_microbatches=1,
)
classifier = linear.LinearClassifier(
feature_columns=feature_columns,
n_classes=n_classes,
optimizer=optimizer,
loss_reduction=tf.compat.v1.losses.Reduction.SUM,
)
classifier.train(
input_fn=test_utils.make_input_fn(
train_features, train_labels, True, 16
)
)
test_features, test_labels = test_utils.make_input_data(64, n_classes)
classifier.evaluate(
input_fn=test_utils.make_input_fn(test_features, test_labels, False, 16)
)
predict_features, predict_labels = test_utils.make_input_data(64, n_classes)
classifier.predict(
input_fn=test_utils.make_input_fn(
predict_features, predict_labels, False
)
)
if __name__ == '__main__':
tf.test.main()