forked from 626_privacy/tensorflow_privacy
add threshold-entropy attack
This commit is contained in:
parent
1981ebe2f2
commit
a41d6aace7
3 changed files with 19 additions and 2 deletions
|
@ -114,11 +114,12 @@ class AttackType(enum.Enum):
|
|||
RANDOM_FOREST = 'rf'
|
||||
K_NEAREST_NEIGHBORS = 'knn'
|
||||
THRESHOLD_ATTACK = 'threshold'
|
||||
THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy'
|
||||
|
||||
@property
|
||||
def is_trained_attack(self):
|
||||
"""Returns whether this type of attack requires training a model."""
|
||||
return self != AttackType.THRESHOLD_ATTACK
|
||||
return (self != AttackType.THRESHOLD_ATTACK) & (self != AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||
|
||||
def __str__(self):
|
||||
"""Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION."""
|
||||
|
|
|
@ -41,12 +41,14 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
|
|||
result.probs_train = _slice_if_not_none(data.probs_train, idx_train)
|
||||
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
|
||||
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
|
||||
result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train)
|
||||
|
||||
# Slice test data.
|
||||
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
|
||||
result.probs_test = _slice_if_not_none(data.probs_test, idx_test)
|
||||
result.labels_test = _slice_if_not_none(data.labels_test, idx_test)
|
||||
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
|
||||
result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test)
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
@ -96,6 +96,19 @@ def _run_threshold_attack(attack_input: AttackInputData):
|
|||
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||
roc_curve=roc_curve)
|
||||
|
||||
def _run_threshold_entropy_attack(attack_input: AttackInputData):
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
np.concatenate((np.zeros(attack_input.get_train_size()),
|
||||
np.ones(attack_input.get_test_size()))),
|
||||
np.concatenate(
|
||||
(attack_input.get_entropy_train(), attack_input.get_entropy_test())))
|
||||
|
||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||
|
||||
return SingleAttackResult(
|
||||
slice_spec=_get_slice_spec(attack_input),
|
||||
attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK,
|
||||
roc_curve=roc_curve)
|
||||
|
||||
def _run_attack(attack_input: AttackInputData,
|
||||
attack_type: AttackType,
|
||||
|
@ -104,7 +117,8 @@ def _run_attack(attack_input: AttackInputData,
|
|||
if attack_type.is_trained_attack:
|
||||
return _run_trained_attack(attack_input, attack_type,
|
||||
balance_attacker_training)
|
||||
|
||||
if attack_type == AttackType.THRESHOLD_ENTROPY_ATTACK:
|
||||
return _run_threshold_entropy_attack(attack_input)
|
||||
return _run_threshold_attack(attack_input)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue