Updating comments to run_attack() and making non-API functions to be private.
PiperOrigin-RevId: 329951618
This commit is contained in:
parent
2f0a078dd9
commit
f4fc9b2623
2 changed files with 23 additions and 10 deletions
|
@ -43,7 +43,7 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
|
|||
return SingleSliceSpec()
|
||||
|
||||
|
||||
def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||
def _run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||
"""Classification attack done by ML models."""
|
||||
attacker = None
|
||||
|
||||
|
@ -79,7 +79,7 @@ def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
|
|||
roc_curve=roc_curve)
|
||||
|
||||
|
||||
def run_threshold_attack(attack_input: AttackInputData):
|
||||
def _run_threshold_attack(attack_input: AttackInputData):
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
np.concatenate((np.zeros(attack_input.get_train_size()),
|
||||
np.ones(attack_input.get_test_size()))),
|
||||
|
@ -94,12 +94,12 @@ def run_threshold_attack(attack_input: AttackInputData):
|
|||
roc_curve=roc_curve)
|
||||
|
||||
|
||||
def run_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||
def _run_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||
attack_input.validate()
|
||||
if attack_type.is_trained_attack:
|
||||
return run_trained_attack(attack_input, attack_type)
|
||||
return _run_trained_attack(attack_input, attack_type)
|
||||
|
||||
return run_threshold_attack(attack_input)
|
||||
return _run_threshold_attack(attack_input)
|
||||
|
||||
|
||||
def run_attacks(
|
||||
|
@ -107,7 +107,20 @@ def run_attacks(
|
|||
slicing_spec: SlicingSpec = None,
|
||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackResults:
|
||||
"""Run all attacks."""
|
||||
"""Runs membership inference attacks on a classification model.
|
||||
|
||||
It runs attacks specified by attack_types on each attack_input slice which is
|
||||
specified by slicing_spec.
|
||||
|
||||
Args:
|
||||
attack_input: input data for running an attack
|
||||
slicing_spec: specifies attack_input slices to run attack on
|
||||
attack_types: attacks to run
|
||||
privacy_report_metadata: the metadata of the model under attack.
|
||||
|
||||
Returns:
|
||||
the attack result.
|
||||
"""
|
||||
attack_input.validate()
|
||||
attack_results = []
|
||||
|
||||
|
@ -118,7 +131,7 @@ def run_attacks(
|
|||
for single_slice_spec in input_slice_specs:
|
||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||
for attack_type in attack_types:
|
||||
attack_results.append(run_attack(attack_input_slice, attack_type))
|
||||
attack_results.append(_run_attack(attack_input_slice, attack_type))
|
||||
|
||||
privacy_report_metadata = _compute_missing_privacy_report_metadata(
|
||||
privacy_report_metadata, attack_input)
|
||||
|
|
|
@ -43,19 +43,19 @@ class RunAttacksTest(absltest.TestCase):
|
|||
self.assertLen(result.single_attack_results, 2)
|
||||
|
||||
def test_run_attack_trained_sets_attack_type(self):
|
||||
result = mia.run_attack(
|
||||
result = mia._run_attack(
|
||||
get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
||||
|
||||
self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
||||
|
||||
def test_run_attack_threshold_sets_attack_type(self):
|
||||
result = mia.run_attack(
|
||||
result = mia._run_attack(
|
||||
get_test_input(100, 100), AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
def test_run_attack_threshold_calculates_correct_auc(self):
|
||||
result = mia.run_attack(
|
||||
result = mia._run_attack(
|
||||
AttackInputData(
|
||||
loss_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]),
|
||||
loss_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])),
|
||||
|
|
Loading…
Reference in a new issue