diff --git a/tensorflow_privacy/privacy/estimators/BUILD b/tensorflow_privacy/privacy/estimators/BUILD index 5b5c7bf..436c13b 100644 --- a/tensorflow_privacy/privacy/estimators/BUILD +++ b/tensorflow_privacy/privacy/estimators/BUILD @@ -27,6 +27,7 @@ py_library( "binary_class_head.py", ], srcs_version = "PY3", + deps = ["//third_party/py/tensorflow:tensorflow_estimator"], ) py_library( @@ -35,6 +36,7 @@ py_library( "multi_class_head.py", ], srcs_version = "PY3", + deps = ["//third_party/py/tensorflow:tensorflow_estimator"], ) py_library( @@ -43,6 +45,7 @@ py_library( "multi_label_head.py", ], srcs_version = "PY3", + deps = ["//third_party/py/tensorflow:tensorflow_estimator"], ) py_library( @@ -51,7 +54,10 @@ py_library( "dnn.py", ], srcs_version = "PY3", - deps = [":head_utils"], + deps = [ + ":head_utils", + "//third_party/py/tensorflow:tensorflow_estimator", + ], ) py_library( @@ -72,6 +78,7 @@ py_test( ":binary_class_head", ":test_utils", "//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", + "//third_party/py/tensorflow:tensorflow_estimator", ], ) @@ -85,6 +92,7 @@ py_test( ":multi_class_head", ":test_utils", "//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", + "//third_party/py/tensorflow:tensorflow_estimator", ], ) @@ -98,6 +106,7 @@ py_test( ":multi_label_head", ":test_utils", "//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", + "//third_party/py/tensorflow:tensorflow_estimator", ], ) diff --git a/tensorflow_privacy/privacy/estimators/binary_class_head.py b/tensorflow_privacy/privacy/estimators/binary_class_head.py index e295502..40a4b6a 100644 --- a/tensorflow_privacy/privacy/estimators/binary_class_head.py +++ b/tensorflow_privacy/privacy/estimators/binary_class_head.py @@ -14,6 +14,7 @@ """Binary class head for Estimator that allow integration with TF Privacy.""" import tensorflow as tf +from tensorflow import estimator as tf_estimator from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import from tensorflow_estimator.python.estimator import model_fn from tensorflow_estimator.python.estimator.canned import prediction_keys @@ -22,7 +23,7 @@ from tensorflow_estimator.python.estimator.head import base_head from tensorflow_estimator.python.estimator.mode_keys import ModeKeys -class DPBinaryClassHead(tf.estimator.BinaryClassHead): +class DPBinaryClassHead(tf_estimator.BinaryClassHead): """Creates a TF Privacy-enabled version of BinaryClassHead.""" def __init__(self, diff --git a/tensorflow_privacy/privacy/estimators/binary_class_head_test.py b/tensorflow_privacy/privacy/estimators/binary_class_head_test.py index 1cf6998..1030608 100644 --- a/tensorflow_privacy/privacy/estimators/binary_class_head_test.py +++ b/tensorflow_privacy/privacy/estimators/binary_class_head_test.py @@ -14,6 +14,7 @@ import numpy as np import tensorflow as tf +from tensorflow import estimator as tf_estimator from tensorflow_privacy.privacy.estimators import binary_class_head from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer @@ -58,7 +59,7 @@ class DPBinaryClassHeadTest(tf.test.TestCase): noise_multiplier=0.0, num_microbatches=2) model_fn = test_utils.make_model_fn(head, optimizer, feature_columns) - classifier = tf.estimator.Estimator(model_fn=model_fn) + classifier = tf_estimator.Estimator(model_fn=model_fn) classifier.train( input_fn=test_utils.make_input_fn(train_features, train_labels, True), diff --git a/tensorflow_privacy/privacy/estimators/dnn.py b/tensorflow_privacy/privacy/estimators/dnn.py index 126a7b5..cda5e09 100644 --- a/tensorflow_privacy/privacy/estimators/dnn.py +++ b/tensorflow_privacy/privacy/estimators/dnn.py @@ -15,12 +15,13 @@ import tensorflow as tf +from tensorflow import estimator as tf_estimator from tensorflow_privacy.privacy.estimators import head_utils from tensorflow_estimator.python.estimator import estimator from tensorflow_estimator.python.estimator.canned import dnn -class DNNClassifier(tf.estimator.Estimator): +class DNNClassifier(tf_estimator.Estimator): """DP version of `tf.estimator.DNNClassifier`.""" def __init__( diff --git a/tensorflow_privacy/privacy/estimators/multi_class_head.py b/tensorflow_privacy/privacy/estimators/multi_class_head.py index 2c4782f..fd25f84 100644 --- a/tensorflow_privacy/privacy/estimators/multi_class_head.py +++ b/tensorflow_privacy/privacy/estimators/multi_class_head.py @@ -14,6 +14,7 @@ """Multiclass head for Estimator that allow integration with TF Privacy.""" import tensorflow as tf +from tensorflow import estimator as tf_estimator from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import from tensorflow_estimator.python.estimator import model_fn from tensorflow_estimator.python.estimator.canned import prediction_keys @@ -22,7 +23,7 @@ from tensorflow_estimator.python.estimator.head import base_head from tensorflow_estimator.python.estimator.mode_keys import ModeKeys -class DPMultiClassHead(tf.estimator.MultiClassHead): +class DPMultiClassHead(tf_estimator.MultiClassHead): """Creates a TF Privacy-enabled version of MultiClassHead.""" def __init__(self, diff --git a/tensorflow_privacy/privacy/estimators/multi_class_head_test.py b/tensorflow_privacy/privacy/estimators/multi_class_head_test.py index b002561..3a70255 100644 --- a/tensorflow_privacy/privacy/estimators/multi_class_head_test.py +++ b/tensorflow_privacy/privacy/estimators/multi_class_head_test.py @@ -14,6 +14,7 @@ import numpy as np import tensorflow as tf +from tensorflow import estimator as tf_estimator from tensorflow_privacy.privacy.estimators import multi_class_head from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer @@ -62,7 +63,7 @@ class DPMultiClassHeadTest(tf.test.TestCase): noise_multiplier=0.0, num_microbatches=2) model_fn = test_utils.make_model_fn(head, optimizer, feature_columns) - classifier = tf.estimator.Estimator(model_fn=model_fn) + classifier = tf_estimator.Estimator(model_fn=model_fn) classifier.train( input_fn=test_utils.make_input_fn(train_features, train_labels, True), diff --git a/tensorflow_privacy/privacy/estimators/multi_label_head.py b/tensorflow_privacy/privacy/estimators/multi_label_head.py index 1dc0ccd..669d300 100644 --- a/tensorflow_privacy/privacy/estimators/multi_label_head.py +++ b/tensorflow_privacy/privacy/estimators/multi_label_head.py @@ -14,6 +14,7 @@ """Multiclass head for Estimator that allow integration with TF Privacy.""" import tensorflow as tf +from tensorflow import estimator as tf_estimator from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import from tensorflow_estimator.python.estimator import model_fn from tensorflow_estimator.python.estimator.canned import prediction_keys @@ -22,7 +23,7 @@ from tensorflow_estimator.python.estimator.head import base_head from tensorflow_estimator.python.estimator.mode_keys import ModeKeys -class DPMultiLabelHead(tf.estimator.MultiLabelHead): +class DPMultiLabelHead(tf_estimator.MultiLabelHead): """Creates a TF Privacy-enabled version of MultiLabelHead.""" def __init__(self, diff --git a/tensorflow_privacy/privacy/estimators/multi_label_head_test.py b/tensorflow_privacy/privacy/estimators/multi_label_head_test.py index 659250b..9de263c 100644 --- a/tensorflow_privacy/privacy/estimators/multi_label_head_test.py +++ b/tensorflow_privacy/privacy/estimators/multi_label_head_test.py @@ -14,6 +14,7 @@ import numpy as np import tensorflow as tf +from tensorflow import estimator as tf_estimator from tensorflow_privacy.privacy.estimators import multi_label_head from tensorflow_privacy.privacy.estimators import test_utils from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer @@ -63,7 +64,7 @@ class DPMultiLabelHeadTest(tf.test.TestCase): noise_multiplier=0.0, num_microbatches=2) model_fn = test_utils.make_model_fn(head, optimizer, feature_columns) - classifier = tf.estimator.Estimator(model_fn=model_fn) + classifier = tf_estimator.Estimator(model_fn=model_fn) classifier.train( input_fn=test_utils.make_input_fn(train_features, train_labels, True),