Adds an option to balance train and test AttackInputData and stratifies the train-test split.

PiperOrigin-RevId: 336609893
This commit is contained in:
David Marn 2020-10-12 00:42:55 -07:00 committed by A. Unique TensorFlower
parent d703168de2
commit 1281d0c63e
3 changed files with 68 additions and 21 deletions

View file

@ -43,7 +43,9 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
return SingleSliceSpec()
def _run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
def _run_trained_attack(attack_input: AttackInputData,
attack_type: AttackType,
balance_attacker_training: bool = True):
"""Classification attack done by ML models."""
attacker = None
@ -59,7 +61,8 @@ def _run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
raise NotImplementedError('Attack type %s not implemented yet.' %
attack_type)
prepared_attacker_data = models.create_attacker_data(attack_input)
prepared_attacker_data = models.create_attacker_data(
attack_input, balance=balance_attacker_training)
attacker.train_model(prepared_attacker_data.features_train,
prepared_attacker_data.is_training_labels_train)
@ -94,19 +97,23 @@ def _run_threshold_attack(attack_input: AttackInputData):
roc_curve=roc_curve)
def _run_attack(attack_input: AttackInputData, attack_type: AttackType):
def _run_attack(attack_input: AttackInputData,
attack_type: AttackType,
balance_attacker_training: bool = True):
attack_input.validate()
if attack_type.is_trained_attack:
return _run_trained_attack(attack_input, attack_type)
return _run_trained_attack(attack_input, attack_type,
balance_attacker_training)
return _run_threshold_attack(attack_input)
def run_attacks(
attack_input: AttackInputData,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackResults:
def run_attacks(attack_input: AttackInputData,
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (
AttackType.THRESHOLD_ATTACK,),
privacy_report_metadata: PrivacyReportMetadata = None,
balance_attacker_training: bool = True) -> AttackResults:
"""Runs membership inference attacks on a classification model.
It runs attacks specified by attack_types on each attack_input slice which is
@ -117,6 +124,10 @@ def run_attacks(
slicing_spec: specifies attack_input slices to run attack on
attack_types: attacks to run
privacy_report_metadata: the metadata of the model under attack.
balance_attacker_training: Whether the training and test sets for the
membership inference attacker should have a balanced (roughly equal)
number of samples from the training and test sets used to develop
the model under attack.
Returns:
the attack result.
@ -131,7 +142,9 @@ def run_attacks(
for single_slice_spec in input_slice_specs:
attack_input_slice = get_slice(attack_input, single_slice_spec)
for attack_type in attack_types:
attack_results.append(_run_attack(attack_input_slice, attack_type))
attack_results.append(
_run_attack(attack_input_slice, attack_type,
balance_attacker_training))
privacy_report_metadata = _compute_missing_privacy_report_metadata(
privacy_report_metadata, attack_input)

View file

@ -43,7 +43,8 @@ class AttackerData:
def create_attacker_data(attack_input_data: AttackInputData,
test_fraction: float = 0.25) -> AttackerData:
test_fraction: float = 0.25,
balance: bool = True) -> AttackerData:
"""Prepare AttackInputData to train ML attackers.
Combines logits and losses and performs a random train-test split.
@ -51,6 +52,10 @@ def create_attacker_data(attack_input_data: AttackInputData,
Args:
attack_input_data: Original AttackInputData
test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples
from the training and test sets used to develop the model
under attack.
Returns:
AttackerData.
@ -60,20 +65,33 @@ def create_attacker_data(attack_input_data: AttackInputData,
attack_input_test = _column_stack(attack_input_data.logits_or_probs_test,
attack_input_data.get_loss_test())
if balance:
min_size = min(attack_input_data.get_train_size(),
attack_input_data.get_test_size())
attack_input_train = _sample_multidimensional_array(attack_input_train,
min_size)
attack_input_test = _sample_multidimensional_array(attack_input_test,
min_size)
features_all = np.concatenate((attack_input_train, attack_input_test))
labels_all = np.concatenate(((np.zeros(attack_input_data.get_train_size())),
(np.ones(attack_input_data.get_test_size()))))
labels_all = np.concatenate(
((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test)))))
# Perform a train-test split
features_train, features_test, \
is_training_labels_train, is_training_labels_test = \
model_selection.train_test_split(
features_all, labels_all, test_size=test_fraction)
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
return AttackerData(features_train, is_training_labels_train, features_test,
is_training_labels_test)
def _sample_multidimensional_array(array, size):
indices = np.random.choice(len(array), size, replace=False)
return array[indices]
def _column_stack(logits, loss):
"""Stacks logits and losses.

View file

@ -34,19 +34,20 @@ class TrainedAttackerTest(absltest.TestCase):
def test_create_attacker_data_loss_only(self):
attack_input = AttackInputData(
loss_train=np.array([1]), loss_test=np.array([2]))
loss_train=np.array([1, 3]), loss_test=np.array([2, 4]))
attacker_data = models.create_attacker_data(attack_input, 0.5)
self.assertLen(attacker_data.features_test, 1)
self.assertLen(attacker_data.features_train, 1)
self.assertLen(attacker_data.features_test, 2)
self.assertLen(attacker_data.features_train, 2)
def test_create_attacker_data_loss_and_logits(self):
attack_input = AttackInputData(
logits_train=np.array([[1, 2], [5, 6]]),
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
logits_test=np.array([[10, 11], [14, 15]]),
loss_train=np.array([3, 7]),
loss_train=np.array([3, 7, 10]),
loss_test=np.array([12, 16]))
attacker_data = models.create_attacker_data(attack_input, 0.25)
self.assertLen(attacker_data.features_test, 1)
attacker_data = models.create_attacker_data(
attack_input, 0.25, balance=False)
self.assertLen(attacker_data.features_test, 2)
self.assertLen(attacker_data.features_train, 3)
for i, feature in enumerate(attacker_data.features_train):
@ -54,6 +55,21 @@ class TrainedAttackerTest(absltest.TestCase):
expected = feature[:2] not in attack_input.logits_train
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
def test_balanced_create_attacker_data_loss_and_logits(self):
attack_input = AttackInputData(
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
logits_test=np.array([[10, 11], [14, 15], [17, 18]]),
loss_train=np.array([3, 7, 10]),
loss_test=np.array([12, 16, 19]))
attacker_data = models.create_attacker_data(attack_input, 0.33)
self.assertLen(attacker_data.features_test, 2)
self.assertLen(attacker_data.features_train, 4)
for i, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 3) # each feature has two logits and one loss
expected = feature[:2] not in attack_input.logits_train
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
if __name__ == '__main__':
absltest.main()