Update several DPQuery tests to TF v2.
PiperOrigin-RevId: 468763153
This commit is contained in:
parent
7fe491f7a4
commit
fd64be5b5b
4 changed files with 169 additions and 210 deletions
|
@ -23,19 +23,16 @@ import tensorflow_probability as tfp
|
||||||
class DistributedSkellamQueryTest(tf.test.TestCase, parameterized.TestCase):
|
class DistributedSkellamQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_skellam_sum_no_noise(self):
|
def test_skellam_sum_no_noise(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
record1 = tf.constant([2, 0], dtype=tf.int32)
|
record1 = tf.constant([2, 0], dtype=tf.int32)
|
||||||
record2 = tf.constant([-1, 1], dtype=tf.int32)
|
record2 = tf.constant([-1, 1], dtype=tf.int32)
|
||||||
|
|
||||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||||
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
|
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [1, 1]
|
expected = [1, 1]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
def test_skellam_multiple_shapes(self):
|
def test_skellam_multiple_shapes(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
tensor1 = tf.constant([2, 0], dtype=tf.int32)
|
tensor1 = tf.constant([2, 0], dtype=tf.int32)
|
||||||
tensor2 = tf.constant([-1, 1, 3], dtype=tf.int32)
|
tensor2 = tf.constant([-1, 1, 3], dtype=tf.int32)
|
||||||
record = [tensor1, tensor2]
|
record = [tensor1, tensor2]
|
||||||
|
@ -43,57 +40,48 @@ class DistributedSkellamQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||||
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
|
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
|
||||||
query_result, _ = test_utils.run_query(query, [record, record])
|
query_result, _ = test_utils.run_query(query, [record, record])
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [2 * tensor1, 2 * tensor2]
|
expected = [2 * tensor1, 2 * tensor2]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
def test_skellam_raise_type_exception(self):
|
def test_skellam_raise_type_exception(self):
|
||||||
with self.cached_session() as sess, self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
record1 = tf.constant([2, 0], dtype=tf.float32)
|
record1 = tf.constant([2, 0], dtype=tf.float32)
|
||||||
record2 = tf.constant([-1, 1], dtype=tf.float32)
|
record2 = tf.constant([-1, 1], dtype=tf.float32)
|
||||||
|
|
||||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||||
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
|
l1_norm_bound=10, l2_norm_bound=10, local_stddev=0.0)
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
test_utils.run_query(query, [record1, record2])
|
||||||
sess.run(query_result)
|
|
||||||
|
|
||||||
def test_skellam_raise_l1_norm_exception(self):
|
def test_skellam_raise_l1_norm_exception(self):
|
||||||
with self.cached_session() as sess, self.assertRaises(
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
tf.errors.InvalidArgumentError):
|
|
||||||
record1 = tf.constant([1, 2], dtype=tf.int32)
|
record1 = tf.constant([1, 2], dtype=tf.int32)
|
||||||
record2 = tf.constant([3, 4], dtype=tf.int32)
|
record2 = tf.constant([3, 4], dtype=tf.int32)
|
||||||
|
|
||||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||||
l1_norm_bound=1, l2_norm_bound=100, local_stddev=0.0)
|
l1_norm_bound=1, l2_norm_bound=100, local_stddev=0.0)
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
sess.run(query_result)
|
|
||||||
|
|
||||||
def test_skellam_raise_l2_norm_exception(self):
|
def test_skellam_raise_l2_norm_exception(self):
|
||||||
with self.cached_session() as sess, self.assertRaises(
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
tf.errors.InvalidArgumentError):
|
|
||||||
record1 = tf.constant([1, 2], dtype=tf.int32)
|
record1 = tf.constant([1, 2], dtype=tf.int32)
|
||||||
record2 = tf.constant([3, 4], dtype=tf.int32)
|
record2 = tf.constant([3, 4], dtype=tf.int32)
|
||||||
|
|
||||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||||
l1_norm_bound=10, l2_norm_bound=4, local_stddev=0.0)
|
l1_norm_bound=10, l2_norm_bound=4, local_stddev=0.0)
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
sess.run(query_result)
|
|
||||||
|
|
||||||
def test_skellam_sum_with_noise(self):
|
def test_skellam_sum_with_noise(self):
|
||||||
"""Use only one record to test std."""
|
"""Use only one record to test std."""
|
||||||
with self.cached_session() as sess:
|
|
||||||
record = tf.constant([1], dtype=tf.int32)
|
record = tf.constant([1], dtype=tf.int32)
|
||||||
local_stddev = 1.0
|
local_stddev = 1.0
|
||||||
|
|
||||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||||
l1_norm_bound=10.0, l2_norm_bound=10, local_stddev=local_stddev)
|
l1_norm_bound=10.0, l2_norm_bound=10, local_stddev=local_stddev)
|
||||||
query_result, _ = test_utils.run_query(query, [record])
|
|
||||||
|
|
||||||
noised_sums = []
|
noised_sums = []
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
noised_sums.append(sess.run(query_result))
|
query_result, _ = test_utils.run_query(query, [record])
|
||||||
|
noised_sums.append(query_result)
|
||||||
|
|
||||||
result_stddev = np.std(noised_sums)
|
result_stddev = np.std(noised_sums)
|
||||||
self.assertNear(result_stddev, local_stddev, 0.1)
|
self.assertNear(result_stddev, local_stddev, 0.1)
|
||||||
|
@ -108,17 +96,16 @@ class DistributedSkellamQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
Both results are evaluated to match percentiles (25, 50, 75).
|
Both results are evaluated to match percentiles (25, 50, 75).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
num_trials = 10000
|
num_trials = 10000
|
||||||
num_users = 100
|
num_users = 100
|
||||||
record = tf.zeros([num_trials], dtype=tf.int32)
|
record = tf.zeros([num_trials], dtype=tf.int32)
|
||||||
local_stddev = 1.0
|
local_stddev = 1.0
|
||||||
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
query = distributed_skellam_query.DistributedSkellamSumQuery(
|
||||||
l1_norm_bound=10.0, l2_norm_bound=10, local_stddev=local_stddev)
|
l1_norm_bound=10.0, l2_norm_bound=10, local_stddev=local_stddev)
|
||||||
query_result, _ = test_utils.run_query(query, [record])
|
|
||||||
distributed_noised = tf.zeros([num_trials], dtype=tf.int32)
|
distributed_noised = tf.zeros([num_trials], dtype=tf.int32)
|
||||||
for _ in range(num_users):
|
for _ in range(num_users):
|
||||||
distributed_noised += sess.run(query_result)
|
query_result, _ = test_utils.run_query(query, [record])
|
||||||
|
distributed_noised += query_result
|
||||||
|
|
||||||
def add_noise(v, stddev):
|
def add_noise(v, stddev):
|
||||||
lam = stddev**2 / 2
|
lam = stddev**2 / 2
|
||||||
|
@ -131,8 +118,8 @@ class DistributedSkellamQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
record_centralized = tf.zeros([num_trials], dtype=tf.int32)
|
record_centralized = tf.zeros([num_trials], dtype=tf.int32)
|
||||||
centralized_noised = sess.run(
|
centralized_noised = add_noise(record_centralized,
|
||||||
add_noise(record_centralized, local_stddev * np.sqrt(num_users)))
|
local_stddev * np.sqrt(num_users))
|
||||||
|
|
||||||
tolerance = 5
|
tolerance = 5
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
|
|
|
@ -22,61 +22,50 @@ from tensorflow_privacy.privacy.dp_query import test_utils
|
||||||
class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_gaussian_sum_no_clip_no_noise(self):
|
def test_gaussian_sum_no_clip_no_noise(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
record1 = tf.constant([2.0, 0.0])
|
record1 = tf.constant([2.0, 0.0])
|
||||||
record2 = tf.constant([-1.0, 1.0])
|
record2 = tf.constant([-1.0, 1.0])
|
||||||
|
|
||||||
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [1.0, 1.0]
|
expected = [1.0, 1.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
def test_gaussian_sum_with_clip_no_noise(self):
|
def test_gaussian_sum_with_clip_no_noise(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
||||||
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||||
|
|
||||||
query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0)
|
query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0)
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [1.0, 1.0]
|
expected = [1.0, 1.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
def test_gaussian_sum_with_changing_clip_no_noise(self):
|
def test_gaussian_sum_with_changing_clip_no_noise(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
||||||
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||||
|
|
||||||
l2_norm_clip = tf.Variable(5.0)
|
l2_norm_clip = tf.Variable(5.0)
|
||||||
l2_norm_clip_placeholder = tf.compat.v1.placeholder(tf.float32)
|
|
||||||
assign_l2_norm_clip = tf.compat.v1.assign(l2_norm_clip,
|
|
||||||
l2_norm_clip_placeholder)
|
|
||||||
query = gaussian_query.GaussianSumQuery(
|
query = gaussian_query.GaussianSumQuery(
|
||||||
l2_norm_clip=l2_norm_clip, stddev=0.0)
|
l2_norm_clip=l2_norm_clip, stddev=0.0)
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
|
|
||||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [1.0, 1.0]
|
expected = [1.0, 1.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
sess.run(assign_l2_norm_clip, {l2_norm_clip_placeholder: 0.0})
|
l2_norm_clip.assign(0.0)
|
||||||
result = sess.run(query_result)
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
expected = [0.0, 0.0]
|
expected = [0.0, 0.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
def test_gaussian_sum_with_noise(self):
|
def test_gaussian_sum_with_noise(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
record1, record2 = 2.71828, 3.14159
|
record1, record2 = 2.71828, 3.14159
|
||||||
stddev = 1.0
|
stddev = 1.0
|
||||||
|
|
||||||
query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=stddev)
|
query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=stddev)
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
|
||||||
|
|
||||||
noised_sums = []
|
noised_sums = []
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
noised_sums.append(sess.run(query_result))
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
|
noised_sums.append(query_result)
|
||||||
|
|
||||||
result_stddev = np.std(noised_sums)
|
result_stddev = np.std(noised_sums)
|
||||||
self.assertNear(result_stddev, stddev, 0.1)
|
self.assertNear(result_stddev, stddev, 0.1)
|
||||||
|
@ -100,11 +89,8 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states(
|
merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states(
|
||||||
sample_state_1, sample_state_2)
|
sample_state_1, sample_state_2)
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
result = sess.run(merged)
|
|
||||||
|
|
||||||
expected = [3.0, 10.0]
|
expected = [3.0, 10.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(merged, expected)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('type_mismatch', [1.0], (1.0,), TypeError),
|
('type_mismatch', [1.0], (1.0,), TypeError),
|
||||||
|
|
|
@ -29,7 +29,6 @@ _basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
||||||
class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_nested_gaussian_sum_no_clip_no_noise(self):
|
def test_nested_gaussian_sum_no_clip_no_noise(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||||
query2 = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
query2 = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
|
||||||
|
|
||||||
|
@ -39,12 +38,10 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record2 = [4.0, [3.0, 2.0]]
|
record2 = [4.0, [3.0, 2.0]]
|
||||||
|
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [5.0, [5.0, 5.0]]
|
expected = [5.0, [5.0, 5.0]]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
def test_nested_gaussian_average_with_clip_no_noise(self):
|
def test_nested_gaussian_average_with_clip_no_noise(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
query1 = normalized_query.NormalizedQuery(
|
query1 = normalized_query.NormalizedQuery(
|
||||||
gaussian_query.GaussianSumQuery(l2_norm_clip=4.0, stddev=0.0),
|
gaussian_query.GaussianSumQuery(l2_norm_clip=4.0, stddev=0.0),
|
||||||
denominator=5.0)
|
denominator=5.0)
|
||||||
|
@ -58,12 +55,10 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
|
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
|
||||||
|
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [1.0, [1.0, 1.0]]
|
expected = [1.0, [1.0, 1.0]]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
def test_complex_nested_query(self):
|
def test_complex_nested_query(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
query_ab = gaussian_query.GaussianSumQuery(l2_norm_clip=1.0, stddev=0.0)
|
query_ab = gaussian_query.GaussianSumQuery(l2_norm_clip=1.0, stddev=0.0)
|
||||||
query_c = normalized_query.NormalizedQuery(
|
query_c = normalized_query.NormalizedQuery(
|
||||||
gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0),
|
gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0),
|
||||||
|
@ -80,12 +75,10 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
|
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
|
||||||
|
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [{'a': 1.0, 'b': 1.0}, {'c': (1.0, 1.0), 'd': [1.0]}]
|
expected = [{'a': 1.0, 'b': 1.0}, {'c': (1.0, 1.0), 'd': [1.0]}]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
def test_nested_query_with_noise(self):
|
def test_nested_query_with_noise(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
stddev = 2.71828
|
stddev = 2.71828
|
||||||
denominator = 3.14159
|
denominator = 3.14159
|
||||||
|
|
||||||
|
@ -98,11 +91,10 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
record1 = (3.0, [2.0, 1.5])
|
record1 = (3.0, [2.0, 1.5])
|
||||||
record2 = (0.0, [-1.0, -3.5])
|
record2 = (0.0, [-1.0, -3.5])
|
||||||
|
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
|
||||||
|
|
||||||
noised_averages = []
|
noised_averages = []
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
noised_averages.append(tf.nest.flatten(sess.run(query_result)))
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
|
noised_averages.append(tf.nest.flatten(query_result))
|
||||||
|
|
||||||
result_stddev = np.std(noised_averages, 0)
|
result_stddev = np.std(noised_averages, 0)
|
||||||
avg_stddev = stddev / denominator
|
avg_stddev = stddev / denominator
|
||||||
|
|
|
@ -21,29 +21,24 @@ from tensorflow_privacy.privacy.dp_query import test_utils
|
||||||
class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_sum(self):
|
def test_sum(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
record1 = tf.constant([2.0, 0.0])
|
record1 = tf.constant([2.0, 0.0])
|
||||||
record2 = tf.constant([-1.0, 1.0])
|
record2 = tf.constant([-1.0, 1.0])
|
||||||
|
|
||||||
query = no_privacy_query.NoPrivacySumQuery()
|
query = no_privacy_query.NoPrivacySumQuery()
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [1.0, 1.0]
|
expected = [1.0, 1.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
def test_no_privacy_average(self):
|
def test_no_privacy_average(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
record1 = tf.constant([5.0, 0.0])
|
record1 = tf.constant([5.0, 0.0])
|
||||||
record2 = tf.constant([-1.0, 2.0])
|
record2 = tf.constant([-1.0, 2.0])
|
||||||
|
|
||||||
query = no_privacy_query.NoPrivacyAverageQuery()
|
query = no_privacy_query.NoPrivacyAverageQuery()
|
||||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [2.0, 1.0]
|
expected = [2.0, 1.0]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
def test_no_privacy_weighted_average(self):
|
def test_no_privacy_weighted_average(self):
|
||||||
with self.cached_session() as sess:
|
|
||||||
record1 = tf.constant([4.0, 0.0])
|
record1 = tf.constant([4.0, 0.0])
|
||||||
record2 = tf.constant([-1.0, 1.0])
|
record2 = tf.constant([-1.0, 1.0])
|
||||||
|
|
||||||
|
@ -52,9 +47,8 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
query = no_privacy_query.NoPrivacyAverageQuery()
|
query = no_privacy_query.NoPrivacyAverageQuery()
|
||||||
query_result, _ = test_utils.run_query(
|
query_result, _ = test_utils.run_query(
|
||||||
query, [record1, record2], weights=weights)
|
query, [record1, record2], weights=weights)
|
||||||
result = sess.run(query_result)
|
|
||||||
expected = [0.25, 0.75]
|
expected = [0.25, 0.75]
|
||||||
self.assertAllClose(result, expected)
|
self.assertAllClose(query_result, expected)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('type_mismatch', [1.0], (1.0,), TypeError),
|
('type_mismatch', [1.0], (1.0,), TypeError),
|
||||||
|
|
Loading…
Reference in a new issue