diff --git a/tensorflow_privacy/privacy/estimators/BUILD b/tensorflow_privacy/privacy/estimators/BUILD index 436c13b..e316fa2 100644 --- a/tensorflow_privacy/privacy/estimators/BUILD +++ b/tensorflow_privacy/privacy/estimators/BUILD @@ -27,7 +27,7 @@ py_library( "binary_class_head.py", ], srcs_version = "PY3", - deps = ["//third_party/py/tensorflow:tensorflow_estimator"], + deps = ["//third_party/tensorflow_estimator/python/estimator:binary_class_head"], ) py_library( @@ -36,7 +36,7 @@ py_library( "multi_class_head.py", ], srcs_version = "PY3", - deps = ["//third_party/py/tensorflow:tensorflow_estimator"], + deps = ["//third_party/tensorflow_estimator/python/estimator:multi_class_head"], ) py_library( @@ -45,7 +45,7 @@ py_library( "multi_label_head.py", ], srcs_version = "PY3", - deps = ["//third_party/py/tensorflow:tensorflow_estimator"], + deps = ["//third_party/tensorflow_estimator/python/estimator:multi_label_head"], ) py_library( @@ -54,10 +54,7 @@ py_library( "dnn.py", ], srcs_version = "PY3", - deps = [ - ":head_utils", - "//third_party/py/tensorflow:tensorflow_estimator", - ], + deps = [":head_utils"], ) py_library( @@ -78,7 +75,6 @@ py_test( ":binary_class_head", ":test_utils", "//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", - "//third_party/py/tensorflow:tensorflow_estimator", ], ) @@ -92,7 +88,6 @@ py_test( ":multi_class_head", ":test_utils", "//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", - "//third_party/py/tensorflow:tensorflow_estimator", ], ) @@ -106,7 +101,6 @@ py_test( ":multi_label_head", ":test_utils", "//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", - "//third_party/py/tensorflow:tensorflow_estimator", ], ) diff --git a/tensorflow_privacy/privacy/estimators/binary_class_head.py b/tensorflow_privacy/privacy/estimators/binary_class_head.py index 40a4b6a..341d716 100644 --- a/tensorflow_privacy/privacy/estimators/binary_class_head.py +++ b/tensorflow_privacy/privacy/estimators/binary_class_head.py @@ -14,16 +14,16 @@ """Binary class head for Estimator that allow integration with TF Privacy.""" import tensorflow as tf -from tensorflow import estimator as tf_estimator from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import from tensorflow_estimator.python.estimator import model_fn from tensorflow_estimator.python.estimator.canned import prediction_keys from tensorflow_estimator.python.estimator.export import export_output from tensorflow_estimator.python.estimator.head import base_head +from tensorflow_estimator.python.estimator.head import binary_class_head from tensorflow_estimator.python.estimator.mode_keys import ModeKeys -class DPBinaryClassHead(tf_estimator.BinaryClassHead): +class DPBinaryClassHead(binary_class_head.BinaryClassHead): """Creates a TF Privacy-enabled version of BinaryClassHead.""" def __init__(self, diff --git a/tensorflow_privacy/privacy/estimators/binary_class_head_test.py b/tensorflow_privacy/privacy/estimators/binary_class_head_test.py index 1030608..09e616f 100644 --- a/tensorflow_privacy/privacy/estimators/binary_class_head_test.py +++ b/tensorflow_privacy/privacy/estimators/binary_class_head_test.py @@ -14,10 +14,10 @@ import numpy as np import tensorflow as tf -from tensorflow import estimator as tf_estimator from tensorflow_privacy.privacy.estimators import binary_class_head from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer +from tensorflow_estimator.python.estimator import estimator class DPBinaryClassHeadTest(tf.test.TestCase): @@ -59,7 +59,7 @@ class DPBinaryClassHeadTest(tf.test.TestCase): noise_multiplier=0.0, num_microbatches=2) model_fn = test_utils.make_model_fn(head, optimizer, feature_columns) - classifier = tf_estimator.Estimator(model_fn=model_fn) + classifier = estimator.Estimator(model_fn=model_fn) classifier.train( input_fn=test_utils.make_input_fn(train_features, train_labels, True), diff --git a/tensorflow_privacy/privacy/estimators/dnn.py b/tensorflow_privacy/privacy/estimators/dnn.py index cda5e09..36de154 100644 --- a/tensorflow_privacy/privacy/estimators/dnn.py +++ b/tensorflow_privacy/privacy/estimators/dnn.py @@ -15,13 +15,12 @@ import tensorflow as tf -from tensorflow import estimator as tf_estimator from tensorflow_privacy.privacy.estimators import head_utils from tensorflow_estimator.python.estimator import estimator from tensorflow_estimator.python.estimator.canned import dnn -class DNNClassifier(tf_estimator.Estimator): +class DNNClassifier(estimator.Estimator): """DP version of `tf.estimator.DNNClassifier`.""" def __init__( diff --git a/tensorflow_privacy/privacy/estimators/multi_class_head.py b/tensorflow_privacy/privacy/estimators/multi_class_head.py index fd25f84..4629217 100644 --- a/tensorflow_privacy/privacy/estimators/multi_class_head.py +++ b/tensorflow_privacy/privacy/estimators/multi_class_head.py @@ -14,16 +14,16 @@ """Multiclass head for Estimator that allow integration with TF Privacy.""" import tensorflow as tf -from tensorflow import estimator as tf_estimator from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import from tensorflow_estimator.python.estimator import model_fn from tensorflow_estimator.python.estimator.canned import prediction_keys from tensorflow_estimator.python.estimator.export import export_output from tensorflow_estimator.python.estimator.head import base_head +from tensorflow_estimator.python.estimator.head import multi_class_head from tensorflow_estimator.python.estimator.mode_keys import ModeKeys -class DPMultiClassHead(tf_estimator.MultiClassHead): +class DPMultiClassHead(multi_class_head.MultiClassHead): """Creates a TF Privacy-enabled version of MultiClassHead.""" def __init__(self, diff --git a/tensorflow_privacy/privacy/estimators/multi_class_head_test.py b/tensorflow_privacy/privacy/estimators/multi_class_head_test.py index 3a70255..9d2e7cf 100644 --- a/tensorflow_privacy/privacy/estimators/multi_class_head_test.py +++ b/tensorflow_privacy/privacy/estimators/multi_class_head_test.py @@ -14,10 +14,10 @@ import numpy as np import tensorflow as tf -from tensorflow import estimator as tf_estimator from tensorflow_privacy.privacy.estimators import multi_class_head from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer +from tensorflow_estimator.python.estimator import estimator class DPMultiClassHeadTest(tf.test.TestCase): @@ -63,7 +63,7 @@ class DPMultiClassHeadTest(tf.test.TestCase): noise_multiplier=0.0, num_microbatches=2) model_fn = test_utils.make_model_fn(head, optimizer, feature_columns) - classifier = tf_estimator.Estimator(model_fn=model_fn) + classifier = estimator.Estimator(model_fn=model_fn) classifier.train( input_fn=test_utils.make_input_fn(train_features, train_labels, True), diff --git a/tensorflow_privacy/privacy/estimators/multi_label_head.py b/tensorflow_privacy/privacy/estimators/multi_label_head.py index 669d300..622a854 100644 --- a/tensorflow_privacy/privacy/estimators/multi_label_head.py +++ b/tensorflow_privacy/privacy/estimators/multi_label_head.py @@ -14,16 +14,16 @@ """Multiclass head for Estimator that allow integration with TF Privacy.""" import tensorflow as tf -from tensorflow import estimator as tf_estimator from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import from tensorflow_estimator.python.estimator import model_fn from tensorflow_estimator.python.estimator.canned import prediction_keys from tensorflow_estimator.python.estimator.export import export_output from tensorflow_estimator.python.estimator.head import base_head +from tensorflow_estimator.python.estimator.head import multi_label_head from tensorflow_estimator.python.estimator.mode_keys import ModeKeys -class DPMultiLabelHead(tf_estimator.MultiLabelHead): +class DPMultiLabelHead(multi_label_head.MultiLabelHead): """Creates a TF Privacy-enabled version of MultiLabelHead.""" def __init__(self, diff --git a/tensorflow_privacy/privacy/estimators/multi_label_head_test.py b/tensorflow_privacy/privacy/estimators/multi_label_head_test.py index 9de263c..2a35c64 100644 --- a/tensorflow_privacy/privacy/estimators/multi_label_head_test.py +++ b/tensorflow_privacy/privacy/estimators/multi_label_head_test.py @@ -14,10 +14,10 @@ import numpy as np import tensorflow as tf -from tensorflow import estimator as tf_estimator from tensorflow_privacy.privacy.estimators import multi_label_head from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer +from tensorflow_estimator.python.estimator import estimator class DPMultiLabelHeadTest(tf.test.TestCase): @@ -64,7 +64,7 @@ class DPMultiLabelHeadTest(tf.test.TestCase): noise_multiplier=0.0, num_microbatches=2) model_fn = test_utils.make_model_fn(head, optimizer, feature_columns) - classifier = tf_estimator.Estimator(model_fn=model_fn) + classifier = estimator.Estimator(model_fn=model_fn) classifier.train( input_fn=test_utils.make_input_fn(train_features, train_labels, True), diff --git a/tensorflow_privacy/privacy/estimators/v1/BUILD b/tensorflow_privacy/privacy/estimators/v1/BUILD index 3195cd8..0522c75 100644 --- a/tensorflow_privacy/privacy/estimators/v1/BUILD +++ b/tensorflow_privacy/privacy/estimators/v1/BUILD @@ -23,10 +23,7 @@ py_library( "dnn.py", ], srcs_version = "PY3", - deps = [ - ":head", - "//third_party/py/tensorflow:tensorflow_estimator", - ], + deps = [":head"], ) py_test( @@ -39,7 +36,6 @@ py_test( ":head", "//tensorflow_privacy/privacy/estimators:test_utils", "//tensorflow_privacy/privacy/optimizers:dp_optimizer", - "//third_party/py/tensorflow:tensorflow_estimator", ], ) diff --git a/tensorflow_privacy/privacy/estimators/v1/dnn.py b/tensorflow_privacy/privacy/estimators/v1/dnn.py index 143d793..c819593 100644 --- a/tensorflow_privacy/privacy/estimators/v1/dnn.py +++ b/tensorflow_privacy/privacy/estimators/v1/dnn.py @@ -16,13 +16,12 @@ import tensorflow as tf -from tensorflow import estimator as tf_estimator from tensorflow_privacy.privacy.estimators.v1 import head as head_lib from tensorflow_estimator.python.estimator import estimator from tensorflow_estimator.python.estimator.canned import dnn -class DNNClassifier(tf_estimator.Estimator): +class DNNClassifier(estimator.Estimator): """DP version of `tf.compat.v1.estimator.DNNClassifier`.""" def __init__( diff --git a/tensorflow_privacy/privacy/estimators/v1/head_test.py b/tensorflow_privacy/privacy/estimators/v1/head_test.py index 3536e09..deaf208 100644 --- a/tensorflow_privacy/privacy/estimators/v1/head_test.py +++ b/tensorflow_privacy/privacy/estimators/v1/head_test.py @@ -15,10 +15,10 @@ from absl.testing import parameterized import tensorflow as tf -from tensorflow import estimator as tf_estimator from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.estimators.v1 import head as head_lib from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer +from tensorflow_estimator.python.estimator import estimator def make_model_fn(head, optimizer, feature_columns): @@ -70,7 +70,7 @@ class DPHeadTest(tf.test.TestCase, parameterized.TestCase): noise_multiplier=0.0, num_microbatches=2) model_fn = make_model_fn(head, optimizer, feature_columns) - classifier = tf_estimator.Estimator(model_fn=model_fn) + classifier = estimator.Estimator(model_fn=model_fn) classifier.train( input_fn=test_utils.make_input_fn(train_features, train_labels, True),