Correct imports of keras loss utils

PiperOrigin-RevId: 486795765
This commit is contained in:
A. Unique TensorFlower 2022-11-07 16:33:28 -08:00
parent e334633466
commit ec747a8d75
3 changed files with 5 additions and 8 deletions

View file

@ -14,7 +14,6 @@
"""Binary class head for Estimator that allow integration with TF Privacy."""
import tensorflow as tf
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
from tensorflow_estimator.python.estimator import model_fn
from tensorflow_estimator.python.estimator.canned import prediction_keys
from tensorflow_estimator.python.estimator.export import export_output
@ -55,7 +54,7 @@ class DPBinaryClassHead(binary_class_head.BinaryClassHead):
labels = self._processed_labels(logits, labels)
unweighted_loss, weights = self._unweighted_loss_and_weights(
logits, labels, features)
vector_training_loss = losses_utils.compute_weighted_loss(
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
unweighted_loss,
sample_weight=weights,
reduction=tf.keras.losses.Reduction.NONE)

View file

@ -14,7 +14,6 @@
"""Multiclass head for Estimator that allow integration with TF Privacy."""
import tensorflow as tf
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
from tensorflow_estimator.python.estimator import model_fn
from tensorflow_estimator.python.estimator.canned import prediction_keys
from tensorflow_estimator.python.estimator.export import export_output
@ -30,14 +29,14 @@ class DPMultiClassHead(multi_class_head.MultiClassHead):
n_classes,
weight_column=None,
label_vocabulary=None,
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=None,
name=None):
super().__init__(
n_classes=n_classes,
weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=loss_fn,
name=name)
@ -55,7 +54,7 @@ class DPMultiClassHead(multi_class_head.MultiClassHead):
labels = self._processed_labels(logits, labels)
unweighted_loss, weights = self._unweighted_loss_and_weights(
logits, labels, features)
vector_training_loss = losses_utils.compute_weighted_loss(
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
unweighted_loss,
sample_weight=weights,
reduction=tf.keras.losses.Reduction.NONE)

View file

@ -14,7 +14,6 @@
"""Multiclass head for Estimator that allow integration with TF Privacy."""
import tensorflow as tf
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
from tensorflow_estimator.python.estimator import model_fn
from tensorflow_estimator.python.estimator.canned import prediction_keys
from tensorflow_estimator.python.estimator.export import export_output
@ -61,7 +60,7 @@ class DPMultiLabelHead(multi_label_head.MultiLabelHead):
labels = self._processed_labels(logits, labels)
unweighted_loss, weights = self._unweighted_loss_and_weights(
logits, labels, features)
vector_training_loss = losses_utils.compute_weighted_loss(
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
unweighted_loss,
sample_weight=weights,
reduction=tf.keras.losses.Reduction.NONE)