diff --git a/tensorflow_privacy/privacy/dp_query/normalized_query_test.py b/tensorflow_privacy/privacy/dp_query/normalized_query_test.py index 0e12185..089378e 100644 --- a/tensorflow_privacy/privacy/dp_query/normalized_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/normalized_query_test.py @@ -21,18 +21,16 @@ from tensorflow_privacy.privacy.dp_query import test_utils class NormalizedQueryTest(tf.test.TestCase): def test_normalization(self): - with self.cached_session() as sess: - record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0]. - record2 = tf.constant([4.0, -3.0]) # Not clipped. + record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0]. + record2 = tf.constant([4.0, -3.0]) # Not clipped. - sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0) - query = normalized_query.NormalizedQuery( - numerator_query=sum_query, denominator=2.0) + sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0) + query = normalized_query.NormalizedQuery( + numerator_query=sum_query, denominator=2.0) - query_result, _ = test_utils.run_query(query, [record1, record2]) - result = sess.run(query_result) - expected = [0.5, 0.5] - self.assertAllClose(result, expected) + query_result, _ = test_utils.run_query(query, [record1, record2]) + expected = [0.5, 0.5] + self.assertAllClose(query_result, expected) if __name__ == '__main__':