diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index fbafcd6..4e1acf3 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -114,11 +114,12 @@ class AttackType(enum.Enum): RANDOM_FOREST = 'rf' K_NEAREST_NEIGHBORS = 'knn' THRESHOLD_ATTACK = 'threshold' + THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy' @property def is_trained_attack(self): """Returns whether this type of attack requires training a model.""" - return self != AttackType.THRESHOLD_ATTACK + return (self != AttackType.THRESHOLD_ATTACK) & (self != AttackType.THRESHOLD_ENTROPY_ATTACK) def __str__(self): """Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION.""" diff --git a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py index 8e5e4b0..fe108a6 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py @@ -41,12 +41,14 @@ def _slice_data_by_indices(data: AttackInputData, idx_train, result.probs_train = _slice_if_not_none(data.probs_train, idx_train) result.labels_train = _slice_if_not_none(data.labels_train, idx_train) result.loss_train = _slice_if_not_none(data.loss_train, idx_train) + result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train) # Slice test data. result.logits_test = _slice_if_not_none(data.logits_test, idx_test) result.probs_test = _slice_if_not_none(data.probs_test, idx_test) result.labels_test = _slice_if_not_none(data.labels_test, idx_test) result.loss_test = _slice_if_not_none(data.loss_test, idx_test) + result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test) return result diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py index 8f877a7..ae88464 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -96,6 +96,19 @@ def _run_threshold_attack(attack_input: AttackInputData): attack_type=AttackType.THRESHOLD_ATTACK, roc_curve=roc_curve) +def _run_threshold_entropy_attack(attack_input: AttackInputData): + fpr, tpr, thresholds = metrics.roc_curve( + np.concatenate((np.zeros(attack_input.get_train_size()), + np.ones(attack_input.get_test_size()))), + np.concatenate( + (attack_input.get_entropy_train(), attack_input.get_entropy_test()))) + + roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) + + return SingleAttackResult( + slice_spec=_get_slice_spec(attack_input), + attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK, + roc_curve=roc_curve) def _run_attack(attack_input: AttackInputData, attack_type: AttackType, @@ -104,7 +117,8 @@ def _run_attack(attack_input: AttackInputData, if attack_type.is_trained_attack: return _run_trained_attack(attack_input, attack_type, balance_attacker_training) - + if attack_type == AttackType.THRESHOLD_ENTROPY_ATTACK: + return _run_threshold_entropy_attack(attack_input) return _run_threshold_attack(attack_input)