Finish implementation of custom indices names.
PiperOrigin-RevId: 545440374
This commit is contained in:
parent
93f5a5249c
commit
a147a480a5
5 changed files with 111 additions and 23 deletions
|
@ -20,6 +20,7 @@ import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from absl import app
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -31,6 +32,15 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import privacy_report
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import privacy_report
|
||||||
|
|
||||||
|
_CUSTOM_SLICES = flags.DEFINE_boolean(
|
||||||
|
"custom_slices",
|
||||||
|
default=False,
|
||||||
|
help="If true, custom slices are used.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TRAIN_SET_SIZE = TEST_SET_SIZE = 1000
|
||||||
|
|
||||||
|
|
||||||
def generate_random_cluster(center, scale, num_points):
|
def generate_random_cluster(center, scale, num_points):
|
||||||
return np.random.normal(size=(num_points, len(center))) * scale + center
|
return np.random.normal(size=(num_points, len(center))) * scale + center
|
||||||
|
@ -104,9 +114,11 @@ def main(unused_argv):
|
||||||
# the generated clusters. More noise makes the classification harder.
|
# the generated clusters. More noise makes the classification harder.
|
||||||
noise_scale = 2
|
noise_scale = 2
|
||||||
training_features, training_labels = generate_features_and_labels(
|
training_features, training_labels = generate_features_and_labels(
|
||||||
samples_per_cluster=250, scale=noise_scale)
|
samples_per_cluster=TRAIN_SET_SIZE // 4, scale=noise_scale
|
||||||
|
)
|
||||||
test_features, test_labels = generate_features_and_labels(
|
test_features, test_labels = generate_features_and_labels(
|
||||||
samples_per_cluster=250, scale=noise_scale)
|
samples_per_cluster=TEST_SET_SIZE // 4, scale=noise_scale
|
||||||
|
)
|
||||||
|
|
||||||
num_clusters = int(round(np.max(training_labels))) + 1
|
num_clusters = int(round(np.max(training_labels))) + 1
|
||||||
|
|
||||||
|
@ -143,6 +155,21 @@ def main(unused_argv):
|
||||||
epoch_num=num_epochs_per_round * (i + 1),
|
epoch_num=num_epochs_per_round * (i + 1),
|
||||||
model_variant_label=model_name)
|
model_variant_label=model_name)
|
||||||
|
|
||||||
|
if _CUSTOM_SLICES.value:
|
||||||
|
custom_train_indices = np.array([i % 2 for i in range(TRAIN_SET_SIZE)])
|
||||||
|
custom_test_indices = np.array(
|
||||||
|
[(i + 1) % 2 for i in range(TEST_SET_SIZE)]
|
||||||
|
)
|
||||||
|
slicing_spec = data_structures.SlicingSpec(
|
||||||
|
all_custom_train_indices=[custom_train_indices],
|
||||||
|
all_custom_test_indices=[custom_test_indices],
|
||||||
|
custom_slices_names={0: "name0", 1: "name1"},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
slicing_spec = data_structures.SlicingSpec(
|
||||||
|
entire_dataset=True, by_class=True
|
||||||
|
)
|
||||||
|
|
||||||
attack_results = mia.run_attacks(
|
attack_results = mia.run_attacks(
|
||||||
data_structures.AttackInputData(
|
data_structures.AttackInputData(
|
||||||
labels_train=training_labels,
|
labels_train=training_labels,
|
||||||
|
@ -150,7 +177,7 @@ def main(unused_argv):
|
||||||
probs_train=training_pred,
|
probs_train=training_pred,
|
||||||
probs_test=test_pred,
|
probs_test=test_pred,
|
||||||
),
|
),
|
||||||
data_structures.SlicingSpec(entire_dataset=True, by_class=True),
|
slicing_spec,
|
||||||
attack_types=(
|
attack_types=(
|
||||||
data_structures.AttackType.THRESHOLD_ATTACK,
|
data_structures.AttackType.THRESHOLD_ATTACK,
|
||||||
data_structures.AttackType.LOGISTIC_REGRESSION,
|
data_structures.AttackType.LOGISTIC_REGRESSION,
|
||||||
|
|
|
@ -67,9 +67,15 @@ class SingleSliceSpec:
|
||||||
return 'Loss percentiles: %d-%d' % self.value
|
return 'Loss percentiles: %d-%d' % self.value
|
||||||
|
|
||||||
if self.feature == SlicingFeature.CUSTOM:
|
if self.feature == SlicingFeature.CUSTOM:
|
||||||
custom_train_indices, custom_test_indices, group_value = self.value
|
custom_train_indices, custom_test_indices, slice_value, slice_name = (
|
||||||
return (f'Custom indices: train = {custom_train_indices}, '
|
self.value
|
||||||
f'test = {custom_test_indices}, group_value = {group_value}')
|
)
|
||||||
|
if slice_name is not None:
|
||||||
|
return f'Custom indices: slice_name = {slice_name}'
|
||||||
|
return (
|
||||||
|
f'Custom indices: train = {custom_train_indices}, '
|
||||||
|
f'test = {custom_test_indices}, group_value = {slice_value}'
|
||||||
|
)
|
||||||
|
|
||||||
return '%s=%s' % (self.feature.name, self.value)
|
return '%s=%s' % (self.feature.name, self.value)
|
||||||
|
|
||||||
|
|
|
@ -44,8 +44,11 @@ class SingleSliceSpecTest(parameterized.TestCase):
|
||||||
(SlicingFeature.CLASS, 2, 'CLASS=2'),
|
(SlicingFeature.CLASS, 2, 'CLASS=2'),
|
||||||
(SlicingFeature.PERCENTILE, (10, 20), 'Loss percentiles: 10-20'),
|
(SlicingFeature.PERCENTILE, (10, 20), 'Loss percentiles: 10-20'),
|
||||||
(SlicingFeature.CORRECTLY_CLASSIFIED, True, 'CORRECTLY_CLASSIFIED=True'),
|
(SlicingFeature.CORRECTLY_CLASSIFIED, True, 'CORRECTLY_CLASSIFIED=True'),
|
||||||
(SlicingFeature.CUSTOM, (np.array([1]), np.array([2, 1]), 1),
|
(
|
||||||
'Custom indices: train = [1], test = [2 1], group_value = 1'),
|
SlicingFeature.CUSTOM,
|
||||||
|
(np.array([1]), np.array([2, 1]), 1, None),
|
||||||
|
'Custom indices: train = [1], test = [2 1], group_value = 1',
|
||||||
|
),
|
||||||
)
|
)
|
||||||
def testStr(self, feature, value, expected_str):
|
def testStr(self, feature, value, expected_str):
|
||||||
self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str)
|
self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str)
|
||||||
|
|
|
@ -264,13 +264,18 @@ def get_single_slice_specs(
|
||||||
f"Too many groups ({groups.size}) for slicing by custom indices. "
|
f"Too many groups ({groups.size}) for slicing by custom indices. "
|
||||||
f"Should be no more than {_MAX_NUM_OF_SLICES}.")
|
f"Should be no more than {_MAX_NUM_OF_SLICES}.")
|
||||||
for g in groups:
|
for g in groups:
|
||||||
|
group_name = None
|
||||||
if slicing_spec.custom_slices_names is not None:
|
if slicing_spec.custom_slices_names is not None:
|
||||||
if g not in slicing_spec.custom_slices_names:
|
if g not in slicing_spec.custom_slices_names:
|
||||||
raise ValueError(f"Custom slice={g} is not in custom_slices_names")
|
raise ValueError(f"Custom slice={g} is not in custom_slices_names")
|
||||||
group_id = slicing_spec.custom_slices_names[g]
|
group_name = slicing_spec.custom_slices_names[g]
|
||||||
else:
|
|
||||||
group_id = (custom_train_indices, custom_test_indices, g)
|
result.append(
|
||||||
result.append(SingleSliceSpec(SlicingFeature.CUSTOM, group_id))
|
SingleSliceSpec(
|
||||||
|
SlicingFeature.CUSTOM,
|
||||||
|
(custom_train_indices, custom_test_indices, g, group_name),
|
||||||
|
)
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -294,7 +299,8 @@ def get_slice(
|
||||||
data, slice_spec.value, return_slice_indices
|
data, slice_spec.value, return_slice_indices
|
||||||
)
|
)
|
||||||
elif slice_spec.feature == SlicingFeature.CUSTOM:
|
elif slice_spec.feature == SlicingFeature.CUSTOM:
|
||||||
custom_train_indices, custom_test_indices, group_value = slice_spec.value
|
custom_train_indices, custom_test_indices, group_value, _ = slice_spec.value
|
||||||
|
|
||||||
data_slice = _slice_by_custom_indices(
|
data_slice = _slice_by_custom_indices(
|
||||||
data,
|
data,
|
||||||
custom_train_indices,
|
custom_train_indices,
|
||||||
|
|
|
@ -109,8 +109,10 @@ class SingleSliceSpecsTest(parameterized.TestCase):
|
||||||
all_custom_train_indices=[custom_train_indices],
|
all_custom_train_indices=[custom_train_indices],
|
||||||
all_custom_test_indices=[custom_test_indices])
|
all_custom_test_indices=[custom_test_indices])
|
||||||
expected = [self.ENTIRE_DATASET_SLICE] + [
|
expected = [self.ENTIRE_DATASET_SLICE] + [
|
||||||
SingleSliceSpec(SlicingFeature.CUSTOM,
|
SingleSliceSpec(
|
||||||
(custom_train_indices, custom_test_indices, g))
|
SlicingFeature.CUSTOM,
|
||||||
|
(custom_train_indices, custom_test_indices, g, None),
|
||||||
|
)
|
||||||
for g in expected_groups
|
for g in expected_groups
|
||||||
]
|
]
|
||||||
output = get_single_slice_specs(input_data)
|
output = get_single_slice_specs(input_data)
|
||||||
|
@ -138,11 +140,49 @@ class SingleSliceSpecsTest(parameterized.TestCase):
|
||||||
for custom_train_indices, custom_test_indices, eg in zip(
|
for custom_train_indices, custom_test_indices, eg in zip(
|
||||||
all_custom_train_indices, all_custom_test_indices,
|
all_custom_train_indices, all_custom_test_indices,
|
||||||
expected_group_values):
|
expected_group_values):
|
||||||
expected.extend([
|
expected.extend(
|
||||||
SingleSliceSpec(SlicingFeature.CUSTOM,
|
[
|
||||||
(custom_train_indices, custom_test_indices, g))
|
SingleSliceSpec(
|
||||||
|
SlicingFeature.CUSTOM,
|
||||||
|
(custom_train_indices, custom_test_indices, g, None),
|
||||||
|
)
|
||||||
for g in eg
|
for g in eg
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
output = get_single_slice_specs(input_data)
|
||||||
|
self.assertTrue(_are_lists_equal(output, expected))
|
||||||
|
|
||||||
|
def test_slicing_by_custom_indices_slice_name(self):
|
||||||
|
all_custom_train_indices = [
|
||||||
|
np.array([1, 2, 1, 2]),
|
||||||
|
]
|
||||||
|
all_custom_test_indices = [
|
||||||
|
np.array([2, 2, 1, 2]),
|
||||||
|
]
|
||||||
|
custom_slices_names = {1: 'slice1', 2: 'slice2'}
|
||||||
|
expected_group_values = [[1, 2]]
|
||||||
|
|
||||||
|
input_data = SlicingSpec(
|
||||||
|
all_custom_train_indices=all_custom_train_indices,
|
||||||
|
all_custom_test_indices=all_custom_test_indices,
|
||||||
|
custom_slices_names=custom_slices_names,
|
||||||
|
)
|
||||||
|
expected = [self.ENTIRE_DATASET_SLICE]
|
||||||
|
for custom_train_indices, custom_test_indices, eg in zip(
|
||||||
|
all_custom_train_indices, all_custom_test_indices, expected_group_values
|
||||||
|
):
|
||||||
|
for g in eg:
|
||||||
|
expected.append(
|
||||||
|
SingleSliceSpec(
|
||||||
|
SlicingFeature.CUSTOM,
|
||||||
|
(
|
||||||
|
custom_train_indices,
|
||||||
|
custom_test_indices,
|
||||||
|
g,
|
||||||
|
custom_slices_names[g],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
output = get_single_slice_specs(input_data)
|
output = get_single_slice_specs(input_data)
|
||||||
self.assertTrue(_are_lists_equal(output, expected))
|
self.assertTrue(_are_lists_equal(output, expected))
|
||||||
|
|
||||||
|
@ -298,7 +338,9 @@ class GetSliceTest(parameterized.TestCase):
|
||||||
custom_train_indices = np.array([2, 2, 100, 4])
|
custom_train_indices = np.array([2, 2, 100, 4])
|
||||||
custom_test_indices = np.array([100, 2, 2, 2])
|
custom_test_indices = np.array([100, 2, 2, 2])
|
||||||
custom_slice = SingleSliceSpec(
|
custom_slice = SingleSliceSpec(
|
||||||
SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2))
|
SlicingFeature.CUSTOM,
|
||||||
|
(custom_train_indices, custom_test_indices, 2, None),
|
||||||
|
)
|
||||||
output = get_slice(self.input_data, custom_slice)
|
output = get_slice(self.input_data, custom_slice)
|
||||||
np.testing.assert_array_equal(output.logits_train,
|
np.testing.assert_array_equal(output.logits_train,
|
||||||
np.array([[0, 1, 0], [2, 0, 3]]))
|
np.array([[0, 1, 0], [2, 0, 3]]))
|
||||||
|
@ -325,7 +367,9 @@ class GetSliceTest(parameterized.TestCase):
|
||||||
def test_slice_by_custom_indices_wrong_size(self, custom_train_indices,
|
def test_slice_by_custom_indices_wrong_size(self, custom_train_indices,
|
||||||
custom_test_indices):
|
custom_test_indices):
|
||||||
custom_slice = SingleSliceSpec(
|
custom_slice = SingleSliceSpec(
|
||||||
SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2))
|
SlicingFeature.CUSTOM,
|
||||||
|
(custom_train_indices, custom_test_indices, 2, None),
|
||||||
|
)
|
||||||
self.assertRaises(ValueError, get_slice, self.input_data, custom_slice)
|
self.assertRaises(ValueError, get_slice, self.input_data, custom_slice)
|
||||||
|
|
||||||
|
|
||||||
|
@ -420,7 +464,9 @@ class GetSliceTestForMultilabelData(absltest.TestCase):
|
||||||
custom_train_indices = np.array([2, 2, 100, 4])
|
custom_train_indices = np.array([2, 2, 100, 4])
|
||||||
custom_test_indices = np.array([100, 2, 2, 2])
|
custom_test_indices = np.array([100, 2, 2, 2])
|
||||||
custom_slice = SingleSliceSpec(
|
custom_slice = SingleSliceSpec(
|
||||||
SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2))
|
SlicingFeature.CUSTOM,
|
||||||
|
(custom_train_indices, custom_test_indices, 2, 'slice_name'),
|
||||||
|
)
|
||||||
output = get_slice(self.input_data, custom_slice)
|
output = get_slice(self.input_data, custom_slice)
|
||||||
# Check logits.
|
# Check logits.
|
||||||
with self.subTest(msg='Check logits'):
|
with self.subTest(msg='Check logits'):
|
||||||
|
|
Loading…
Reference in a new issue