Fix SumAggregationDPQuery's initial_sample_state raising a ValueError when called on TensorSpec.

PiperOrigin-RevId: 480474975
This commit is contained in:
A. Unique TensorFlower 2022-10-11 16:01:19 -07:00
parent 0738d6f555
commit f8ed0fcd9c
3 changed files with 40 additions and 1 deletions

View file

@ -15,6 +15,14 @@ py_library(
srcs_version = "PY3", srcs_version = "PY3",
) )
py_test(
name = "dp_query_test",
srcs = ["dp_query_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":no_privacy_query"],
)
py_library( py_library(
name = "discrete_gaussian_utils", name = "discrete_gaussian_utils",
srcs = ["discrete_gaussian_utils.py"], srcs = ["discrete_gaussian_utils.py"],

View file

@ -270,7 +270,7 @@ def _zeros_like(arg):
"""A `zeros_like` function that also works for `tf.TensorSpec`s.""" """A `zeros_like` function that also works for `tf.TensorSpec`s."""
try: try:
arg = tf.convert_to_tensor(value=arg) arg = tf.convert_to_tensor(value=arg)
except TypeError: except (TypeError, ValueError):
pass pass
return tf.zeros(arg.shape, arg.dtype) return tf.zeros(arg.shape, arg.dtype)

View file

@ -0,0 +1,31 @@
# Copyright 2022, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import no_privacy_query
class SumAggregationQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_initial_sample_state_works_on_tensorspecs(self):
query = no_privacy_query.NoPrivacySumQuery()
template = tf.TensorSpec.from_tensor(tf.constant([1.0, 2.0]))
sample_state = query.initial_sample_state(template)
expected = [0.0, 0.0]
self.assertAllClose(sample_state, expected)
if __name__ == '__main__':
tf.test.main()