Add DP-SGD version of v1 LinearClassifier.
PiperOrigin-RevId: 551350685
This commit is contained in:
parent
225355258c
commit
134c898ded
4 changed files with 171 additions and 1 deletions
|
@ -26,6 +26,15 @@ py_library(
|
||||||
deps = [":head"],
|
deps = [":head"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "linear",
|
||||||
|
srcs = [
|
||||||
|
"linear.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":head"],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "head_test",
|
name = "head_test",
|
||||||
timeout = "long",
|
timeout = "long",
|
||||||
|
@ -51,3 +60,16 @@ py_test(
|
||||||
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
|
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "linear_test",
|
||||||
|
timeout = "long",
|
||||||
|
srcs = ["linear_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":linear",
|
||||||
|
"//tensorflow_privacy/privacy/estimators:test_utils",
|
||||||
|
"//tensorflow_privacy/privacy/optimizers:dp_optimizer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -38,7 +38,7 @@ class DNNClassifier(estimator.Estimator):
|
||||||
input_layer_partitioner=None,
|
input_layer_partitioner=None,
|
||||||
config=None,
|
config=None,
|
||||||
warm_start_from=None,
|
warm_start_from=None,
|
||||||
loss_reduction=tf.compat.v1.losses.Reduction.SUM,
|
loss_reduction=tf.compat.v1.losses.Reduction.SUM, # For scalar summary.
|
||||||
batch_norm=False,
|
batch_norm=False,
|
||||||
):
|
):
|
||||||
"""See `tf.compat.v1.estimator.DNNClassifier`."""
|
"""See `tf.compat.v1.estimator.DNNClassifier`."""
|
||||||
|
|
72
tensorflow_privacy/privacy/estimators/v1/linear.py
Normal file
72
tensorflow_privacy/privacy/estimators/v1/linear.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""DP version of LinearClassifier v1."""
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.estimators.v1 import head as head_lib
|
||||||
|
from tensorflow_estimator.python.estimator import estimator # pylint: disable=g-deprecated-tf-checker
|
||||||
|
from tensorflow_estimator.python.estimator.canned import linear # pylint: disable=g-deprecated-tf-checker
|
||||||
|
|
||||||
|
|
||||||
|
class LinearClassifier(estimator.Estimator):
|
||||||
|
"""DP version of `tf.compat.v1.estimator.LinearClassifier`."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_columns,
|
||||||
|
model_dir=None,
|
||||||
|
n_classes=2,
|
||||||
|
weight_column=None,
|
||||||
|
label_vocabulary=None,
|
||||||
|
optimizer='Ftrl',
|
||||||
|
config=None,
|
||||||
|
partitioner=None,
|
||||||
|
warm_start_from=None,
|
||||||
|
loss_reduction=tf.compat.v1.losses.Reduction.SUM, # For scalar summary.
|
||||||
|
sparse_combiner='sum',
|
||||||
|
):
|
||||||
|
"""See `tf.compat.v1.estimator.LinearClassifier`."""
|
||||||
|
linear._validate_linear_sdca_optimizer_for_linear_classifier( # pylint: disable=protected-access
|
||||||
|
feature_columns=feature_columns,
|
||||||
|
n_classes=n_classes,
|
||||||
|
optimizer=optimizer,
|
||||||
|
sparse_combiner=sparse_combiner,
|
||||||
|
)
|
||||||
|
estimator._canned_estimator_api_gauge.get_cell('Classifier').set('Linear') # pylint: disable=protected-access
|
||||||
|
|
||||||
|
head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access
|
||||||
|
n_classes, weight_column, label_vocabulary, loss_reduction
|
||||||
|
)
|
||||||
|
|
||||||
|
def _model_fn(features, labels, mode, config):
|
||||||
|
"""Call the defined shared _linear_model_fn."""
|
||||||
|
return linear._linear_model_fn( # pylint: disable=protected-access
|
||||||
|
features=features,
|
||||||
|
labels=labels,
|
||||||
|
mode=mode,
|
||||||
|
head=head,
|
||||||
|
feature_columns=tuple(feature_columns or []),
|
||||||
|
optimizer=optimizer,
|
||||||
|
partitioner=partitioner,
|
||||||
|
config=config,
|
||||||
|
sparse_combiner=sparse_combiner,
|
||||||
|
)
|
||||||
|
|
||||||
|
super(LinearClassifier, self).__init__(
|
||||||
|
model_fn=_model_fn,
|
||||||
|
model_dir=model_dir,
|
||||||
|
config=config,
|
||||||
|
warm_start_from=warm_start_from,
|
||||||
|
)
|
76
tensorflow_privacy/privacy/estimators/v1/linear_test.py
Normal file
76
tensorflow_privacy/privacy/estimators/v1/linear_test.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for LinearClassifier."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow_privacy.privacy.estimators import test_utils
|
||||||
|
from tensorflow_privacy.privacy.estimators.v1 import linear
|
||||||
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
|
||||||
|
|
||||||
|
|
||||||
|
class DPLinearClassifierClassifierTest(
|
||||||
|
tf.test.TestCase, parameterized.TestCase
|
||||||
|
):
|
||||||
|
"""Tests for DP-enabled LinearClassifier."""
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('BinaryClassLinear', 2),
|
||||||
|
('MultiClassLinear 3', 3),
|
||||||
|
('MultiClassLinear 4', 4),
|
||||||
|
)
|
||||||
|
def testLinearClassifier(self, n_classes):
|
||||||
|
train_features, train_labels = test_utils.make_input_data(256, n_classes)
|
||||||
|
feature_columns = []
|
||||||
|
for key in train_features:
|
||||||
|
feature_columns.append(tf.feature_column.numeric_column(key=key)) # pylint: disable=g-deprecated-tf-checker
|
||||||
|
|
||||||
|
optimizer = functools.partial(
|
||||||
|
DPGradientDescentGaussianOptimizer,
|
||||||
|
learning_rate=0.5,
|
||||||
|
l2_norm_clip=1.0,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
classifier = linear.LinearClassifier(
|
||||||
|
feature_columns=feature_columns,
|
||||||
|
n_classes=n_classes,
|
||||||
|
optimizer=optimizer,
|
||||||
|
loss_reduction=tf.compat.v1.losses.Reduction.SUM,
|
||||||
|
)
|
||||||
|
|
||||||
|
classifier.train(
|
||||||
|
input_fn=test_utils.make_input_fn(
|
||||||
|
train_features, train_labels, True, 16
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_features, test_labels = test_utils.make_input_data(64, n_classes)
|
||||||
|
classifier.evaluate(
|
||||||
|
input_fn=test_utils.make_input_fn(test_features, test_labels, False, 16)
|
||||||
|
)
|
||||||
|
|
||||||
|
predict_features, predict_labels = test_utils.make_input_data(64, n_classes)
|
||||||
|
classifier.predict(
|
||||||
|
input_fn=test_utils.make_input_fn(
|
||||||
|
predict_features, predict_labels, False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in a new issue