Add DP-SGD version of v1 LinearClassifier.
PiperOrigin-RevId: 551350685
This commit is contained in:
parent
225355258c
commit
134c898ded
4 changed files with 171 additions and 1 deletions
|
@ -26,6 +26,15 @@ py_library(
|
|||
deps = [":head"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "linear",
|
||||
srcs = [
|
||||
"linear.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [":head"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "head_test",
|
||||
timeout = "long",
|
||||
|
@ -51,3 +60,16 @@ py_test(
|
|||
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "linear_test",
|
||||
timeout = "long",
|
||||
srcs = ["linear_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":linear",
|
||||
"//tensorflow_privacy/privacy/estimators:test_utils",
|
||||
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -38,7 +38,7 @@ class DNNClassifier(estimator.Estimator):
|
|||
input_layer_partitioner=None,
|
||||
config=None,
|
||||
warm_start_from=None,
|
||||
loss_reduction=tf.compat.v1.losses.Reduction.SUM,
|
||||
loss_reduction=tf.compat.v1.losses.Reduction.SUM, # For scalar summary.
|
||||
batch_norm=False,
|
||||
):
|
||||
"""See `tf.compat.v1.estimator.DNNClassifier`."""
|
||||
|
|
72
tensorflow_privacy/privacy/estimators/v1/linear.py
Normal file
72
tensorflow_privacy/privacy/estimators/v1/linear.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""DP version of LinearClassifier v1."""
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.estimators.v1 import head as head_lib
|
||||
from tensorflow_estimator.python.estimator import estimator # pylint: disable=g-deprecated-tf-checker
|
||||
from tensorflow_estimator.python.estimator.canned import linear # pylint: disable=g-deprecated-tf-checker
|
||||
|
||||
|
||||
class LinearClassifier(estimator.Estimator):
|
||||
"""DP version of `tf.compat.v1.estimator.LinearClassifier`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_columns,
|
||||
model_dir=None,
|
||||
n_classes=2,
|
||||
weight_column=None,
|
||||
label_vocabulary=None,
|
||||
optimizer='Ftrl',
|
||||
config=None,
|
||||
partitioner=None,
|
||||
warm_start_from=None,
|
||||
loss_reduction=tf.compat.v1.losses.Reduction.SUM, # For scalar summary.
|
||||
sparse_combiner='sum',
|
||||
):
|
||||
"""See `tf.compat.v1.estimator.LinearClassifier`."""
|
||||
linear._validate_linear_sdca_optimizer_for_linear_classifier( # pylint: disable=protected-access
|
||||
feature_columns=feature_columns,
|
||||
n_classes=n_classes,
|
||||
optimizer=optimizer,
|
||||
sparse_combiner=sparse_combiner,
|
||||
)
|
||||
estimator._canned_estimator_api_gauge.get_cell('Classifier').set('Linear') # pylint: disable=protected-access
|
||||
|
||||
head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access
|
||||
n_classes, weight_column, label_vocabulary, loss_reduction
|
||||
)
|
||||
|
||||
def _model_fn(features, labels, mode, config):
|
||||
"""Call the defined shared _linear_model_fn."""
|
||||
return linear._linear_model_fn( # pylint: disable=protected-access
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
head=head,
|
||||
feature_columns=tuple(feature_columns or []),
|
||||
optimizer=optimizer,
|
||||
partitioner=partitioner,
|
||||
config=config,
|
||||
sparse_combiner=sparse_combiner,
|
||||
)
|
||||
|
||||
super(LinearClassifier, self).__init__(
|
||||
model_fn=_model_fn,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
warm_start_from=warm_start_from,
|
||||
)
|
76
tensorflow_privacy/privacy/estimators/v1/linear_test.py
Normal file
76
tensorflow_privacy/privacy/estimators/v1/linear_test.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2023, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for LinearClassifier."""
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.estimators import test_utils
|
||||
from tensorflow_privacy.privacy.estimators.v1 import linear
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
|
||||
|
||||
|
||||
class DPLinearClassifierClassifierTest(
|
||||
tf.test.TestCase, parameterized.TestCase
|
||||
):
|
||||
"""Tests for DP-enabled LinearClassifier."""
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('BinaryClassLinear', 2),
|
||||
('MultiClassLinear 3', 3),
|
||||
('MultiClassLinear 4', 4),
|
||||
)
|
||||
def testLinearClassifier(self, n_classes):
|
||||
train_features, train_labels = test_utils.make_input_data(256, n_classes)
|
||||
feature_columns = []
|
||||
for key in train_features:
|
||||
feature_columns.append(tf.feature_column.numeric_column(key=key)) # pylint: disable=g-deprecated-tf-checker
|
||||
|
||||
optimizer = functools.partial(
|
||||
DPGradientDescentGaussianOptimizer,
|
||||
learning_rate=0.5,
|
||||
l2_norm_clip=1.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=1,
|
||||
)
|
||||
|
||||
classifier = linear.LinearClassifier(
|
||||
feature_columns=feature_columns,
|
||||
n_classes=n_classes,
|
||||
optimizer=optimizer,
|
||||
loss_reduction=tf.compat.v1.losses.Reduction.SUM,
|
||||
)
|
||||
|
||||
classifier.train(
|
||||
input_fn=test_utils.make_input_fn(
|
||||
train_features, train_labels, True, 16
|
||||
)
|
||||
)
|
||||
|
||||
test_features, test_labels = test_utils.make_input_data(64, n_classes)
|
||||
classifier.evaluate(
|
||||
input_fn=test_utils.make_input_fn(test_features, test_labels, False, 16)
|
||||
)
|
||||
|
||||
predict_features, predict_labels = test_utils.make_input_data(64, n_classes)
|
||||
classifier.predict(
|
||||
input_fn=test_utils.make_input_fn(
|
||||
predict_features, predict_labels, False
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in a new issue