Explicitly import estimator from tensorflow as a separate import instead of
accessing it via tf.estimator and depend on the tensorflow estimator target. PiperOrigin-RevId: 437818180
This commit is contained in:
parent
70ab071e23
commit
fc2c15ab21
8 changed files with 24 additions and 8 deletions
|
@ -27,6 +27,7 @@ py_library(
|
|||
"binary_class_head.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//third_party/py/tensorflow:tensorflow_estimator"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -35,6 +36,7 @@ py_library(
|
|||
"multi_class_head.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//third_party/py/tensorflow:tensorflow_estimator"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -43,6 +45,7 @@ py_library(
|
|||
"multi_label_head.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//third_party/py/tensorflow:tensorflow_estimator"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -51,7 +54,10 @@ py_library(
|
|||
"dnn.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [":head_utils"],
|
||||
deps = [
|
||||
":head_utils",
|
||||
"//third_party/py/tensorflow:tensorflow_estimator",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -72,6 +78,7 @@ py_test(
|
|||
":binary_class_head",
|
||||
":test_utils",
|
||||
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
|
||||
"//third_party/py/tensorflow:tensorflow_estimator",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -85,6 +92,7 @@ py_test(
|
|||
":multi_class_head",
|
||||
":test_utils",
|
||||
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
|
||||
"//third_party/py/tensorflow:tensorflow_estimator",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -98,6 +106,7 @@ py_test(
|
|||
":multi_label_head",
|
||||
":test_utils",
|
||||
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
|
||||
"//third_party/py/tensorflow:tensorflow_estimator",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""Binary class head for Estimator that allow integration with TF Privacy."""
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import estimator as tf_estimator
|
||||
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow_estimator.python.estimator import model_fn
|
||||
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||
|
@ -22,7 +23,7 @@ from tensorflow_estimator.python.estimator.head import base_head
|
|||
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
|
||||
|
||||
|
||||
class DPBinaryClassHead(tf.estimator.BinaryClassHead):
|
||||
class DPBinaryClassHead(tf_estimator.BinaryClassHead):
|
||||
"""Creates a TF Privacy-enabled version of BinaryClassHead."""
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import estimator as tf_estimator
|
||||
from tensorflow_privacy.privacy.estimators import binary_class_head
|
||||
from tensorflow_privacy.privacy.estimators import test_utils
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
|
||||
|
@ -58,7 +59,7 @@ class DPBinaryClassHeadTest(tf.test.TestCase):
|
|||
noise_multiplier=0.0,
|
||||
num_microbatches=2)
|
||||
model_fn = test_utils.make_model_fn(head, optimizer, feature_columns)
|
||||
classifier = tf.estimator.Estimator(model_fn=model_fn)
|
||||
classifier = tf_estimator.Estimator(model_fn=model_fn)
|
||||
|
||||
classifier.train(
|
||||
input_fn=test_utils.make_input_fn(train_features, train_labels, True),
|
||||
|
|
|
@ -15,12 +15,13 @@
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow import estimator as tf_estimator
|
||||
from tensorflow_privacy.privacy.estimators import head_utils
|
||||
from tensorflow_estimator.python.estimator import estimator
|
||||
from tensorflow_estimator.python.estimator.canned import dnn
|
||||
|
||||
|
||||
class DNNClassifier(tf.estimator.Estimator):
|
||||
class DNNClassifier(tf_estimator.Estimator):
|
||||
"""DP version of `tf.estimator.DNNClassifier`."""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""Multiclass head for Estimator that allow integration with TF Privacy."""
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import estimator as tf_estimator
|
||||
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow_estimator.python.estimator import model_fn
|
||||
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||
|
@ -22,7 +23,7 @@ from tensorflow_estimator.python.estimator.head import base_head
|
|||
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
|
||||
|
||||
|
||||
class DPMultiClassHead(tf.estimator.MultiClassHead):
|
||||
class DPMultiClassHead(tf_estimator.MultiClassHead):
|
||||
"""Creates a TF Privacy-enabled version of MultiClassHead."""
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import estimator as tf_estimator
|
||||
from tensorflow_privacy.privacy.estimators import multi_class_head
|
||||
from tensorflow_privacy.privacy.estimators import test_utils
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
|
||||
|
@ -62,7 +63,7 @@ class DPMultiClassHeadTest(tf.test.TestCase):
|
|||
noise_multiplier=0.0,
|
||||
num_microbatches=2)
|
||||
model_fn = test_utils.make_model_fn(head, optimizer, feature_columns)
|
||||
classifier = tf.estimator.Estimator(model_fn=model_fn)
|
||||
classifier = tf_estimator.Estimator(model_fn=model_fn)
|
||||
|
||||
classifier.train(
|
||||
input_fn=test_utils.make_input_fn(train_features, train_labels, True),
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""Multiclass head for Estimator that allow integration with TF Privacy."""
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import estimator as tf_estimator
|
||||
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow_estimator.python.estimator import model_fn
|
||||
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||
|
@ -22,7 +23,7 @@ from tensorflow_estimator.python.estimator.head import base_head
|
|||
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys
|
||||
|
||||
|
||||
class DPMultiLabelHead(tf.estimator.MultiLabelHead):
|
||||
class DPMultiLabelHead(tf_estimator.MultiLabelHead):
|
||||
"""Creates a TF Privacy-enabled version of MultiLabelHead."""
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import estimator as tf_estimator
|
||||
from tensorflow_privacy.privacy.estimators import multi_label_head
|
||||
from tensorflow_privacy.privacy.estimators import test_utils
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
|
||||
|
@ -63,7 +64,7 @@ class DPMultiLabelHeadTest(tf.test.TestCase):
|
|||
noise_multiplier=0.0,
|
||||
num_microbatches=2)
|
||||
model_fn = test_utils.make_model_fn(head, optimizer, feature_columns)
|
||||
classifier = tf.estimator.Estimator(model_fn=model_fn)
|
||||
classifier = tf_estimator.Estimator(model_fn=model_fn)
|
||||
|
||||
classifier.train(
|
||||
input_fn=test_utils.make_input_fn(train_features, train_labels, True),
|
||||
|
|
Loading…
Reference in a new issue