forked from 626_privacy/tensorflow_privacy
Handle the case when the data comes from a multilabel classification problem but the provided samples happen to have just one positive label per sample.
PiperOrigin-RevId: 445468067
This commit is contained in:
parent
e0ab480e3d
commit
930c4d13e8
3 changed files with 54 additions and 0 deletions
|
@ -235,6 +235,13 @@ class AttackInputData:
|
|||
# corresponding class is absent from the example, and 1s where the
|
||||
# corresponding class is present.
|
||||
multilabel_data: Optional[bool] = None
|
||||
# In some corner cases, the provided data comes from a multi-label
|
||||
# classification model, but the samples all happen to have just 1 label. In
|
||||
# that case, the `is_multilabel_data()` test will return a `False` value. The
|
||||
# attack models will expect 1D input, which will throw an exception. Handle
|
||||
# this case by letting the user set a flag that forces the input data to be
|
||||
# treated as multilabel data.
|
||||
force_multilabel_data: bool = False
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
|
@ -450,6 +457,10 @@ class AttackInputData:
|
|||
Raises:
|
||||
ValueError if the dimensionality of the train and test data are not equal.
|
||||
"""
|
||||
# If 'force_multilabel_data' is set, then assume multilabel data going
|
||||
# forward.
|
||||
if self.force_multilabel_data:
|
||||
self.multilabel_data = True
|
||||
# If the data has already been checked for multihot encoded labels, then
|
||||
# return the result of the evaluation.
|
||||
if self.multilabel_data is not None:
|
||||
|
|
|
@ -358,6 +358,41 @@ class AttackInputDataTest(parameterized.TestCase):
|
|||
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
||||
np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
|
||||
|
||||
def test_validate_with_force_multilabel_false(self):
|
||||
attack_input = AttackInputData(
|
||||
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||
probs_test=np.array([[0.8, 0.7, 0.9]]),
|
||||
labels_train=np.array([[0, 0, 1], [0, 1, 0]]),
|
||||
labels_test=np.array([[1, 0, 0]]))
|
||||
self.assertRaisesRegex(ValueError,
|
||||
r'should be a one dimensional numpy array.',
|
||||
attack_input.validate)
|
||||
|
||||
def test_validate_with_force_multilabel_true(self):
|
||||
attack_input = AttackInputData(
|
||||
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||
probs_test=np.array([[0.8, 0.7, 0.9]]),
|
||||
labels_train=np.array([[0, 0, 1], [0, 1, 0]]),
|
||||
labels_test=np.array([[1, 0, 0]]),
|
||||
force_multilabel_data=True)
|
||||
try:
|
||||
attack_input.validate()
|
||||
except ValueError:
|
||||
# For a 'ValueError' exception the test should record a failure. All
|
||||
# other exceptions are errors.
|
||||
self.fail('ValueError not raised by validate().')
|
||||
|
||||
def test_multilabel_data_true_with_force_multilabel_true(self):
|
||||
attack_input = AttackInputData(
|
||||
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||
probs_test=np.array([[0.8, 0.7, 0.9]]),
|
||||
labels_train=np.array([[0, 0, 1], [0, 1, 0]]),
|
||||
labels_test=np.array([[1, 0, 0]]),
|
||||
force_multilabel_data=True)
|
||||
self.assertTrue(
|
||||
attack_input.multilabel_data,
|
||||
'"force_multilabel_data" is True but "multilabel_data" is False.')
|
||||
|
||||
|
||||
class RocCurveTest(absltest.TestCase):
|
||||
|
||||
|
|
|
@ -217,6 +217,7 @@ def run_attacks(attack_input: AttackInputData,
|
|||
"""
|
||||
attack_input.validate()
|
||||
attack_results = []
|
||||
attack_types = list(attack_types)
|
||||
|
||||
if slicing_spec is None:
|
||||
slicing_spec = SlicingSpec(entire_dataset=True)
|
||||
|
@ -224,6 +225,10 @@ def run_attacks(attack_input: AttackInputData,
|
|||
if slicing_spec.by_class:
|
||||
num_classes = attack_input.num_classes
|
||||
input_slice_specs = get_single_slice_specs(slicing_spec, num_classes)
|
||||
num_slice_specs = len(input_slice_specs)
|
||||
num_attacks = len(attack_types)
|
||||
logging.info('Will run %s attacks on each of %s slice specifications.',
|
||||
num_attacks, num_slice_specs)
|
||||
for single_slice_spec in input_slice_specs:
|
||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||
for attack_type in attack_types:
|
||||
|
@ -231,6 +236,9 @@ def run_attacks(attack_input: AttackInputData,
|
|||
attack_result = _run_attack(attack_input_slice, attack_type,
|
||||
balance_attacker_training, min_num_samples)
|
||||
if attack_result is not None:
|
||||
logging.info('%s attack had an AUC=%s and attacker advantage=%s',
|
||||
attack_type.name, attack_result.get_auc(),
|
||||
attack_result.get_attacker_advantage())
|
||||
attack_results.append(attack_result)
|
||||
|
||||
privacy_report_metadata = _compute_missing_privacy_report_metadata(
|
||||
|
|
Loading…
Reference in a new issue