diff --git a/tensorflow_privacy/privacy/dp_query/BUILD b/tensorflow_privacy/privacy/dp_query/BUILD index 787de79..9fed57d 100644 --- a/tensorflow_privacy/privacy/dp_query/BUILD +++ b/tensorflow_privacy/privacy/dp_query/BUILD @@ -15,6 +15,14 @@ py_library( srcs_version = "PY3", ) +py_test( + name = "dp_query_test", + srcs = ["dp_query_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":no_privacy_query"], +) + py_library( name = "discrete_gaussian_utils", srcs = ["discrete_gaussian_utils.py"], diff --git a/tensorflow_privacy/privacy/dp_query/dp_query.py b/tensorflow_privacy/privacy/dp_query/dp_query.py index 008ff9b..7e7220c 100644 --- a/tensorflow_privacy/privacy/dp_query/dp_query.py +++ b/tensorflow_privacy/privacy/dp_query/dp_query.py @@ -270,7 +270,7 @@ def _zeros_like(arg): """A `zeros_like` function that also works for `tf.TensorSpec`s.""" try: arg = tf.convert_to_tensor(value=arg) - except TypeError: + except (TypeError, ValueError): pass return tf.zeros(arg.shape, arg.dtype) diff --git a/tensorflow_privacy/privacy/dp_query/dp_query_test.py b/tensorflow_privacy/privacy/dp_query/dp_query_test.py new file mode 100644 index 0000000..c76d86c --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/dp_query_test.py @@ -0,0 +1,31 @@ +# Copyright 2022, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf +from tensorflow_privacy.privacy.dp_query import no_privacy_query + + +class SumAggregationQueryTest(tf.test.TestCase, parameterized.TestCase): + + def test_initial_sample_state_works_on_tensorspecs(self): + query = no_privacy_query.NoPrivacySumQuery() + template = tf.TensorSpec.from_tensor(tf.constant([1.0, 2.0])) + sample_state = query.initial_sample_state(template) + expected = [0.0, 0.0] + self.assertAllClose(sample_state, expected) + + +if __name__ == '__main__': + tf.test.main()