forked from 626_privacy/tensorflow_privacy
Correct imports of keras loss utils
PiperOrigin-RevId: 486795765
This commit is contained in:
parent
e334633466
commit
ec747a8d75
3 changed files with 5 additions and 8 deletions
|
@ -14,7 +14,6 @@
|
|||
"""Binary class head for Estimator that allow integration with TF Privacy."""
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow_estimator.python.estimator import model_fn
|
||||
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||
from tensorflow_estimator.python.estimator.export import export_output
|
||||
|
@ -55,7 +54,7 @@ class DPBinaryClassHead(binary_class_head.BinaryClassHead):
|
|||
labels = self._processed_labels(logits, labels)
|
||||
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
||||
logits, labels, features)
|
||||
vector_training_loss = losses_utils.compute_weighted_loss(
|
||||
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
|
||||
unweighted_loss,
|
||||
sample_weight=weights,
|
||||
reduction=tf.keras.losses.Reduction.NONE)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
"""Multiclass head for Estimator that allow integration with TF Privacy."""
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow_estimator.python.estimator import model_fn
|
||||
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||
from tensorflow_estimator.python.estimator.export import export_output
|
||||
|
@ -30,14 +29,14 @@ class DPMultiClassHead(multi_class_head.MultiClassHead):
|
|||
n_classes,
|
||||
weight_column=None,
|
||||
label_vocabulary=None,
|
||||
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||
loss_fn=None,
|
||||
name=None):
|
||||
super().__init__(
|
||||
n_classes=n_classes,
|
||||
weight_column=weight_column,
|
||||
label_vocabulary=label_vocabulary,
|
||||
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||
loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
|
||||
loss_fn=loss_fn,
|
||||
name=name)
|
||||
|
||||
|
@ -55,7 +54,7 @@ class DPMultiClassHead(multi_class_head.MultiClassHead):
|
|||
labels = self._processed_labels(logits, labels)
|
||||
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
||||
logits, labels, features)
|
||||
vector_training_loss = losses_utils.compute_weighted_loss(
|
||||
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
|
||||
unweighted_loss,
|
||||
sample_weight=weights,
|
||||
reduction=tf.keras.losses.Reduction.NONE)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
"""Multiclass head for Estimator that allow integration with TF Privacy."""
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow_estimator.python.estimator import model_fn
|
||||
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||
from tensorflow_estimator.python.estimator.export import export_output
|
||||
|
@ -61,7 +60,7 @@ class DPMultiLabelHead(multi_label_head.MultiLabelHead):
|
|||
labels = self._processed_labels(logits, labels)
|
||||
unweighted_loss, weights = self._unweighted_loss_and_weights(
|
||||
logits, labels, features)
|
||||
vector_training_loss = losses_utils.compute_weighted_loss(
|
||||
vector_training_loss = tf.keras.__internal__.losses.compute_weighted_loss(
|
||||
unweighted_loss,
|
||||
sample_weight=weights,
|
||||
reduction=tf.keras.losses.Reduction.NONE)
|
||||
|
|
Loading…
Reference in a new issue