forked from 626_privacy/tensorflow_privacy
Fix bug in v1 estimators that was preventing use of microbatches.
PiperOrigin-RevId: 560765153
This commit is contained in:
parent
b4b47b1403
commit
372c934d14
3 changed files with 51 additions and 21 deletions
|
@ -20,16 +20,20 @@ from tensorflow_privacy.privacy.estimators import test_utils
|
||||||
from tensorflow_privacy.privacy.estimators.v1 import dnn
|
from tensorflow_privacy.privacy.estimators.v1 import dnn
|
||||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
|
||||||
|
|
||||||
|
# pylint: disable=g-deprecated-tf-checker
|
||||||
|
|
||||||
|
|
||||||
class DPDNNClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
class DPDNNClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
"""Tests for DP-enabled DNNClassifier."""
|
"""Tests for DP-enabled DNNClassifier."""
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('BinaryClassDNN', 2),
|
('BinaryClassDNN', 2, 1),
|
||||||
('MultiClassDNN 3', 3),
|
('BinaryClassDNN 4', 2, 4),
|
||||||
('MultiClassDNN 4', 4),
|
('MultiClassDNN 3', 3, 1),
|
||||||
|
('MultiClassDNN 4', 4, 1),
|
||||||
|
('MultiClassDNN 4 4', 4, 4),
|
||||||
)
|
)
|
||||||
def testDNN(self, n_classes):
|
def testDNN(self, n_classes, num_microbatches):
|
||||||
train_features, train_labels = test_utils.make_input_data(256, n_classes)
|
train_features, train_labels = test_utils.make_input_data(256, n_classes)
|
||||||
feature_columns = []
|
feature_columns = []
|
||||||
for key in train_features:
|
for key in train_features:
|
||||||
|
@ -40,7 +44,8 @@ class DPDNNClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
learning_rate=0.5,
|
learning_rate=0.5,
|
||||||
l2_norm_clip=1.0,
|
l2_norm_clip=1.0,
|
||||||
noise_multiplier=0.0,
|
noise_multiplier=0.0,
|
||||||
num_microbatches=1)
|
num_microbatches=num_microbatches,
|
||||||
|
)
|
||||||
|
|
||||||
classifier = dnn.DNNClassifier(
|
classifier = dnn.DNNClassifier(
|
||||||
hidden_units=[10],
|
hidden_units=[10],
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.python.ops import lookup_ops # pylint: disable=g-direct-tensorflow-import
|
from tensorflow.python.ops import lookup_ops # pylint: disable=g-direct-tensorflow-import
|
||||||
|
# pylint: disable=g-deprecated-tf-checker
|
||||||
from tensorflow_estimator.python.estimator import model_fn
|
from tensorflow_estimator.python.estimator import model_fn
|
||||||
from tensorflow_estimator.python.estimator.canned import head as head_lib
|
from tensorflow_estimator.python.estimator.canned import head as head_lib
|
||||||
from tensorflow_estimator.python.estimator.canned import metric_keys
|
from tensorflow_estimator.python.estimator.canned import metric_keys
|
||||||
|
@ -23,6 +24,7 @@ from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||||
from tensorflow_estimator.python.estimator.export import export_output
|
from tensorflow_estimator.python.estimator.export import export_output
|
||||||
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
|
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
|
||||||
|
|
||||||
|
|
||||||
# Collect together all protected access items needed from base head.
|
# Collect together all protected access items needed from base head.
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
_DEFAULT_SERVING_KEY = head_lib._DEFAULT_SERVING_KEY
|
_DEFAULT_SERVING_KEY = head_lib._DEFAULT_SERVING_KEY
|
||||||
|
@ -39,8 +41,12 @@ _create_eval_metrics_tuple = head_lib._create_eval_metrics_tuple
|
||||||
_summary_key = head_lib._summary_key
|
_summary_key = head_lib._summary_key
|
||||||
_validate_loss_fn_args = head_lib._validate_loss_fn_args
|
_validate_loss_fn_args = head_lib._validate_loss_fn_args
|
||||||
|
|
||||||
_BaseBinaryLogisticHeadWithSigmoidCrossEntropyLoss = head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss
|
_BaseBinaryLogisticHeadWithSigmoidCrossEntropyLoss = (
|
||||||
_BaseMultiClassHeadWithSoftmaxCrossEntropyLoss = head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss
|
head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss
|
||||||
|
)
|
||||||
|
_BaseMultiClassHeadWithSoftmaxCrossEntropyLoss = (
|
||||||
|
head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss
|
||||||
|
)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,25 +152,33 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(
|
||||||
classifier_output = _classification_output(
|
classifier_output = _classification_output(
|
||||||
scores=probabilities,
|
scores=probabilities,
|
||||||
n_classes=self._n_classes,
|
n_classes=self._n_classes,
|
||||||
label_vocabulary=self._label_vocabulary)
|
label_vocabulary=self._label_vocabulary,
|
||||||
|
)
|
||||||
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
|
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
|
||||||
mode=ModeKeys.PREDICT,
|
mode=ModeKeys.PREDICT,
|
||||||
predictions=predictions,
|
predictions=predictions,
|
||||||
export_outputs={
|
export_outputs={
|
||||||
_DEFAULT_SERVING_KEY: classifier_output,
|
_DEFAULT_SERVING_KEY: classifier_output,
|
||||||
_CLASSIFY_SERVING_KEY: classifier_output,
|
_CLASSIFY_SERVING_KEY: classifier_output,
|
||||||
_PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)
|
_PREDICT_SERVING_KEY: export_output.PredictOutput(predictions),
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
training_loss, unreduced_loss, weights, label_ids = self.create_loss(
|
training_loss, unreduced_loss, weights, label_ids = self.create_loss(
|
||||||
features=features, mode=mode, logits=logits, labels=labels)
|
features=features, mode=mode, logits=logits, labels=labels
|
||||||
|
)
|
||||||
if regularization_losses:
|
if regularization_losses:
|
||||||
regularization_loss = tf.math.add_n(regularization_losses)
|
regularization_loss = tf.math.add_n(regularization_losses)
|
||||||
regularized_training_loss = tf.math.add_n(
|
regularized_training_loss = tf.math.add_n(
|
||||||
[training_loss, regularization_loss])
|
[training_loss, regularization_loss]
|
||||||
|
)
|
||||||
|
unreduced_regularized_training_loss = tf.math.add(
|
||||||
|
unreduced_loss, regularization_loss
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
regularization_loss = None
|
regularization_loss = None
|
||||||
regularized_training_loss = training_loss
|
regularized_training_loss = training_loss
|
||||||
|
unreduced_regularized_training_loss = unreduced_loss
|
||||||
|
|
||||||
if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE:
|
if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE:
|
||||||
scalar_loss = tf.reduce_mean(regularized_training_loss)
|
scalar_loss = tf.reduce_mean(regularized_training_loss)
|
||||||
|
@ -191,8 +205,10 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(
|
||||||
if train_op_fn is not None:
|
if train_op_fn is not None:
|
||||||
raise ValueError('train_op_fn and optimizer cannot both be set.')
|
raise ValueError('train_op_fn and optimizer cannot both be set.')
|
||||||
train_op = optimizer.minimize(
|
train_op = optimizer.minimize(
|
||||||
regularized_training_loss,
|
# regularized_training_loss,
|
||||||
global_step=tf.compat.v1.train.get_global_step())
|
unreduced_regularized_training_loss,
|
||||||
|
global_step=tf.compat.v1.train.get_global_step(),
|
||||||
|
)
|
||||||
elif train_op_fn is not None:
|
elif train_op_fn is not None:
|
||||||
train_op = train_op_fn(regularized_training_loss)
|
train_op = train_op_fn(regularized_training_loss)
|
||||||
else:
|
else:
|
||||||
|
@ -352,9 +368,13 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
|
||||||
regularization_loss = tf.math.add_n(regularization_losses)
|
regularization_loss = tf.math.add_n(regularization_losses)
|
||||||
regularized_training_loss = tf.math.add_n(
|
regularized_training_loss = tf.math.add_n(
|
||||||
[training_loss, regularization_loss])
|
[training_loss, regularization_loss])
|
||||||
|
unreduced_regularized_training_loss = tf.math.add(
|
||||||
|
unreduced_loss, regularization_loss
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
regularization_loss = None
|
regularization_loss = None
|
||||||
regularized_training_loss = training_loss
|
regularized_training_loss = training_loss
|
||||||
|
unreduced_regularized_training_loss = unreduced_loss
|
||||||
|
|
||||||
if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE:
|
if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE:
|
||||||
scalar_loss = tf.reduce_mean(regularized_training_loss)
|
scalar_loss = tf.reduce_mean(regularized_training_loss)
|
||||||
|
@ -382,8 +402,9 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
|
||||||
if train_op_fn is not None:
|
if train_op_fn is not None:
|
||||||
raise ValueError('train_op_fn and optimizer cannot both be set.')
|
raise ValueError('train_op_fn and optimizer cannot both be set.')
|
||||||
train_op = optimizer.minimize(
|
train_op = optimizer.minimize(
|
||||||
regularized_training_loss,
|
unreduced_regularized_training_loss,
|
||||||
global_step=tf.compat.v1.train.get_global_step())
|
global_step=tf.compat.v1.train.get_global_step(),
|
||||||
|
)
|
||||||
elif train_op_fn is not None:
|
elif train_op_fn is not None:
|
||||||
train_op = train_op_fn(regularized_training_loss)
|
train_op = train_op_fn(regularized_training_loss)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -21,6 +21,8 @@ from tensorflow_privacy.privacy.estimators import test_utils
|
||||||
from tensorflow_privacy.privacy.estimators.v1 import linear
|
from tensorflow_privacy.privacy.estimators.v1 import linear
|
||||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
|
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
|
||||||
|
|
||||||
|
# pylint: disable=g-deprecated-tf-checker
|
||||||
|
|
||||||
|
|
||||||
class DPLinearClassifierClassifierTest(
|
class DPLinearClassifierClassifierTest(
|
||||||
tf.test.TestCase, parameterized.TestCase
|
tf.test.TestCase, parameterized.TestCase
|
||||||
|
@ -28,11 +30,13 @@ class DPLinearClassifierClassifierTest(
|
||||||
"""Tests for DP-enabled LinearClassifier."""
|
"""Tests for DP-enabled LinearClassifier."""
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('BinaryClassLinear', 2),
|
('BinaryClassLinear 1', 2, 1),
|
||||||
('MultiClassLinear 3', 3),
|
('BinaryClassLinear 4', 2, 4),
|
||||||
('MultiClassLinear 4', 4),
|
('MultiClassLinear 3', 3, 1),
|
||||||
|
('MultiClassLinear 4', 4, 1),
|
||||||
|
('MultiClassLinear 4 1', 4, 2),
|
||||||
)
|
)
|
||||||
def testLinearClassifier(self, n_classes):
|
def testRunsWithoutErrors(self, n_classes, num_microbatches):
|
||||||
train_features, train_labels = test_utils.make_input_data(256, n_classes)
|
train_features, train_labels = test_utils.make_input_data(256, n_classes)
|
||||||
feature_columns = []
|
feature_columns = []
|
||||||
for key in train_features:
|
for key in train_features:
|
||||||
|
@ -43,7 +47,7 @@ class DPLinearClassifierClassifierTest(
|
||||||
learning_rate=0.5,
|
learning_rate=0.5,
|
||||||
l2_norm_clip=1.0,
|
l2_norm_clip=1.0,
|
||||||
noise_multiplier=0.0,
|
noise_multiplier=0.0,
|
||||||
num_microbatches=1,
|
num_microbatches=num_microbatches,
|
||||||
)
|
)
|
||||||
|
|
||||||
classifier = linear.LinearClassifier(
|
classifier = linear.LinearClassifier(
|
||||||
|
|
Loading…
Reference in a new issue