Fix bug in v1 estimators that was preventing use of microbatches.

PiperOrigin-RevId: 560765153
This commit is contained in:
Steve Chien 2023-08-28 11:13:55 -07:00 committed by A. Unique TensorFlower
parent b4b47b1403
commit 372c934d14
3 changed files with 51 additions and 21 deletions

View file

@ -20,16 +20,20 @@ from tensorflow_privacy.privacy.estimators import test_utils
from tensorflow_privacy.privacy.estimators.v1 import dnn from tensorflow_privacy.privacy.estimators.v1 import dnn
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
# pylint: disable=g-deprecated-tf-checker
class DPDNNClassifierTest(tf.test.TestCase, parameterized.TestCase): class DPDNNClassifierTest(tf.test.TestCase, parameterized.TestCase):
"""Tests for DP-enabled DNNClassifier.""" """Tests for DP-enabled DNNClassifier."""
@parameterized.named_parameters( @parameterized.named_parameters(
('BinaryClassDNN', 2), ('BinaryClassDNN', 2, 1),
('MultiClassDNN 3', 3), ('BinaryClassDNN 4', 2, 4),
('MultiClassDNN 4', 4), ('MultiClassDNN 3', 3, 1),
('MultiClassDNN 4', 4, 1),
('MultiClassDNN 4 4', 4, 4),
) )
def testDNN(self, n_classes): def testDNN(self, n_classes, num_microbatches):
train_features, train_labels = test_utils.make_input_data(256, n_classes) train_features, train_labels = test_utils.make_input_data(256, n_classes)
feature_columns = [] feature_columns = []
for key in train_features: for key in train_features:
@ -40,7 +44,8 @@ class DPDNNClassifierTest(tf.test.TestCase, parameterized.TestCase):
learning_rate=0.5, learning_rate=0.5,
l2_norm_clip=1.0, l2_norm_clip=1.0,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=1) num_microbatches=num_microbatches,
)
classifier = dnn.DNNClassifier( classifier = dnn.DNNClassifier(
hidden_units=[10], hidden_units=[10],

View file

@ -16,6 +16,7 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.python.ops import lookup_ops # pylint: disable=g-direct-tensorflow-import from tensorflow.python.ops import lookup_ops # pylint: disable=g-direct-tensorflow-import
# pylint: disable=g-deprecated-tf-checker
from tensorflow_estimator.python.estimator import model_fn from tensorflow_estimator.python.estimator import model_fn
from tensorflow_estimator.python.estimator.canned import head as head_lib from tensorflow_estimator.python.estimator.canned import head as head_lib
from tensorflow_estimator.python.estimator.canned import metric_keys from tensorflow_estimator.python.estimator.canned import metric_keys
@ -23,6 +24,7 @@ from tensorflow_estimator.python.estimator.canned import prediction_keys
from tensorflow_estimator.python.estimator.export import export_output from tensorflow_estimator.python.estimator.export import export_output
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
# Collect together all protected access items needed from base head. # Collect together all protected access items needed from base head.
# pylint: disable=protected-access # pylint: disable=protected-access
_DEFAULT_SERVING_KEY = head_lib._DEFAULT_SERVING_KEY _DEFAULT_SERVING_KEY = head_lib._DEFAULT_SERVING_KEY
@ -39,8 +41,12 @@ _create_eval_metrics_tuple = head_lib._create_eval_metrics_tuple
_summary_key = head_lib._summary_key _summary_key = head_lib._summary_key
_validate_loss_fn_args = head_lib._validate_loss_fn_args _validate_loss_fn_args = head_lib._validate_loss_fn_args
_BaseBinaryLogisticHeadWithSigmoidCrossEntropyLoss = head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss _BaseBinaryLogisticHeadWithSigmoidCrossEntropyLoss = (
_BaseMultiClassHeadWithSoftmaxCrossEntropyLoss = head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss
)
_BaseMultiClassHeadWithSoftmaxCrossEntropyLoss = (
head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss
)
# pylint: enable=protected-access # pylint: enable=protected-access
@ -146,25 +152,33 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(
classifier_output = _classification_output( classifier_output = _classification_output(
scores=probabilities, scores=probabilities,
n_classes=self._n_classes, n_classes=self._n_classes,
label_vocabulary=self._label_vocabulary) label_vocabulary=self._label_vocabulary,
)
return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=ModeKeys.PREDICT, mode=ModeKeys.PREDICT,
predictions=predictions, predictions=predictions,
export_outputs={ export_outputs={
_DEFAULT_SERVING_KEY: classifier_output, _DEFAULT_SERVING_KEY: classifier_output,
_CLASSIFY_SERVING_KEY: classifier_output, _CLASSIFY_SERVING_KEY: classifier_output,
_PREDICT_SERVING_KEY: export_output.PredictOutput(predictions) _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions),
}) },
)
training_loss, unreduced_loss, weights, label_ids = self.create_loss( training_loss, unreduced_loss, weights, label_ids = self.create_loss(
features=features, mode=mode, logits=logits, labels=labels) features=features, mode=mode, logits=logits, labels=labels
)
if regularization_losses: if regularization_losses:
regularization_loss = tf.math.add_n(regularization_losses) regularization_loss = tf.math.add_n(regularization_losses)
regularized_training_loss = tf.math.add_n( regularized_training_loss = tf.math.add_n(
[training_loss, regularization_loss]) [training_loss, regularization_loss]
)
unreduced_regularized_training_loss = tf.math.add(
unreduced_loss, regularization_loss
)
else: else:
regularization_loss = None regularization_loss = None
regularized_training_loss = training_loss regularized_training_loss = training_loss
unreduced_regularized_training_loss = unreduced_loss
if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE: if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE:
scalar_loss = tf.reduce_mean(regularized_training_loss) scalar_loss = tf.reduce_mean(regularized_training_loss)
@ -191,8 +205,10 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(
if train_op_fn is not None: if train_op_fn is not None:
raise ValueError('train_op_fn and optimizer cannot both be set.') raise ValueError('train_op_fn and optimizer cannot both be set.')
train_op = optimizer.minimize( train_op = optimizer.minimize(
regularized_training_loss, # regularized_training_loss,
global_step=tf.compat.v1.train.get_global_step()) unreduced_regularized_training_loss,
global_step=tf.compat.v1.train.get_global_step(),
)
elif train_op_fn is not None: elif train_op_fn is not None:
train_op = train_op_fn(regularized_training_loss) train_op = train_op_fn(regularized_training_loss)
else: else:
@ -352,9 +368,13 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
regularization_loss = tf.math.add_n(regularization_losses) regularization_loss = tf.math.add_n(regularization_losses)
regularized_training_loss = tf.math.add_n( regularized_training_loss = tf.math.add_n(
[training_loss, regularization_loss]) [training_loss, regularization_loss])
unreduced_regularized_training_loss = tf.math.add(
unreduced_loss, regularization_loss
)
else: else:
regularization_loss = None regularization_loss = None
regularized_training_loss = training_loss regularized_training_loss = training_loss
unreduced_regularized_training_loss = unreduced_loss
if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE: if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE:
scalar_loss = tf.reduce_mean(regularized_training_loss) scalar_loss = tf.reduce_mean(regularized_training_loss)
@ -382,8 +402,9 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
if train_op_fn is not None: if train_op_fn is not None:
raise ValueError('train_op_fn and optimizer cannot both be set.') raise ValueError('train_op_fn and optimizer cannot both be set.')
train_op = optimizer.minimize( train_op = optimizer.minimize(
regularized_training_loss, unreduced_regularized_training_loss,
global_step=tf.compat.v1.train.get_global_step()) global_step=tf.compat.v1.train.get_global_step(),
)
elif train_op_fn is not None: elif train_op_fn is not None:
train_op = train_op_fn(regularized_training_loss) train_op = train_op_fn(regularized_training_loss)
else: else:

View file

@ -21,6 +21,8 @@ from tensorflow_privacy.privacy.estimators import test_utils
from tensorflow_privacy.privacy.estimators.v1 import linear from tensorflow_privacy.privacy.estimators.v1 import linear
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
# pylint: disable=g-deprecated-tf-checker
class DPLinearClassifierClassifierTest( class DPLinearClassifierClassifierTest(
tf.test.TestCase, parameterized.TestCase tf.test.TestCase, parameterized.TestCase
@ -28,11 +30,13 @@ class DPLinearClassifierClassifierTest(
"""Tests for DP-enabled LinearClassifier.""" """Tests for DP-enabled LinearClassifier."""
@parameterized.named_parameters( @parameterized.named_parameters(
('BinaryClassLinear', 2), ('BinaryClassLinear 1', 2, 1),
('MultiClassLinear 3', 3), ('BinaryClassLinear 4', 2, 4),
('MultiClassLinear 4', 4), ('MultiClassLinear 3', 3, 1),
('MultiClassLinear 4', 4, 1),
('MultiClassLinear 4 1', 4, 2),
) )
def testLinearClassifier(self, n_classes): def testRunsWithoutErrors(self, n_classes, num_microbatches):
train_features, train_labels = test_utils.make_input_data(256, n_classes) train_features, train_labels = test_utils.make_input_data(256, n_classes)
feature_columns = [] feature_columns = []
for key in train_features: for key in train_features:
@ -43,7 +47,7 @@ class DPLinearClassifierClassifierTest(
learning_rate=0.5, learning_rate=0.5,
l2_norm_clip=1.0, l2_norm_clip=1.0,
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=1, num_microbatches=num_microbatches,
) )
classifier = linear.LinearClassifier( classifier = linear.LinearClassifier(