Updating comments to run_attack() and making non-API functions to be private.
PiperOrigin-RevId: 329951618
This commit is contained in:
parent
2f0a078dd9
commit
f4fc9b2623
2 changed files with 23 additions and 10 deletions
|
@ -43,7 +43,7 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
|
||||||
return SingleSliceSpec()
|
return SingleSliceSpec()
|
||||||
|
|
||||||
|
|
||||||
def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
|
def _run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||||
"""Classification attack done by ML models."""
|
"""Classification attack done by ML models."""
|
||||||
attacker = None
|
attacker = None
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||||
roc_curve=roc_curve)
|
roc_curve=roc_curve)
|
||||||
|
|
||||||
|
|
||||||
def run_threshold_attack(attack_input: AttackInputData):
|
def _run_threshold_attack(attack_input: AttackInputData):
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
np.concatenate((np.zeros(attack_input.get_train_size()),
|
np.concatenate((np.zeros(attack_input.get_train_size()),
|
||||||
np.ones(attack_input.get_test_size()))),
|
np.ones(attack_input.get_test_size()))),
|
||||||
|
@ -94,12 +94,12 @@ def run_threshold_attack(attack_input: AttackInputData):
|
||||||
roc_curve=roc_curve)
|
roc_curve=roc_curve)
|
||||||
|
|
||||||
|
|
||||||
def run_attack(attack_input: AttackInputData, attack_type: AttackType):
|
def _run_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||||
attack_input.validate()
|
attack_input.validate()
|
||||||
if attack_type.is_trained_attack:
|
if attack_type.is_trained_attack:
|
||||||
return run_trained_attack(attack_input, attack_type)
|
return _run_trained_attack(attack_input, attack_type)
|
||||||
|
|
||||||
return run_threshold_attack(attack_input)
|
return _run_threshold_attack(attack_input)
|
||||||
|
|
||||||
|
|
||||||
def run_attacks(
|
def run_attacks(
|
||||||
|
@ -107,7 +107,20 @@ def run_attacks(
|
||||||
slicing_spec: SlicingSpec = None,
|
slicing_spec: SlicingSpec = None,
|
||||||
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,),
|
||||||
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackResults:
|
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackResults:
|
||||||
"""Run all attacks."""
|
"""Runs membership inference attacks on a classification model.
|
||||||
|
|
||||||
|
It runs attacks specified by attack_types on each attack_input slice which is
|
||||||
|
specified by slicing_spec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attack_input: input data for running an attack
|
||||||
|
slicing_spec: specifies attack_input slices to run attack on
|
||||||
|
attack_types: attacks to run
|
||||||
|
privacy_report_metadata: the metadata of the model under attack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the attack result.
|
||||||
|
"""
|
||||||
attack_input.validate()
|
attack_input.validate()
|
||||||
attack_results = []
|
attack_results = []
|
||||||
|
|
||||||
|
@ -118,7 +131,7 @@ def run_attacks(
|
||||||
for single_slice_spec in input_slice_specs:
|
for single_slice_spec in input_slice_specs:
|
||||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||||
for attack_type in attack_types:
|
for attack_type in attack_types:
|
||||||
attack_results.append(run_attack(attack_input_slice, attack_type))
|
attack_results.append(_run_attack(attack_input_slice, attack_type))
|
||||||
|
|
||||||
privacy_report_metadata = _compute_missing_privacy_report_metadata(
|
privacy_report_metadata = _compute_missing_privacy_report_metadata(
|
||||||
privacy_report_metadata, attack_input)
|
privacy_report_metadata, attack_input)
|
||||||
|
|
|
@ -43,19 +43,19 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
self.assertLen(result.single_attack_results, 2)
|
self.assertLen(result.single_attack_results, 2)
|
||||||
|
|
||||||
def test_run_attack_trained_sets_attack_type(self):
|
def test_run_attack_trained_sets_attack_type(self):
|
||||||
result = mia.run_attack(
|
result = mia._run_attack(
|
||||||
get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
||||||
|
|
||||||
self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
||||||
|
|
||||||
def test_run_attack_threshold_sets_attack_type(self):
|
def test_run_attack_threshold_sets_attack_type(self):
|
||||||
result = mia.run_attack(
|
result = mia._run_attack(
|
||||||
get_test_input(100, 100), AttackType.THRESHOLD_ATTACK)
|
get_test_input(100, 100), AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
|
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
def test_run_attack_threshold_calculates_correct_auc(self):
|
def test_run_attack_threshold_calculates_correct_auc(self):
|
||||||
result = mia.run_attack(
|
result = mia._run_attack(
|
||||||
AttackInputData(
|
AttackInputData(
|
||||||
loss_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]),
|
loss_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]),
|
||||||
loss_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])),
|
loss_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])),
|
||||||
|
|
Loading…
Reference in a new issue