forked from 626_privacy/tensorflow_privacy
Allows slicing by custom indices.
PiperOrigin-RevId: 486998645
This commit is contained in:
parent
ec747a8d75
commit
2040f08f0d
5 changed files with 258 additions and 16 deletions
|
@ -20,7 +20,7 @@ import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, Iterable, MutableSequence, Optional, Union
|
from typing import Any, Iterable, MutableSequence, Optional, Union, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -39,6 +39,7 @@ class SlicingFeature(enum.Enum):
|
||||||
CLASS = 'class'
|
CLASS = 'class'
|
||||||
PERCENTILE = 'percentile'
|
PERCENTILE = 'percentile'
|
||||||
CORRECTLY_CLASSIFIED = 'correctly_classified'
|
CORRECTLY_CLASSIFIED = 'correctly_classified'
|
||||||
|
CUSTOM = 'custom'
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -65,6 +66,11 @@ class SingleSliceSpec:
|
||||||
if self.feature == SlicingFeature.PERCENTILE:
|
if self.feature == SlicingFeature.PERCENTILE:
|
||||||
return 'Loss percentiles: %d-%d' % self.value
|
return 'Loss percentiles: %d-%d' % self.value
|
||||||
|
|
||||||
|
if self.feature == SlicingFeature.CUSTOM:
|
||||||
|
custom_train_indices, custom_test_indices, group_value = self.value
|
||||||
|
return (f'Custom indices: train = {custom_train_indices}, '
|
||||||
|
f'test = {custom_test_indices}, group_value = {group_value}')
|
||||||
|
|
||||||
return '%s=%s' % (self.feature.name, self.value)
|
return '%s=%s' % (self.feature.name, self.value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,6 +97,37 @@ class SlicingSpec:
|
||||||
# examples will be generated.
|
# examples will be generated.
|
||||||
by_classification_correctness: bool = False
|
by_classification_correctness: bool = False
|
||||||
|
|
||||||
|
# When both `all_custom_train_indices` and `all_custom_test_indices` are set,
|
||||||
|
# will slice by custom indices.
|
||||||
|
# `custom_train_indices` and `custom_test_indices` are sequences containing
|
||||||
|
# the same number of arrays. Each array indicates the grouping of training and
|
||||||
|
# test examples, and should have a length equal to the number of training and
|
||||||
|
# test examples.
|
||||||
|
# For example, suppose we have 3 training examples (a1, a2, a3), and
|
||||||
|
# 2 test examples (b1, b2). Then,
|
||||||
|
# all_custom_train_indices = [np.array([2, 1, 2]), np.array([0, 0, 1])]
|
||||||
|
# all_custom_test_indices = [np.array([1, 2]), np.array([1, 0])]
|
||||||
|
# means we are going to consider two ways of slicing them:
|
||||||
|
# 1. two groups: (a2, b1) corresponding to value 1, (a1, a3, b2) corresponding
|
||||||
|
# to value 2.
|
||||||
|
# 2. two groups: (a1, a2, b2) corresponding to value 0, (a3, b1) corresponding
|
||||||
|
# to value 1.
|
||||||
|
all_custom_train_indices: Optional[Sequence[np.ndarray]] = None
|
||||||
|
all_custom_test_indices: Optional[Sequence[np.ndarray]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.all_custom_train_indices and not self.all_custom_test_indices:
|
||||||
|
return
|
||||||
|
if bool(self.all_custom_train_indices) != bool(
|
||||||
|
self.all_custom_test_indices):
|
||||||
|
raise ValueError('custom_train_indices and custom_test_indices must '
|
||||||
|
'be provided or set to None at the same time.')
|
||||||
|
if len(self.all_custom_train_indices) != len(self.all_custom_test_indices):
|
||||||
|
raise ValueError('all_custom_train_indices and all_custom_test_indices '
|
||||||
|
'should have the same length, but got'
|
||||||
|
f'{len(self.all_custom_train_indices)} and '
|
||||||
|
f'{len(self.all_custom_test_indices)}.')
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""Only keeps the True values."""
|
"""Only keeps the True values."""
|
||||||
result = ['SlicingSpec(']
|
result = ['SlicingSpec(']
|
||||||
|
@ -107,6 +144,8 @@ class SlicingSpec:
|
||||||
result.append(' By percentiles,')
|
result.append(' By percentiles,')
|
||||||
if self.by_classification_correctness:
|
if self.by_classification_correctness:
|
||||||
result.append(' By classification correctness,')
|
result.append(' By classification correctness,')
|
||||||
|
if self.all_custom_train_indices:
|
||||||
|
result.append(' By custom indices,')
|
||||||
result.append(')')
|
result.append(')')
|
||||||
return '\n'.join(result)
|
return '\n'.join(result)
|
||||||
|
|
||||||
|
@ -123,8 +162,9 @@ class AttackType(enum.Enum):
|
||||||
@property
|
@property
|
||||||
def is_trained_attack(self):
|
def is_trained_attack(self):
|
||||||
"""Returns whether this type of attack requires training a model."""
|
"""Returns whether this type of attack requires training a model."""
|
||||||
return (self != AttackType.THRESHOLD_ATTACK) and (
|
# Compare by name instead of the variable itself to support module reload.
|
||||||
self != AttackType.THRESHOLD_ENTROPY_ATTACK)
|
return self.name not in (AttackType.THRESHOLD_ATTACK.name,
|
||||||
|
AttackType.THRESHOLD_ENTROPY_ATTACK.name)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION."""
|
"""Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION."""
|
||||||
|
|
|
@ -44,6 +44,8 @@ class SingleSliceSpecTest(parameterized.TestCase):
|
||||||
(SlicingFeature.CLASS, 2, 'CLASS=2'),
|
(SlicingFeature.CLASS, 2, 'CLASS=2'),
|
||||||
(SlicingFeature.PERCENTILE, (10, 20), 'Loss percentiles: 10-20'),
|
(SlicingFeature.PERCENTILE, (10, 20), 'Loss percentiles: 10-20'),
|
||||||
(SlicingFeature.CORRECTLY_CLASSIFIED, True, 'CORRECTLY_CLASSIFIED=True'),
|
(SlicingFeature.CORRECTLY_CLASSIFIED, True, 'CORRECTLY_CLASSIFIED=True'),
|
||||||
|
(SlicingFeature.CUSTOM, (np.array([1]), np.array([2, 1]), 1),
|
||||||
|
'Custom indices: train = [1], test = [2 1], group_value = 1'),
|
||||||
)
|
)
|
||||||
def testStr(self, feature, value, expected_str):
|
def testStr(self, feature, value, expected_str):
|
||||||
self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str)
|
self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str)
|
||||||
|
|
|
@ -26,13 +26,16 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
|
||||||
|
|
||||||
|
|
||||||
|
_MAX_NUM_OF_SLICES = 1000
|
||||||
|
|
||||||
|
|
||||||
def _slice_if_not_none(a, idx):
|
def _slice_if_not_none(a, idx):
|
||||||
return None if a is None else a[idx]
|
return None if a is None else a[idx]
|
||||||
|
|
||||||
|
|
||||||
def _slice_data_by_indices(data: AttackInputData, idx_train,
|
def _slice_data_by_indices(data: AttackInputData, idx_train,
|
||||||
idx_test) -> AttackInputData:
|
idx_test) -> AttackInputData:
|
||||||
"""Slices train fields with with idx_train and test fields with and idx_test."""
|
"""Slices train fields with idx_train and test fields with idx_test."""
|
||||||
|
|
||||||
result = AttackInputData()
|
result = AttackInputData()
|
||||||
|
|
||||||
|
@ -128,10 +131,55 @@ def _slice_by_classification_correctness(data: AttackInputData,
|
||||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||||
|
|
||||||
|
|
||||||
|
def _slice_by_custom_indices(data: AttackInputData,
|
||||||
|
custom_train_indices: np.ndarray,
|
||||||
|
custom_test_indices: np.ndarray,
|
||||||
|
group_value: int) -> AttackInputData:
|
||||||
|
"""Slices attack inputs by custom indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Data to be used as input to the attack models.
|
||||||
|
custom_train_indices: The group indices of each training example.
|
||||||
|
custom_test_indices: The group indices of each test example.
|
||||||
|
group_value: The group value to pick.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AttackInputData object containing the sliced data.
|
||||||
|
"""
|
||||||
|
train_size, test_size = data.get_train_size(), data.get_test_size()
|
||||||
|
if custom_train_indices.shape[0] != train_size:
|
||||||
|
raise ValueError(
|
||||||
|
"custom_train_indices should have the same number of elements as "
|
||||||
|
f"the training data, but got {custom_train_indices.shape} and "
|
||||||
|
f"{train_size}")
|
||||||
|
if custom_test_indices.shape[0] != test_size:
|
||||||
|
raise ValueError(
|
||||||
|
"custom_test_indices should have the same number of elements as "
|
||||||
|
f"the test data, but got {custom_test_indices.shape} and "
|
||||||
|
f"{test_size}")
|
||||||
|
idx_train = custom_train_indices == group_value
|
||||||
|
idx_test = custom_test_indices == group_value
|
||||||
|
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||||
|
|
||||||
|
|
||||||
def get_single_slice_specs(
|
def get_single_slice_specs(
|
||||||
slicing_spec: SlicingSpec,
|
slicing_spec: SlicingSpec,
|
||||||
num_classes: Optional[int] = None) -> List[SingleSliceSpec]:
|
num_classes: Optional[int] = None) -> List[SingleSliceSpec]:
|
||||||
"""Returns slices of data according to slicing_spec."""
|
"""Returns slices of data according to slicing_spec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slicing_spec: the slicing specification
|
||||||
|
num_classes: number of classes of the examples. Required when slicing by
|
||||||
|
class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Slices of data according to the slicing specification.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the number of slices is above `_MAX_NUM_OF_SLICES` when
|
||||||
|
slicing by class or slicing with custom indices. Or, if `num_classes` is
|
||||||
|
not provided when slicing by class.
|
||||||
|
"""
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
if slicing_spec.entire_dataset:
|
if slicing_spec.entire_dataset:
|
||||||
|
@ -141,10 +189,12 @@ def get_single_slice_specs(
|
||||||
by_class = slicing_spec.by_class
|
by_class = slicing_spec.by_class
|
||||||
if isinstance(by_class, bool):
|
if isinstance(by_class, bool):
|
||||||
if by_class:
|
if by_class:
|
||||||
assert num_classes, "When by_class == True, num_classes should be given."
|
if not num_classes:
|
||||||
assert 0 <= num_classes <= 1000, (
|
raise ValueError("When by_class == True, num_classes should be given.")
|
||||||
f"Too much classes for slicing by classes. "
|
if not 0 <= num_classes <= _MAX_NUM_OF_SLICES:
|
||||||
f"Found {num_classes}.")
|
raise ValueError(f"Too many classes for slicing by classes. "
|
||||||
|
f"Found {num_classes}."
|
||||||
|
f"Should be no more than {_MAX_NUM_OF_SLICES}.")
|
||||||
for c in range(num_classes):
|
for c in range(num_classes):
|
||||||
result.append(SingleSliceSpec(SlicingFeature.CLASS, c))
|
result.append(SingleSliceSpec(SlicingFeature.CLASS, c))
|
||||||
elif isinstance(by_class, int):
|
elif isinstance(by_class, int):
|
||||||
|
@ -164,6 +214,23 @@ def get_single_slice_specs(
|
||||||
result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True))
|
result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True))
|
||||||
result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, False))
|
result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, False))
|
||||||
|
|
||||||
|
# Create slices by custom indices.
|
||||||
|
if slicing_spec.all_custom_train_indices:
|
||||||
|
for custom_train_indices, custom_test_indices in zip(
|
||||||
|
slicing_spec.all_custom_train_indices,
|
||||||
|
slicing_spec.all_custom_test_indices):
|
||||||
|
groups = np.intersect1d(
|
||||||
|
np.unique(custom_train_indices),
|
||||||
|
np.unique(custom_test_indices),
|
||||||
|
assume_unique=True)
|
||||||
|
if not 0 <= groups.size <= _MAX_NUM_OF_SLICES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Too many groups ({groups.size}) for slicing by custom indices. "
|
||||||
|
f"Should be no more than {_MAX_NUM_OF_SLICES}.")
|
||||||
|
for g in groups:
|
||||||
|
result.append(
|
||||||
|
SingleSliceSpec(SlicingFeature.CUSTOM,
|
||||||
|
(custom_train_indices, custom_test_indices, g)))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -179,6 +246,10 @@ def get_slice(data: AttackInputData,
|
||||||
data_slice = _slice_by_percentiles(data, from_percentile, to_percentile)
|
data_slice = _slice_by_percentiles(data, from_percentile, to_percentile)
|
||||||
elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED:
|
elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED:
|
||||||
data_slice = _slice_by_classification_correctness(data, slice_spec.value)
|
data_slice = _slice_by_classification_correctness(data, slice_spec.value)
|
||||||
|
elif slice_spec.feature == SlicingFeature.CUSTOM:
|
||||||
|
custom_train_indices, custom_test_indices, group_value = slice_spec.value
|
||||||
|
data_slice = _slice_by_custom_indices(data, custom_train_indices,
|
||||||
|
custom_test_indices, group_value)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unknown slice spec feature "%s"' % slice_spec.feature)
|
raise ValueError('Unknown slice spec feature "%s"' % slice_spec.feature)
|
||||||
|
|
||||||
|
|
|
@ -13,10 +13,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
from absl.testing.absltest import mock
|
from absl.testing.absltest import mock
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleSliceSpec
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingFeature
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingFeature
|
||||||
|
@ -38,7 +39,7 @@ def _are_lists_equal(lhs, rhs) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class SingleSliceSpecsTest(absltest.TestCase):
|
class SingleSliceSpecsTest(parameterized.TestCase):
|
||||||
"""Tests for get_single_slice_specs."""
|
"""Tests for get_single_slice_specs."""
|
||||||
|
|
||||||
ENTIRE_DATASET_SLICE = SingleSliceSpec()
|
ENTIRE_DATASET_SLICE = SingleSliceSpec()
|
||||||
|
@ -95,8 +96,81 @@ class SingleSliceSpecsTest(absltest.TestCase):
|
||||||
output = get_single_slice_specs(input_data, n_classes)
|
output = get_single_slice_specs(input_data, n_classes)
|
||||||
self.assertLen(output, expected_slices)
|
self.assertLen(output, expected_slices)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(np.array([1, 2, 1, 2]), np.array([2, 2, 1, 2]), [1, 2]),
|
||||||
|
(np.array([0, -1, 2, -1, 2]), np.array([2, 2, -1, 2]), [-1, 2]),
|
||||||
|
(np.array([1, 2, 1, 2] + list(range(5000))), np.array([2, 2, 1]), [1, 2]),
|
||||||
|
(np.array([1, 2, 1, 2]), np.array([3, 4]), []),
|
||||||
|
)
|
||||||
|
def test_slicing_by_custom_indices_one_pair(self, custom_train_indices,
|
||||||
|
custom_test_indices,
|
||||||
|
expected_groups):
|
||||||
|
input_data = SlicingSpec(
|
||||||
|
all_custom_train_indices=[custom_train_indices],
|
||||||
|
all_custom_test_indices=[custom_test_indices])
|
||||||
|
expected = [self.ENTIRE_DATASET_SLICE] + [
|
||||||
|
SingleSliceSpec(SlicingFeature.CUSTOM,
|
||||||
|
(custom_train_indices, custom_test_indices, g))
|
||||||
|
for g in expected_groups
|
||||||
|
]
|
||||||
|
output = get_single_slice_specs(input_data)
|
||||||
|
self.assertTrue(_are_lists_equal(output, expected))
|
||||||
|
|
||||||
class GetSliceTest(absltest.TestCase):
|
def test_slicing_by_custom_indices_multi_pairs(self):
|
||||||
|
all_custom_train_indices = [
|
||||||
|
np.array([1, 2, 1, 2]),
|
||||||
|
np.array([0, -1, 2, -1, 2]),
|
||||||
|
np.array([1, 2, 1, 2] + list(range(5000))),
|
||||||
|
np.array([1, 2, 1, 2])
|
||||||
|
]
|
||||||
|
all_custom_test_indices = [
|
||||||
|
np.array([2, 2, 1, 2]),
|
||||||
|
np.array([2, 2, -1, 2]),
|
||||||
|
np.array([2, 2, 1]),
|
||||||
|
np.array([3, 4])
|
||||||
|
]
|
||||||
|
expected_group_values = [[1, 2], [-1, 2], [1, 2], []]
|
||||||
|
|
||||||
|
input_data = SlicingSpec(
|
||||||
|
all_custom_train_indices=all_custom_train_indices,
|
||||||
|
all_custom_test_indices=all_custom_test_indices)
|
||||||
|
expected = [self.ENTIRE_DATASET_SLICE]
|
||||||
|
for custom_train_indices, custom_test_indices, eg in zip(
|
||||||
|
all_custom_train_indices, all_custom_test_indices,
|
||||||
|
expected_group_values):
|
||||||
|
expected.extend([
|
||||||
|
SingleSliceSpec(SlicingFeature.CUSTOM,
|
||||||
|
(custom_train_indices, custom_test_indices, g))
|
||||||
|
for g in eg
|
||||||
|
])
|
||||||
|
output = get_single_slice_specs(input_data)
|
||||||
|
self.assertTrue(_are_lists_equal(output, expected))
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
([np.array([1, 2])], None),
|
||||||
|
(None, [np.array([1, 2])]),
|
||||||
|
([], [np.array([1, 2])]),
|
||||||
|
([np.array([1, 2])], [np.array([1, 2]),
|
||||||
|
np.array([1, 2])]),
|
||||||
|
)
|
||||||
|
def test_slicing_by_custom_indices_wrong_indices(self,
|
||||||
|
all_custom_train_indices,
|
||||||
|
all_custom_test_indices):
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
SlicingSpec,
|
||||||
|
all_custom_train_indices=all_custom_train_indices,
|
||||||
|
all_custom_test_indices=all_custom_test_indices)
|
||||||
|
|
||||||
|
def test_slicing_by_custom_indices_too_many_groups(self):
|
||||||
|
input_data = SlicingSpec(
|
||||||
|
all_custom_train_indices=[np.arange(1001),
|
||||||
|
np.arange(3)],
|
||||||
|
all_custom_test_indices=[np.arange(1001), np.arange(3)])
|
||||||
|
self.assertRaises(ValueError, get_single_slice_specs, input_data)
|
||||||
|
|
||||||
|
|
||||||
|
class GetSliceTest(parameterized.TestCase):
|
||||||
|
|
||||||
def __init__(self, methodname):
|
def __init__(self, methodname):
|
||||||
"""Initialize the test class."""
|
"""Initialize the test class."""
|
||||||
|
@ -210,6 +284,40 @@ class GetSliceTest(absltest.TestCase):
|
||||||
self.assertTrue((output.labels_train == [0, 2]).all())
|
self.assertTrue((output.labels_train == [0, 2]).all())
|
||||||
self.assertTrue((output.labels_test == [1, 2, 0]).all())
|
self.assertTrue((output.labels_test == [1, 2, 0]).all())
|
||||||
|
|
||||||
|
def test_slice_by_custom_indices(self):
|
||||||
|
custom_train_indices = np.array([2, 2, 100, 4])
|
||||||
|
custom_test_indices = np.array([100, 2, 2, 2])
|
||||||
|
custom_slice = SingleSliceSpec(
|
||||||
|
SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2))
|
||||||
|
output = get_slice(self.input_data, custom_slice)
|
||||||
|
np.testing.assert_array_equal(output.logits_train,
|
||||||
|
np.array([[0, 1, 0], [2, 0, 3]]))
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
output.logits_test, np.array([[12, 13, 0], [14, 15, 0], [0, 16, 17]]))
|
||||||
|
np.testing.assert_array_equal(output.probs_train,
|
||||||
|
np.array([[0, 1, 0], [0.1, 0, 0.7]]))
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
output.probs_test, np.array([[0.1, 0.9, 0], [0.15, 0.85, 0], [0, 0,
|
||||||
|
1]]))
|
||||||
|
np.testing.assert_array_equal(output.labels_train, np.array([1, 0]))
|
||||||
|
np.testing.assert_array_equal(output.labels_test, np.array([2, 0, 2]))
|
||||||
|
np.testing.assert_array_equal(output.loss_train, np.array([2, 0.25]))
|
||||||
|
np.testing.assert_array_equal(output.loss_test, np.array([3.5, 7, 4.5]))
|
||||||
|
np.testing.assert_array_equal(output.entropy_train, np.array([0.4, 8]))
|
||||||
|
np.testing.assert_array_equal(output.entropy_test,
|
||||||
|
np.array([10.5, 4.5, 0.3]))
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(np.array([2, 2, 100]), np.array([100, 2, 2])),
|
||||||
|
(np.array([2, 2, 100, 4]), np.array([100, 2, 2])),
|
||||||
|
(np.array([2, 100, 4]), np.array([100, 2, 2, 2])),
|
||||||
|
)
|
||||||
|
def test_slice_by_custom_indices_wrong_size(self, custom_train_indices,
|
||||||
|
custom_test_indices):
|
||||||
|
custom_slice = SingleSliceSpec(
|
||||||
|
SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2))
|
||||||
|
self.assertRaises(ValueError, get_slice, self.input_data, custom_slice)
|
||||||
|
|
||||||
|
|
||||||
class GetSliceTestForMultilabelData(absltest.TestCase):
|
class GetSliceTestForMultilabelData(absltest.TestCase):
|
||||||
|
|
||||||
|
@ -288,6 +396,26 @@ class GetSliceTestForMultilabelData(absltest.TestCase):
|
||||||
False)
|
False)
|
||||||
self.assertRaises(ValueError, get_slice, self.input_data, percentile_slice)
|
self.assertRaises(ValueError, get_slice, self.input_data, percentile_slice)
|
||||||
|
|
||||||
|
def test_slice_by_custom_indices(self):
|
||||||
|
custom_train_indices = np.array([2, 2, 100, 4])
|
||||||
|
custom_test_indices = np.array([100, 2, 2, 2])
|
||||||
|
custom_slice = SingleSliceSpec(
|
||||||
|
SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2))
|
||||||
|
output = get_slice(self.input_data, custom_slice)
|
||||||
|
# Check logits.
|
||||||
|
with self.subTest(msg='Check logits'):
|
||||||
|
np.testing.assert_array_equal(output.logits_train,
|
||||||
|
np.array([[0, 1, 0], [2, 0, 3]]))
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
output.logits_test, np.array([[12, 13, 0], [14, 15, 0], [0, 16, 17]]))
|
||||||
|
|
||||||
|
# Check labels.
|
||||||
|
with self.subTest(msg='Check labels'):
|
||||||
|
np.testing.assert_array_equal(output.labels_train,
|
||||||
|
np.array([[0, 1, 1], [1, 0, 1]]))
|
||||||
|
np.testing.assert_array_equal(output.labels_test,
|
||||||
|
np.array([[0, 1, 0], [0, 1, 0], [0, 0, 1]]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -285,12 +285,13 @@ class KNearestNeighborsAttacker(TrainedAttacker):
|
||||||
def create_attacker(attack_type,
|
def create_attacker(attack_type,
|
||||||
backend: Optional[str] = None) -> TrainedAttacker:
|
backend: Optional[str] = None) -> TrainedAttacker:
|
||||||
"""Returns the corresponding attacker for the provided attack_type."""
|
"""Returns the corresponding attacker for the provided attack_type."""
|
||||||
if attack_type == data_structures.AttackType.LOGISTIC_REGRESSION:
|
# Compare by name instead of the variable itself to support module reload.
|
||||||
|
if attack_type.name == data_structures.AttackType.LOGISTIC_REGRESSION.name:
|
||||||
return LogisticRegressionAttacker(backend=backend)
|
return LogisticRegressionAttacker(backend=backend)
|
||||||
if attack_type == data_structures.AttackType.MULTI_LAYERED_PERCEPTRON:
|
if attack_type.name == data_structures.AttackType.MULTI_LAYERED_PERCEPTRON.name:
|
||||||
return MultilayerPerceptronAttacker(backend=backend)
|
return MultilayerPerceptronAttacker(backend=backend)
|
||||||
if attack_type == data_structures.AttackType.RANDOM_FOREST:
|
if attack_type.name == data_structures.AttackType.RANDOM_FOREST.name:
|
||||||
return RandomForestAttacker(backend=backend)
|
return RandomForestAttacker(backend=backend)
|
||||||
if attack_type == data_structures.AttackType.K_NEAREST_NEIGHBORS:
|
if attack_type.name == data_structures.AttackType.K_NEAREST_NEIGHBORS.name:
|
||||||
return KNearestNeighborsAttacker(backend=backend)
|
return KNearestNeighborsAttacker(backend=backend)
|
||||||
raise NotImplementedError('Attack type %s not implemented yet.' % attack_type)
|
raise NotImplementedError('Attack type %s not implemented yet.' % attack_type)
|
||||||
|
|
Loading…
Reference in a new issue