Correct imports of keras loss utils
PiperOrigin-RevId: 486795765
This commit is contained in:
parent
e334633466
commit
ec747a8d75
3 changed files with 5 additions and 8 deletions
|
@ -14,7 +14,6 @@
|
||||||
"""Binary class head for Estimator that allow integration with TF Privacy."""
|
"""Binary class head for Estimator that allow integration with TF Privacy."""
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
|
||||||
from tensorflow_estimator.python.estimator import model_fn
|
from tensorflow_estimator.python.estimator import model_fn
|
||||||
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||||
from tensorflow_estimator.python.estimator.export import export_output
|
from tensorflow_estimator.python.estimator.export import export_output
|
||||||
|
@ -55,7 +54,7 @@ class DPBinaryClassHead(binary_class_head.BinaryClassHead):
|
||||||
labels = self._processed_labels(logits, labels)
|
labels = self._processed_labels(logits, labels)
|
||||||
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
||||||
logits, labels, features)
|
logits, labels, features)
|
||||||
vector_training_loss = losses_utils.compute_weighted_loss(
|
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
|
||||||
unweighted_loss,
|
unweighted_loss,
|
||||||
sample_weight=weights,
|
sample_weight=weights,
|
||||||
reduction=tf.keras.losses.Reduction.NONE)
|
reduction=tf.keras.losses.Reduction.NONE)
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
"""Multiclass head for Estimator that allow integration with TF Privacy."""
|
"""Multiclass head for Estimator that allow integration with TF Privacy."""
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
|
||||||
from tensorflow_estimator.python.estimator import model_fn
|
from tensorflow_estimator.python.estimator import model_fn
|
||||||
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||||
from tensorflow_estimator.python.estimator.export import export_output
|
from tensorflow_estimator.python.estimator.export import export_output
|
||||||
|
@ -30,14 +29,14 @@ class DPMultiClassHead(multi_class_head.MultiClassHead):
|
||||||
n_classes,
|
n_classes,
|
||||||
weight_column=None,
|
weight_column=None,
|
||||||
label_vocabulary=None,
|
label_vocabulary=None,
|
||||||
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||||
loss_fn=None,
|
loss_fn=None,
|
||||||
name=None):
|
name=None):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
n_classes=n_classes,
|
n_classes=n_classes,
|
||||||
weight_column=weight_column,
|
weight_column=weight_column,
|
||||||
label_vocabulary=label_vocabulary,
|
label_vocabulary=label_vocabulary,
|
||||||
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
@ -55,7 +54,7 @@ class DPMultiClassHead(multi_class_head.MultiClassHead):
|
||||||
labels = self._processed_labels(logits, labels)
|
labels = self._processed_labels(logits, labels)
|
||||||
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
||||||
logits, labels, features)
|
logits, labels, features)
|
||||||
vector_training_loss = losses_utils.compute_weighted_loss(
|
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
|
||||||
unweighted_loss,
|
unweighted_loss,
|
||||||
sample_weight=weights,
|
sample_weight=weights,
|
||||||
reduction=tf.keras.losses.Reduction.NONE)
|
reduction=tf.keras.losses.Reduction.NONE)
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
"""Multiclass head for Estimator that allow integration with TF Privacy."""
|
"""Multiclass head for Estimator that allow integration with TF Privacy."""
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
|
||||||
from tensorflow_estimator.python.estimator import model_fn
|
from tensorflow_estimator.python.estimator import model_fn
|
||||||
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||||
from tensorflow_estimator.python.estimator.export import export_output
|
from tensorflow_estimator.python.estimator.export import export_output
|
||||||
|
@ -61,7 +60,7 @@ class DPMultiLabelHead(multi_label_head.MultiLabelHead):
|
||||||
labels = self._processed_labels(logits, labels)
|
labels = self._processed_labels(logits, labels)
|
||||||
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
||||||
logits, labels, features)
|
logits, labels, features)
|
||||||
vector_training_loss = losses_utils.compute_weighted_loss(
|
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
|
||||||
unweighted_loss,
|
unweighted_loss,
|
||||||
sample_weight=weights,
|
sample_weight=weights,
|
||||||
reduction=tf.keras.losses.Reduction.NONE)
|
reduction=tf.keras.losses.Reduction.NONE)
|
||||||
|
|
Loading…
Reference in a new issue