diff --git a/tensorflow_privacy/privacy/dp_query/nested_query.py b/tensorflow_privacy/privacy/dp_query/nested_query.py index f6cb57c..e76d4ea 100644 --- a/tensorflow_privacy/privacy/dp_query/nested_query.py +++ b/tensorflow_privacy/privacy/dp_query/nested_query.py @@ -22,6 +22,7 @@ from __future__ import print_function import tensorflow as tf from tensorflow_privacy.privacy.dp_query import dp_query +from tensorflow.contrib import framework as contrib_framework class NestedQuery(dp_query.DPQuery): @@ -54,7 +55,7 @@ class NestedQuery(dp_query.DPQuery): def caller(query, *args): return getattr(query, fn)(*args, **kwargs) - return tf.contrib.framework.nest.map_structure_up_to( + return contrib_framework.nest.map_structure_up_to( self._queries, caller, self._queries, *inputs) def set_ledger(self, ledger): @@ -105,7 +106,7 @@ class NestedQuery(dp_query.DPQuery): 'get_noised_result', sample_state, global_state) flat_estimates, flat_new_global_states = zip( - *tf.contrib.framework.nest.flatten_up_to( + *contrib_framework.nest.flatten_up_to( self._queries, estimates_and_new_global_states)) return (tf.nest.pack_sequence_as(self._queries, flat_estimates), tf.nest.pack_sequence_as(self._queries, flat_new_global_states))