diff --git a/requirements.txt b/requirements.txt index cb596eb..2d99d08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ tensorflow>=1.13 mpmath scipy>=0.17 +dm-tree~=0.1.1 diff --git a/setup.py b/setup.py index 4b1cbed..4f52c47 100644 --- a/setup.py +++ b/setup.py @@ -15,18 +15,20 @@ from setuptools import find_packages from setuptools import setup -setup(name='tensorflow_privacy', - version='0.2.2', - url='https://github.com/tensorflow/privacy', - license='Apache-2.0', - install_requires=[ - 'scipy>=0.17', - 'mpmath', # used in tests only - ], - # Explicit dependence on TensorFlow is not supported. - # See https://github.com/tensorflow/tensorflow/issues/7166 - extras_require={ - 'tf': ['tensorflow>=1.0.0'], - 'tf_gpu': ['tensorflow-gpu>=1.0.0'], - }, - packages=find_packages()) +setup( + name='tensorflow_privacy', + version='0.2.2', + url='https://github.com/tensorflow/privacy', + license='Apache-2.0', + install_requires=[ + 'scipy>=0.17', + 'mpmath', # used in tests only + 'dm-tree~=0.1.1', # used in tests only + ], + # Explicit dependence on TensorFlow is not supported. + # See https://github.com/tensorflow/tensorflow/issues/7166 + extras_require={ + 'tf': ['tensorflow>=1.0.0'], + 'tf_gpu': ['tensorflow-gpu>=1.0.0'], + }, + packages=find_packages()) diff --git a/tensorflow_privacy/privacy/dp_query/BUILD b/tensorflow_privacy/privacy/dp_query/BUILD index 76ce76b..2931140 100644 --- a/tensorflow_privacy/privacy/dp_query/BUILD +++ b/tensorflow_privacy/privacy/dp_query/BUILD @@ -88,8 +88,8 @@ py_library( srcs = ["nested_query.py"], deps = [ ":dp_query", - "//third_party/py/distutils", "//third_party/py/tensorflow", + "//third_party/py/tree", ], ) diff --git a/tensorflow_privacy/privacy/dp_query/nested_query.py b/tensorflow_privacy/privacy/dp_query/nested_query.py index e76d4ea..5d0dbf8 100644 --- a/tensorflow_privacy/privacy/dp_query/nested_query.py +++ b/tensorflow_privacy/privacy/dp_query/nested_query.py @@ -20,9 +20,8 @@ from __future__ import division from __future__ import print_function import tensorflow as tf - from tensorflow_privacy.privacy.dp_query import dp_query -from tensorflow.contrib import framework as contrib_framework +import tree class NestedQuery(dp_query.DPQuery): @@ -55,8 +54,8 @@ class NestedQuery(dp_query.DPQuery): def caller(query, *args): return getattr(query, fn)(*args, **kwargs) - return contrib_framework.nest.map_structure_up_to( - self._queries, caller, self._queries, *inputs) + return tree.map_structure_up_to(self._queries, caller, self._queries, + *inputs) def set_ledger(self, ledger): self._map_to_queries('set_ledger', ledger=ledger) @@ -106,7 +105,6 @@ class NestedQuery(dp_query.DPQuery): 'get_noised_result', sample_state, global_state) flat_estimates, flat_new_global_states = zip( - *contrib_framework.nest.flatten_up_to( - self._queries, estimates_and_new_global_states)) + *tree.flatten_up_to(self._queries, estimates_and_new_global_states)) return (tf.nest.pack_sequence_as(self._queries, flat_estimates), tf.nest.pack_sequence_as(self._queries, flat_new_global_states))