Remove tf.contrib.framework.
PiperOrigin-RevId: 289487098
This commit is contained in:
parent
8d98c3433b
commit
c80a862ae2
1 changed files with 3 additions and 2 deletions
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||||
|
from tensorflow.contrib import framework as contrib_framework
|
||||||
|
|
||||||
|
|
||||||
class NestedQuery(dp_query.DPQuery):
|
class NestedQuery(dp_query.DPQuery):
|
||||||
|
@ -54,7 +55,7 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
def caller(query, *args):
|
def caller(query, *args):
|
||||||
return getattr(query, fn)(*args, **kwargs)
|
return getattr(query, fn)(*args, **kwargs)
|
||||||
|
|
||||||
return tf.contrib.framework.nest.map_structure_up_to(
|
return contrib_framework.nest.map_structure_up_to(
|
||||||
self._queries, caller, self._queries, *inputs)
|
self._queries, caller, self._queries, *inputs)
|
||||||
|
|
||||||
def set_ledger(self, ledger):
|
def set_ledger(self, ledger):
|
||||||
|
@ -105,7 +106,7 @@ class NestedQuery(dp_query.DPQuery):
|
||||||
'get_noised_result', sample_state, global_state)
|
'get_noised_result', sample_state, global_state)
|
||||||
|
|
||||||
flat_estimates, flat_new_global_states = zip(
|
flat_estimates, flat_new_global_states = zip(
|
||||||
*tf.contrib.framework.nest.flatten_up_to(
|
*contrib_framework.nest.flatten_up_to(
|
||||||
self._queries, estimates_and_new_global_states))
|
self._queries, estimates_and_new_global_states))
|
||||||
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
||||||
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
||||||
|
|
Loading…
Reference in a new issue