Move TF privacy off contrib.
PiperOrigin-RevId: 289953826
This commit is contained in:
parent
1a448b4272
commit
a8a2d91795
4 changed files with 23 additions and 22 deletions
|
@ -1,3 +1,4 @@
|
||||||
tensorflow>=1.13
|
tensorflow>=1.13
|
||||||
mpmath
|
mpmath
|
||||||
scipy>=0.17
|
scipy>=0.17
|
||||||
|
dm-tree~=0.1.1
|
||||||
|
|
4
setup.py
4
setup.py
|
@ -15,13 +15,15 @@
|
||||||
from setuptools import find_packages
|
from setuptools import find_packages
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
setup(name='tensorflow_privacy',
|
setup(
|
||||||
|
name='tensorflow_privacy',
|
||||||
version='0.2.2',
|
version='0.2.2',
|
||||||
url='https://github.com/tensorflow/privacy',
|
url='https://github.com/tensorflow/privacy',
|
||||||
license='Apache-2.0',
|
license='Apache-2.0',
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'scipy>=0.17',
|
'scipy>=0.17',
|
||||||
'mpmath', # used in tests only
|
'mpmath', # used in tests only
|
||||||
|
'dm-tree~=0.1.1', # used in tests only
|
||||||
],
|
],
|
||||||
# Explicit dependence on TensorFlow is not supported.
|
# Explicit dependence on TensorFlow is not supported.
|
||||||
# See https://github.com/tensorflow/tensorflow/issues/7166
|
# See https://github.com/tensorflow/tensorflow/issues/7166
|
||||||
|
|
|
@ -88,8 +88,8 @@ py_library(
|
||||||
srcs = ["nested_query.py"],
|
srcs = ["nested_query.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":dp_query",
|
":dp_query",
|
||||||
"//third_party/py/distutils",
|
|
||||||
"//third_party/py/tensorflow",
|
"//third_party/py/tensorflow",
|
||||||
|
"//third_party/py/tree",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,8 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
from tensorflow.contrib import framework as contrib_framework
|
import tree
|
||||||
|
|
||||||
|
|
||||||
class NestedQuery(dp_query.DPQuery):
|
class NestedQuery(dp_query.DPQuery):
|
||||||
|
@ -55,8 +54,8 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
def caller(query, *args):
|
def caller(query, *args):
|
||||||
return getattr(query, fn)(*args, **kwargs)
|
return getattr(query, fn)(*args, **kwargs)
|
||||||
|
|
||||||
return contrib_framework.nest.map_structure_up_to(
|
return tree.map_structure_up_to(self._queries, caller, self._queries,
|
||||||
self._queries, caller, self._queries, *inputs)
|
*inputs)
|
||||||
|
|
||||||
def set_ledger(self, ledger):
|
def set_ledger(self, ledger):
|
||||||
self._map_to_queries('set_ledger', ledger=ledger)
|
self._map_to_queries('set_ledger', ledger=ledger)
|
||||||
|
@ -106,7 +105,6 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
'get_noised_result', sample_state, global_state)
|
'get_noised_result', sample_state, global_state)
|
||||||
|
|
||||||
flat_estimates, flat_new_global_states = zip(
|
flat_estimates, flat_new_global_states = zip(
|
||||||
*contrib_framework.nest.flatten_up_to(
|
*tree.flatten_up_to(self._queries, estimates_and_new_global_states))
|
||||||
self._queries, estimates_and_new_global_states))
|
|
||||||
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
||||||
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
||||||
|
|
Loading…
Reference in a new issue