Updating comments to run_attack() and making non-API functions to be private.

PiperOrigin-RevId: 329951618
This commit is contained in:
Vadym Doroshenko 2020-09-03 10:56:06 -07:00 committed by A. Unique TensorFlower
parent 2f0a078dd9
commit f4fc9b2623
2 changed files with 23 additions and 10 deletions

View file

@ -43,7 +43,7 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
return SingleSliceSpec()
def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
def _run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
"""Classification attack done by ML models."""
attacker = None
@ -79,7 +79,7 @@ def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
roc_curve=roc_curve)
def run_threshold_attack(attack_input: AttackInputData):
def _run_threshold_attack(attack_input: AttackInputData):
fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.zeros(attack_input.get_train_size()),
np.ones(attack_input.get_test_size()))),
@ -94,12 +94,12 @@ def run_threshold_attack(attack_input: AttackInputData):
roc_curve=roc_curve)
def run_attack(attack_input: AttackInputData, attack_type: AttackType):
def _run_attack(attack_input: AttackInputData, attack_type: AttackType):
attack_input.validate()
if attack_type.is_trained_attack:
return run_trained_attack(attack_input, attack_type)
return _run_trained_attack(attack_input, attack_type)
return run_threshold_attack(attack_input)
return _run_threshold_attack(attack_input)
def run_attacks(
@ -107,7 +107,20 @@ def run_attacks(
slicing_spec: SlicingSpec = None,
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackResults:
"""Run all attacks."""
"""Runs membership inference attacks on a classification model.
It runs attacks specified by attack_types on each attack_input slice which is
specified by slicing_spec.
Args:
attack_input: input data for running an attack
slicing_spec: specifies attack_input slices to run attack on
attack_types: attacks to run
privacy_report_metadata: the metadata of the model under attack.
Returns:
the attack result.
"""
attack_input.validate()
attack_results = []
@ -118,7 +131,7 @@ def run_attacks(
for single_slice_spec in input_slice_specs:
attack_input_slice = get_slice(attack_input, single_slice_spec)
for attack_type in attack_types:
attack_results.append(run_attack(attack_input_slice, attack_type))
attack_results.append(_run_attack(attack_input_slice, attack_type))
privacy_report_metadata = _compute_missing_privacy_report_metadata(
privacy_report_metadata, attack_input)

View file

@ -43,19 +43,19 @@ class RunAttacksTest(absltest.TestCase):
self.assertLen(result.single_attack_results, 2)
def test_run_attack_trained_sets_attack_type(self):
result = mia.run_attack(
result = mia._run_attack(
get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION)
def test_run_attack_threshold_sets_attack_type(self):
result = mia.run_attack(
result = mia._run_attack(
get_test_input(100, 100), AttackType.THRESHOLD_ATTACK)
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
def test_run_attack_threshold_calculates_correct_auc(self):
result = mia.run_attack(
result = mia._run_attack(
AttackInputData(
loss_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]),
loss_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])),