Adds an option to balance train and test AttackInputData and stratifies the train-test split.
PiperOrigin-RevId: 336609893
This commit is contained in:
parent
d703168de2
commit
1281d0c63e
3 changed files with 68 additions and 21 deletions
|
@ -43,7 +43,9 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
|
|||
return SingleSliceSpec()
|
||||
|
||||
|
||||
def _run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||
def _run_trained_attack(attack_input: AttackInputData,
|
||||
attack_type: AttackType,
|
||||
balance_attacker_training: bool = True):
|
||||
"""Classification attack done by ML models."""
|
||||
attacker = None
|
||||
|
||||
|
@ -59,7 +61,8 @@ def _run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
|
|||
raise NotImplementedError('Attack type %s not implemented yet.' %
|
||||
attack_type)
|
||||
|
||||
prepared_attacker_data = models.create_attacker_data(attack_input)
|
||||
prepared_attacker_data = models.create_attacker_data(
|
||||
attack_input, balance=balance_attacker_training)
|
||||
|
||||
attacker.train_model(prepared_attacker_data.features_train,
|
||||
prepared_attacker_data.is_training_labels_train)
|
||||
|
@ -94,19 +97,23 @@ def _run_threshold_attack(attack_input: AttackInputData):
|
|||
roc_curve=roc_curve)
|
||||
|
||||
|
||||
def _run_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||
def _run_attack(attack_input: AttackInputData,
|
||||
attack_type: AttackType,
|
||||
balance_attacker_training: bool = True):
|
||||
attack_input.validate()
|
||||
if attack_type.is_trained_attack:
|
||||
return _run_trained_attack(attack_input, attack_type)
|
||||
return _run_trained_attack(attack_input, attack_type,
|
||||
balance_attacker_training)
|
||||
|
||||
return _run_threshold_attack(attack_input)
|
||||
|
||||
|
||||
def run_attacks(
|
||||
attack_input: AttackInputData,
|
||||
def run_attacks(attack_input: AttackInputData,
|
||||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackResults:
|
||||
attack_types: Iterable[AttackType] = (
|
||||
AttackType.THRESHOLD_ATTACK,),
|
||||
privacy_report_metadata: PrivacyReportMetadata = None,
|
||||
balance_attacker_training: bool = True) -> AttackResults:
|
||||
"""Runs membership inference attacks on a classification model.
|
||||
|
||||
It runs attacks specified by attack_types on each attack_input slice which is
|
||||
|
@ -117,6 +124,10 @@ def run_attacks(
|
|||
slicing_spec: specifies attack_input slices to run attack on
|
||||
attack_types: attacks to run
|
||||
privacy_report_metadata: the metadata of the model under attack.
|
||||
balance_attacker_training: Whether the training and test sets for the
|
||||
membership inference attacker should have a balanced (roughly equal)
|
||||
number of samples from the training and test sets used to develop
|
||||
the model under attack.
|
||||
|
||||
Returns:
|
||||
the attack result.
|
||||
|
@ -131,7 +142,9 @@ def run_attacks(
|
|||
for single_slice_spec in input_slice_specs:
|
||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||
for attack_type in attack_types:
|
||||
attack_results.append(_run_attack(attack_input_slice, attack_type))
|
||||
attack_results.append(
|
||||
_run_attack(attack_input_slice, attack_type,
|
||||
balance_attacker_training))
|
||||
|
||||
privacy_report_metadata = _compute_missing_privacy_report_metadata(
|
||||
privacy_report_metadata, attack_input)
|
||||
|
|
|
@ -43,7 +43,8 @@ class AttackerData:
|
|||
|
||||
|
||||
def create_attacker_data(attack_input_data: AttackInputData,
|
||||
test_fraction: float = 0.25) -> AttackerData:
|
||||
test_fraction: float = 0.25,
|
||||
balance: bool = True) -> AttackerData:
|
||||
"""Prepare AttackInputData to train ML attackers.
|
||||
|
||||
Combines logits and losses and performs a random train-test split.
|
||||
|
@ -51,6 +52,10 @@ def create_attacker_data(attack_input_data: AttackInputData,
|
|||
Args:
|
||||
attack_input_data: Original AttackInputData
|
||||
test_fraction: Fraction of the dataset to include in the test split.
|
||||
balance: Whether the training and test sets for the membership inference
|
||||
attacker should have a balanced (roughly equal) number of samples
|
||||
from the training and test sets used to develop the model
|
||||
under attack.
|
||||
|
||||
Returns:
|
||||
AttackerData.
|
||||
|
@ -60,20 +65,33 @@ def create_attacker_data(attack_input_data: AttackInputData,
|
|||
attack_input_test = _column_stack(attack_input_data.logits_or_probs_test,
|
||||
attack_input_data.get_loss_test())
|
||||
|
||||
if balance:
|
||||
min_size = min(attack_input_data.get_train_size(),
|
||||
attack_input_data.get_test_size())
|
||||
attack_input_train = _sample_multidimensional_array(attack_input_train,
|
||||
min_size)
|
||||
attack_input_test = _sample_multidimensional_array(attack_input_test,
|
||||
min_size)
|
||||
|
||||
features_all = np.concatenate((attack_input_train, attack_input_test))
|
||||
|
||||
labels_all = np.concatenate(((np.zeros(attack_input_data.get_train_size())),
|
||||
(np.ones(attack_input_data.get_test_size()))))
|
||||
labels_all = np.concatenate(
|
||||
((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test)))))
|
||||
|
||||
# Perform a train-test split
|
||||
features_train, features_test, \
|
||||
is_training_labels_train, is_training_labels_test = \
|
||||
model_selection.train_test_split(
|
||||
features_all, labels_all, test_size=test_fraction)
|
||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||
return AttackerData(features_train, is_training_labels_train, features_test,
|
||||
is_training_labels_test)
|
||||
|
||||
|
||||
def _sample_multidimensional_array(array, size):
|
||||
indices = np.random.choice(len(array), size, replace=False)
|
||||
return array[indices]
|
||||
|
||||
|
||||
def _column_stack(logits, loss):
|
||||
"""Stacks logits and losses.
|
||||
|
||||
|
|
|
@ -34,19 +34,20 @@ class TrainedAttackerTest(absltest.TestCase):
|
|||
|
||||
def test_create_attacker_data_loss_only(self):
|
||||
attack_input = AttackInputData(
|
||||
loss_train=np.array([1]), loss_test=np.array([2]))
|
||||
loss_train=np.array([1, 3]), loss_test=np.array([2, 4]))
|
||||
attacker_data = models.create_attacker_data(attack_input, 0.5)
|
||||
self.assertLen(attacker_data.features_test, 1)
|
||||
self.assertLen(attacker_data.features_train, 1)
|
||||
self.assertLen(attacker_data.features_test, 2)
|
||||
self.assertLen(attacker_data.features_train, 2)
|
||||
|
||||
def test_create_attacker_data_loss_and_logits(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[1, 2], [5, 6]]),
|
||||
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
||||
logits_test=np.array([[10, 11], [14, 15]]),
|
||||
loss_train=np.array([3, 7]),
|
||||
loss_train=np.array([3, 7, 10]),
|
||||
loss_test=np.array([12, 16]))
|
||||
attacker_data = models.create_attacker_data(attack_input, 0.25)
|
||||
self.assertLen(attacker_data.features_test, 1)
|
||||
attacker_data = models.create_attacker_data(
|
||||
attack_input, 0.25, balance=False)
|
||||
self.assertLen(attacker_data.features_test, 2)
|
||||
self.assertLen(attacker_data.features_train, 3)
|
||||
|
||||
for i, feature in enumerate(attacker_data.features_train):
|
||||
|
@ -54,6 +55,21 @@ class TrainedAttackerTest(absltest.TestCase):
|
|||
expected = feature[:2] not in attack_input.logits_train
|
||||
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
||||
|
||||
def test_balanced_create_attacker_data_loss_and_logits(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
||||
logits_test=np.array([[10, 11], [14, 15], [17, 18]]),
|
||||
loss_train=np.array([3, 7, 10]),
|
||||
loss_test=np.array([12, 16, 19]))
|
||||
attacker_data = models.create_attacker_data(attack_input, 0.33)
|
||||
self.assertLen(attacker_data.features_test, 2)
|
||||
self.assertLen(attacker_data.features_train, 4)
|
||||
|
||||
for i, feature in enumerate(attacker_data.features_train):
|
||||
self.assertLen(feature, 3) # each feature has two logits and one loss
|
||||
expected = feature[:2] not in attack_input.logits_train
|
||||
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue