From 510dd207d57c2fe18d6bd06c1289f21930b6eaa4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Jul 2020 02:43:29 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 321742857 --- .../data_structures_test.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py new file mode 100644 index 0000000..721f9e1 --- /dev/null +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -0,0 +1,39 @@ +# Copyright 2020, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for tensorflow_privacy.privacy.membership_inference_attack.data_structures.""" +from absl.testing import absltest +import numpy as np +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData + + +class AttackInputDataTest(absltest.TestCase): + + def test_get_loss(self): + attack_input = AttackInputData( + logits_train=np.array([[0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]), + logits_test=np.array([[0.2, 0.3, 0.5], [0.3, 0.5, 0.2]]), + labels_train=np.array([1, 0]), + labels_test=np.array([0, 1]) + ) + + np.testing.assert_almost_equal( + attack_input.get_loss_train().tolist(), [0.5, 0.2]) + np.testing.assert_almost_equal( + attack_input.get_loss_test().tolist(), [0.2, 0.5]) + + +if __name__ == '__main__': + absltest.main()