Handle the case when the data comes from a multilabel classification problem but the provided samples happen to have just one positive label per sample.

PiperOrigin-RevId: 445468067
This commit is contained in:
A. Unique TensorFlower 2022-04-29 11:36:21 -07:00
parent e0ab480e3d
commit 930c4d13e8
3 changed files with 54 additions and 0 deletions

View file

@ -235,6 +235,13 @@ class AttackInputData:
# corresponding class is absent from the example, and 1s where the
# corresponding class is present.
multilabel_data: Optional[bool] = None
# In some corner cases, the provided data comes from a multi-label
# classification model, but the samples all happen to have just 1 label. In
# that case, the `is_multilabel_data()` test will return a `False` value. The
# attack models will expect 1D input, which will throw an exception. Handle
# this case by letting the user set a flag that forces the input data to be
# treated as multilabel data.
force_multilabel_data: bool = False
@property
def num_classes(self):
@ -450,6 +457,10 @@ class AttackInputData:
Raises:
ValueError if the dimensionality of the train and test data are not equal.
"""
# If 'force_multilabel_data' is set, then assume multilabel data going
# forward.
if self.force_multilabel_data:
self.multilabel_data = True
# If the data has already been checked for multihot encoded labels, then
# return the result of the evaluation.
if self.multilabel_data is not None:

View file

@ -358,6 +358,41 @@ class AttackInputDataTest(parameterized.TestCase):
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
def test_validate_with_force_multilabel_false(self):
attack_input = AttackInputData(
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
probs_test=np.array([[0.8, 0.7, 0.9]]),
labels_train=np.array([[0, 0, 1], [0, 1, 0]]),
labels_test=np.array([[1, 0, 0]]))
self.assertRaisesRegex(ValueError,
r'should be a one dimensional numpy array.',
attack_input.validate)
def test_validate_with_force_multilabel_true(self):
attack_input = AttackInputData(
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
probs_test=np.array([[0.8, 0.7, 0.9]]),
labels_train=np.array([[0, 0, 1], [0, 1, 0]]),
labels_test=np.array([[1, 0, 0]]),
force_multilabel_data=True)
try:
attack_input.validate()
except ValueError:
# For a 'ValueError' exception the test should record a failure. All
# other exceptions are errors.
self.fail('ValueError not raised by validate().')
def test_multilabel_data_true_with_force_multilabel_true(self):
attack_input = AttackInputData(
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
probs_test=np.array([[0.8, 0.7, 0.9]]),
labels_train=np.array([[0, 0, 1], [0, 1, 0]]),
labels_test=np.array([[1, 0, 0]]),
force_multilabel_data=True)
self.assertTrue(
attack_input.multilabel_data,
'"force_multilabel_data" is True but "multilabel_data" is False.')
class RocCurveTest(absltest.TestCase):

View file

@ -217,6 +217,7 @@ def run_attacks(attack_input: AttackInputData,
"""
attack_input.validate()
attack_results = []
attack_types = list(attack_types)
if slicing_spec is None:
slicing_spec = SlicingSpec(entire_dataset=True)
@ -224,6 +225,10 @@ def run_attacks(attack_input: AttackInputData,
if slicing_spec.by_class:
num_classes = attack_input.num_classes
input_slice_specs = get_single_slice_specs(slicing_spec, num_classes)
num_slice_specs = len(input_slice_specs)
num_attacks = len(attack_types)
logging.info('Will run %s attacks on each of %s slice specifications.',
num_attacks, num_slice_specs)
for single_slice_spec in input_slice_specs:
attack_input_slice = get_slice(attack_input, single_slice_spec)
for attack_type in attack_types:
@ -231,6 +236,9 @@ def run_attacks(attack_input: AttackInputData,
attack_result = _run_attack(attack_input_slice, attack_type,
balance_attacker_training, min_num_samples)
if attack_result is not None:
logging.info('%s attack had an AUC=%s and attacker advantage=%s',
attack_type.name, attack_result.get_auc(),
attack_result.get_attacker_advantage())
attack_results.append(attack_result)
privacy_report_metadata = _compute_missing_privacy_report_metadata(