Explicitly import estimator from tensorflow as a separate import instead of

accessing it via tf.estimator and depend on the tensorflow estimator target.

PiperOrigin-RevId: 437818180
This commit is contained in:
Fabien Hertschuh 2022-03-28 12:00:20 -07:00 committed by A. Unique TensorFlower
parent 70ab071e23
commit fc2c15ab21
8 changed files with 24 additions and 8 deletions

View file

@ -27,6 +27,7 @@ py_library(
"binary_class_head.py", "binary_class_head.py",
], ],
srcs_version = "PY3", srcs_version = "PY3",
deps = ["//third_party/py/tensorflow:tensorflow_estimator"],
) )
py_library( py_library(
@ -35,6 +36,7 @@ py_library(
"multi_class_head.py", "multi_class_head.py",
], ],
srcs_version = "PY3", srcs_version = "PY3",
deps = ["//third_party/py/tensorflow:tensorflow_estimator"],
) )
py_library( py_library(
@ -43,6 +45,7 @@ py_library(
"multi_label_head.py", "multi_label_head.py",
], ],
srcs_version = "PY3", srcs_version = "PY3",
deps = ["//third_party/py/tensorflow:tensorflow_estimator"],
) )
py_library( py_library(
@ -51,7 +54,10 @@ py_library(
"dnn.py", "dnn.py",
], ],
srcs_version = "PY3", srcs_version = "PY3",
deps = [":head_utils"], deps = [
":head_utils",
"//third_party/py/tensorflow:tensorflow_estimator",
],
) )
py_library( py_library(
@ -72,6 +78,7 @@ py_test(
":binary_class_head", ":binary_class_head",
":test_utils", ":test_utils",
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", "//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
"//third_party/py/tensorflow:tensorflow_estimator",
], ],
) )
@ -85,6 +92,7 @@ py_test(
":multi_class_head", ":multi_class_head",
":test_utils", ":test_utils",
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", "//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
"//third_party/py/tensorflow:tensorflow_estimator",
], ],
) )
@ -98,6 +106,7 @@ py_test(
":multi_label_head", ":multi_label_head",
":test_utils", ":test_utils",
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", "//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
"//third_party/py/tensorflow:tensorflow_estimator",
], ],
) )

View file

