forked from 626_privacy/tensorflow_privacy
Handle the case when the data comes from a multilabel classification problem but the provided samples happen to have just one positive label per sample.
PiperOrigin-RevId: 445468067
This commit is contained in:
parent
e0ab480e3d
commit
930c4d13e8
3 changed files with 54 additions and 0 deletions
|
@ -235,6 +235,13 @@ class AttackInputData:
|
||||||
# corresponding class is absent from the example, and 1s where the
|
# corresponding class is absent from the example, and 1s where the
|
||||||
# corresponding class is present.
|
# corresponding class is present.
|
||||||
multilabel_data: Optional[bool] = None
|
multilabel_data: Optional[bool] = None
|
||||||
|
# In some corner cases, the provided data comes from a multi-label
|
||||||
|
# classification model, but the samples all happen to have just 1 label. In
|
||||||
|
# that case, the `is_multilabel_data()` test will return a `False` value. The
|
||||||
|
# attack models will expect 1D input, which will throw an exception. Handle
|
||||||
|
# this case by letting the user set a flag that forces the input data to be
|
||||||
|
# treated as multilabel data.
|
||||||
|
force_multilabel_data: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
|
@ -450,6 +457,10 @@ class AttackInputData:
|
||||||
Raises:
|
Raises:
|
||||||
ValueError if the dimensionality of the train and test data are not equal.
|
ValueError if the dimensionality of the train and test data are not equal.
|
||||||
"""
|
"""
|
||||||
|
# If 'force_multilabel_data' is set, then assume multilabel data going
|
||||||
|
# forward.
|
||||||
|
if self.force_multilabel_data:
|
||||||
|
self.multilabel_data = True
|
||||||
# If the data has already been checked for multihot encoded labels, then
|
# If the data has already been checked for multihot encoded labels, then
|
||||||
# return the result of the evaluation.
|
# return the result of the evaluation.
|
||||||
if self.multilabel_data is not None:
|
if self.multilabel_data is not None:
|
||||||
|
|
|
@ -358,6 +358,41 @@ class AttackInputDataTest(parameterized.TestCase):
|
||||||
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
||||||
np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
|
np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
|
||||||
|
|
||||||
|
def test_validate_with_force_multilabel_false(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||||
|
probs_test=np.array([[0.8, 0.7, 0.9]]),
|
||||||
|
labels_train=np.array([[0, 0, 1], [0, 1, 0]]),
|
||||||
|
labels_test=np.array([[1, 0, 0]]))
|
||||||
|
self.assertRaisesRegex(ValueError,
|
||||||
|
r'should be a one dimensional numpy array.',
|
||||||
|
attack_input.validate)
|
||||||
|
|
||||||
|
def test_validate_with_force_multilabel_true(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||||
|
probs_test=np.array([[0.8, 0.7, 0.9]]),
|
||||||
|
labels_train=np.array([[0, 0, 1], [0, 1, 0]]),
|
||||||
|
labels_test=np.array([[1, 0, 0]]),
|
||||||
|
force_multilabel_data=True)
|
||||||
|
try:
|
||||||
|
attack_input.validate()
|
||||||
|
except ValueError:
|
||||||
|
# For a 'ValueError' exception the test should record a failure. All
|
||||||
|
# other exceptions are errors.
|
||||||
|
self.fail('ValueError not raised by validate().')
|
||||||
|
|
||||||
|
def test_multilabel_data_true_with_force_multilabel_true(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||||
|
probs_test=np.array([[0.8, 0.7, 0.9]]),
|
||||||
|
labels_train=np.array([[0, 0, 1], [0, 1, 0]]),
|
||||||
|
labels_test=np.array([[1, 0, 0]]),
|
||||||
|
force_multilabel_data=True)
|
||||||
|
self.assertTrue(
|
||||||
|
attack_input.multilabel_data,
|
||||||
|
'"force_multilabel_data" is True but "multilabel_data" is False.')
|
||||||
|
|
||||||
|
|
||||||
class RocCurveTest(absltest.TestCase):
|
class RocCurveTest(absltest.TestCase):
|
||||||
|
|
||||||
|
|
|
@ -217,6 +217,7 @@ def run_attacks(attack_input: AttackInputData,
|
||||||
"""
|
"""
|
||||||
attack_input.validate()
|
attack_input.validate()
|
||||||
attack_results = []
|
attack_results = []
|
||||||
|
attack_types = list(attack_types)
|
||||||
|
|
||||||
if slicing_spec is None:
|
if slicing_spec is None:
|
||||||
slicing_spec = SlicingSpec(entire_dataset=True)
|
slicing_spec = SlicingSpec(entire_dataset=True)
|
||||||
|
@ -224,6 +225,10 @@ def run_attacks(attack_input: AttackInputData,
|
||||||
if slicing_spec.by_class:
|
if slicing_spec.by_class:
|
||||||
num_classes = attack_input.num_classes
|
num_classes = attack_input.num_classes
|
||||||
input_slice_specs = get_single_slice_specs(slicing_spec, num_classes)
|
input_slice_specs = get_single_slice_specs(slicing_spec, num_classes)
|
||||||
|
num_slice_specs = len(input_slice_specs)
|
||||||
|
num_attacks = len(attack_types)
|
||||||
|
logging.info('Will run %s attacks on each of %s slice specifications.',
|
||||||
|
num_attacks, num_slice_specs)
|
||||||
for single_slice_spec in input_slice_specs:
|
for single_slice_spec in input_slice_specs:
|
||||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||||
for attack_type in attack_types:
|
for attack_type in attack_types:
|
||||||
|
@ -231,6 +236,9 @@ def run_attacks(attack_input: AttackInputData,
|
||||||
attack_result = _run_attack(attack_input_slice, attack_type,
|
attack_result = _run_attack(attack_input_slice, attack_type,
|
||||||
balance_attacker_training, min_num_samples)
|
balance_attacker_training, min_num_samples)
|
||||||
if attack_result is not None:
|
if attack_result is not None:
|
||||||
|
logging.info('%s attack had an AUC=%s and attacker advantage=%s',
|
||||||
|
attack_type.name, attack_result.get_auc(),
|
||||||
|
attack_result.get_attacker_advantage())
|
||||||
attack_results.append(attack_result)
|
attack_results.append(attack_result)
|
||||||
|
|
||||||
privacy_report_metadata = _compute_missing_privacy_report_metadata(
|
privacy_report_metadata = _compute_missing_privacy_report_metadata(
|
||||||
|
|
Loading…
Reference in a new issue