Move TF privacy off contrib.

PiperOrigin-RevId: 289953826
This commit is contained in:
Keith Rush 2020-01-15 15:46:55 -08:00 committed by A. Unique TensorFlower
parent 1a448b4272
commit a8a2d91795
4 changed files with 23 additions and 22 deletions

View file

@ -1,3 +1,4 @@
tensorflow>=1.13 tensorflow>=1.13
mpmath mpmath
scipy>=0.17 scipy>=0.17
dm-tree~=0.1.1

View file

@ -15,18 +15,20 @@
from setuptools import find_packages from setuptools import find_packages
from setuptools import setup from setuptools import setup
setup(name='tensorflow_privacy', setup(
version='0.2.2', name='tensorflow_privacy',
url='https://github.com/tensorflow/privacy', version='0.2.2',
license='Apache-2.0', url='https://github.com/tensorflow/privacy',
install_requires=[ license='Apache-2.0',
'scipy>=0.17', install_requires=[
'mpmath', # used in tests only 'scipy>=0.17',
], 'mpmath', # used in tests only
# Explicit dependence on TensorFlow is not supported. 'dm-tree~=0.1.1', # used in tests only
# See https://github.com/tensorflow/tensorflow/issues/7166 ],
extras_require={ # Explicit dependence on TensorFlow is not supported.
'tf': ['tensorflow>=1.0.0'], # See https://github.com/tensorflow/tensorflow/issues/7166
'tf_gpu': ['tensorflow-gpu>=1.0.0'], extras_require={
}, 'tf': ['tensorflow>=1.0.0'],
packages=find_packages()) 'tf_gpu': ['tensorflow-gpu>=1.0.0'],
},
packages=find_packages())

View file

@ -88,8 +88,8 @@ py_library(
srcs = ["nested_query.py"], srcs = ["nested_query.py"],
deps = [ deps = [
":dp_query", ":dp_query",
"//third_party/py/distutils",
"//third_party/py/tensorflow", "//third_party/py/tensorflow",
"//third_party/py/tree",
], ],
) )

View file

@ -20,9 +20,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow.contrib import framework as contrib_framework import tree
class NestedQuery(dp_query.DPQuery): class NestedQuery(dp_query.DPQuery):
@ -55,8 +54,8 @@ class NestedQuery(dp_query.DPQuery):
def caller(query, *args): def caller(query, *args):
return getattr(query, fn)(*args, **kwargs) return getattr(query, fn)(*args, **kwargs)
return contrib_framework.nest.map_structure_up_to( return tree.map_structure_up_to(self._queries, caller, self._queries,
self._queries, caller, self._queries, *inputs) *inputs)
def set_ledger(self, ledger): def set_ledger(self, ledger):
self._map_to_queries('set_ledger', ledger=ledger) self._map_to_queries('set_ledger', ledger=ledger)
@ -106,7 +105,6 @@ class NestedQuery(dp_query.DPQuery):
'get_noised_result', sample_state, global_state) 'get_noised_result', sample_state, global_state)
flat_estimates, flat_new_global_states = zip( flat_estimates, flat_new_global_states = zip(
*contrib_framework.nest.flatten_up_to( *tree.flatten_up_to(self._queries, estimates_and_new_global_states))
self._queries, estimates_and_new_global_states))
return (tf.nest.pack_sequence_as(self._queries, flat_estimates), return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
tf.nest.pack_sequence_as(self._queries, flat_new_global_states)) tf.nest.pack_sequence_as(self._queries, flat_new_global_states))