diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py index ca8a494..e785dd9 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py @@ -42,9 +42,9 @@ def _slice_data_by_indices(data: AttackInputData, idx_train, result.labels_train = _slice_if_not_none(data.labels_train, idx_train) result.loss_train = _slice_if_not_none(data.loss_train, idx_train) result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train) - # Copy over sample weights if provided. - result.sample_weight_train = data.sample_weight_train - result.sample_weight_test = data.sample_weight_test + # Slice sample weights if provided. + result.sample_weight_train = _slice_if_not_none(data.sample_weight_train, + idx_train) # Slice test data. result.logits_test = _slice_if_not_none(data.logits_test, idx_test) @@ -52,6 +52,9 @@ def _slice_data_by_indices(data: AttackInputData, idx_train, result.labels_test = _slice_if_not_none(data.labels_test, idx_test) result.loss_test = _slice_if_not_none(data.loss_test, idx_test) result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test) + # Slice sample weights if provided. + result.sample_weight_test = _slice_if_not_none(data.sample_weight_test, + idx_test) # A slice has the same multilabel status as the original data. This is because # of the way multilabel status is computed. A dataset is multilabel if at diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py index c0cdd51..0324b9a 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py @@ -115,8 +115,8 @@ class GetSliceTest(absltest.TestCase): loss_test = np.array([0.5, 3.5, 7, 4.5]) entropy_train = np.array([0.4, 8, 0.6, 10]) entropy_test = np.array([15, 10.5, 4.5, 0.3]) - sample_weight_train = np.array([1.0, 0.5]) - sample_weight_test = np.array([0.5, 1.0]) + sample_weight_train = np.array([1.0, 0.2, 0.5, 0.8]) + sample_weight_test = np.array([0.5, 1.0, 0.1, 0.8]) self.input_data = AttackInputData( logits_train=logits_train, @@ -175,8 +175,8 @@ class GetSliceTest(absltest.TestCase): # Check sample weights self.assertLen(output.sample_weight_train, 2) np.testing.assert_array_equal(output.sample_weight_train, [1.0, 0.5]) - self.assertLen(output.sample_weight_test, 2) - np.testing.assert_array_equal(output.sample_weight_test, [0.5, 1.0]) + self.assertLen(output.sample_weight_test, 1) + np.testing.assert_array_equal(output.sample_weight_test, [0.5]) def test_slice_by_percentile(self): percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))