add threshold-entropy attack

This commit is contained in:
Liwei Song 2020-10-21 16:41:20 -04:00
parent 1981ebe2f2
commit a41d6aace7
3 changed files with 19 additions and 2 deletions

View file

@ -114,11 +114,12 @@ class AttackType(enum.Enum):
RANDOM_FOREST = 'rf'
K_NEAREST_NEIGHBORS = 'knn'
THRESHOLD_ATTACK = 'threshold'
THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy'
@property
def is_trained_attack(self):
"""Returns whether this type of attack requires training a model."""
return self != AttackType.THRESHOLD_ATTACK
return (self != AttackType.THRESHOLD_ATTACK) & (self != AttackType.THRESHOLD_ENTROPY_ATTACK)
def __str__(self):
"""Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION."""

View file

@ -41,12 +41,14 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
result.probs_train = _slice_if_not_none(data.probs_train, idx_train)
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train)
# Slice test data.
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
result.probs_test = _slice_if_not_none(data.probs_test, idx_test)
result.labels_test = _slice_if_not_none(data.labels_test, idx_test)
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test)
return result

View file

@ -96,6 +96,19 @@ def _run_threshold_attack(attack_input: AttackInputData):
attack_type=AttackType.THRESHOLD_ATTACK,
roc_curve=roc_curve)
def _run_threshold_entropy_attack(attack_input: AttackInputData):
fpr, tpr, thresholds = metrics.roc_curve(
np.concatenate((np.zeros(attack_input.get_train_size()),
np.ones(attack_input.get_test_size()))),
np.concatenate(
(attack_input.get_entropy_train(), attack_input.get_entropy_test())))
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
return SingleAttackResult(
slice_spec=_get_slice_spec(attack_input),
attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK,
roc_curve=roc_curve)
def _run_attack(attack_input: AttackInputData,
attack_type: AttackType,
@ -104,7 +117,8 @@ def _run_attack(attack_input: AttackInputData,
if attack_type.is_trained_attack:
return _run_trained_attack(attack_input, attack_type,
balance_attacker_training)
if attack_type == AttackType.THRESHOLD_ENTROPY_ATTACK:
return _run_threshold_entropy_attack(attack_input)
return _run_threshold_attack(attack_input)