@ -14,6 +14,7 @@
"""Binary class head for Estimator that allow integration with TF Privacy.""" """Binary class head for Estimator that allow integration with TF Privacy."""
import tensorflow as tf import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
from tensorflow_estimator.python.estimator import model_fn from tensorflow_estimator.python.estimator import model_fn
from tensorflow_estimator.python.estimator.canned import prediction_keys from tensorflow_estimator.python.estimator.canned import prediction_keys
@ -22,7 +23,7 @@ from tensorflow_estimator.python.estimator.head import base_head
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
class DPBinaryClassHead(tf.estimator.BinaryClassHead): class DPBinaryClassHead(tf_estimator.BinaryClassHead):
"""Creates a TF Privacy-enabled version of BinaryClassHead.""" """Creates a TF Privacy-enabled version of BinaryClassHead."""
def __init__(self, def __init__(self,

View file

@ -14,6 +14,7 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow_privacy.privacy.estimators import binary_class_head from tensorflow_privacy.privacy.estimators import binary_class_head
from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.estimators import test_utils
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
@ -58,7 +59,7 @@ class DPBinaryClassHeadTest(tf.test.TestCase):
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=2) num_microbatches=2)
model_fn = test_utils.make_model_fn(head, optimizer, feature_columns) model_fn = test_utils.make_model_fn(head, optimizer, feature_columns)
classifier = tf.estimator.Estimator(model_fn=model_fn) classifier = tf_estimator.Estimator(model_fn=model_fn)
classifier.train( classifier.train(
input_fn=test_utils.make_input_fn(train_features, train_labels, True), input_fn=test_utils.make_input_fn(train_features, train_labels, True),

View file

@ -15,12 +15,13 @@
import tensorflow as tf import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow_privacy.privacy.estimators import head_utils from tensorflow_privacy.privacy.estimators import head_utils
from tensorflow_estimator.python.estimator import estimator from tensorflow_estimator.python.estimator import estimator
from tensorflow_estimator.python.estimator.canned import dnn from tensorflow_estimator.python.estimator.canned import dnn
class DNNClassifier(tf.estimator.Estimator): class DNNClassifier(tf_estimator.Estimator):
"""DP version of `tf.estimator.DNNClassifier`.""" """DP version of `tf.estimator.DNNClassifier`."""
def __init__( def __init__(

View file

@ -14,6 +14,7 @@
"""Multiclass head for Estimator that allow integration with TF Privacy.""" """Multiclass head for Estimator that allow integration with TF Privacy."""
import tensorflow as tf import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
from tensorflow_estimator.python.estimator import model_fn from tensorflow_estimator.python.estimator import model_fn
from tensorflow_estimator.python.estimator.canned import prediction_keys from tensorflow_estimator.python.estimator.canned import prediction_keys
@ -22,7 +23,7 @@ from tensorflow_estimator.python.estimator.head import base_head
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
class DPMultiClassHead(tf.estimator.MultiClassHead): class DPMultiClassHead(tf_estimator.MultiClassHead):
"""Creates a TF Privacy-enabled version of MultiClassHead.""" """Creates a TF Privacy-enabled version of MultiClassHead."""
def __init__(self, def __init__(self,

View file

@ -14,6 +14,7 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow_privacy.privacy.estimators import multi_class_head from tensorflow_privacy.privacy.estimators import multi_class_head
from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.estimators import test_utils
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
@ -62,7 +63,7 @@ class DPMultiClassHeadTest(tf.test.TestCase):
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=2) num_microbatches=2)
model_fn = test_utils.make_model_fn(head, optimizer, feature_columns) model_fn = test_utils.make_model_fn(head, optimizer, feature_columns)
classifier = tf.estimator.Estimator(model_fn=model_fn) classifier = tf_estimator.Estimator(model_fn=model_fn)
classifier.train( classifier.train(
input_fn=test_utils.make_input_fn(train_features, train_labels, True), input_fn=test_utils.make_input_fn(train_features, train_labels, True),

View file

@ -14,6 +14,7 @@
"""Multiclass head for Estimator that allow integration with TF Privacy.""" """Multiclass head for Estimator that allow integration with TF Privacy."""
import tensorflow as tf import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
from tensorflow_estimator.python.estimator import model_fn from tensorflow_estimator.python.estimator import model_fn
from tensorflow_estimator.python.estimator.canned import prediction_keys from tensorflow_estimator.python.estimator.canned import prediction_keys
@ -22,7 +23,7 @@ from tensorflow_estimator.python.estimator.head import base_head
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
class DPMultiLabelHead(tf.estimator.MultiLabelHead): class DPMultiLabelHead(tf_estimator.MultiLabelHead):
"""Creates a TF Privacy-enabled version of MultiLabelHead.""" """Creates a TF Privacy-enabled version of MultiLabelHead."""
def __init__(self, def __init__(self,

View file

@ -14,6 +14,7 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow import estimator as tf_estimator
from tensorflow_privacy.privacy.estimators import multi_label_head from tensorflow_privacy.privacy.estimators import multi_label_head
from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.estimators import test_utils
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
@ -63,7 +64,7 @@ class DPMultiLabelHeadTest(tf.test.TestCase):
noise_multiplier=0.0, noise_multiplier=0.0,
num_microbatches=2) num_microbatches=2)
model_fn = test_utils.make_model_fn(head, optimizer, feature_columns) model_fn = test_utils.make_model_fn(head, optimizer, feature_columns)
classifier = tf.estimator.Estimator(model_fn=model_fn) classifier = tf_estimator.Estimator(model_fn=model_fn)
classifier.train( classifier.train(
input_fn=test_utils.make_input_fn(train_features, train_labels, True), input_fn=test_utils.make_input_fn(train_features, train_labels, True),