Merge pull request #131 from lwsong:master
PiperOrigin-RevId: 339012372
This commit is contained in:
commit
67f7f35383
5 changed files with 52 additions and 3 deletions
|
@ -114,11 +114,13 @@ class AttackType(enum.Enum):
|
||||||
RANDOM_FOREST = 'rf'
|
RANDOM_FOREST = 'rf'
|
||||||
K_NEAREST_NEIGHBORS = 'knn'
|
K_NEAREST_NEIGHBORS = 'knn'
|
||||||
THRESHOLD_ATTACK = 'threshold'
|
THRESHOLD_ATTACK = 'threshold'
|
||||||
|
THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trained_attack(self):
|
def is_trained_attack(self):
|
||||||
"""Returns whether this type of attack requires training a model."""
|
"""Returns whether this type of attack requires training a model."""
|
||||||
return self != AttackType.THRESHOLD_ATTACK
|
return (self != AttackType.THRESHOLD_ATTACK) and (
|
||||||
|
self != AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION."""
|
"""Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION."""
|
||||||
|
@ -278,12 +280,16 @@ class AttackInputData:
|
||||||
"""Returns size of the training set."""
|
"""Returns size of the training set."""
|
||||||
if self.loss_train is not None:
|
if self.loss_train is not None:
|
||||||
return self.loss_train.size
|
return self.loss_train.size
|
||||||
|
if self.entropy_train is not None:
|
||||||
|
return self.entropy_train.size
|
||||||
return self.logits_or_probs_train.shape[0]
|
return self.logits_or_probs_train.shape[0]
|
||||||
|
|
||||||
def get_test_size(self):
|
def get_test_size(self):
|
||||||
"""Returns size of the test set."""
|
"""Returns size of the test set."""
|
||||||
if self.loss_test is not None:
|
if self.loss_test is not None:
|
||||||
return self.loss_test.size
|
return self.loss_test.size
|
||||||
|
if self.entropy_test is not None:
|
||||||
|
return self.entropy_test.size
|
||||||
return self.logits_or_probs_test.shape[0]
|
return self.logits_or_probs_test.shape[0]
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
|
|
|
@ -41,12 +41,14 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
|
||||||
result.probs_train = _slice_if_not_none(data.probs_train, idx_train)
|
result.probs_train = _slice_if_not_none(data.probs_train, idx_train)
|
||||||
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
|
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
|
||||||
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
|
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
|
||||||
|
result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train)
|
||||||
|
|
||||||
# Slice test data.
|
# Slice test data.
|
||||||
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
|
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
|
||||||
result.probs_test = _slice_if_not_none(data.probs_test, idx_test)
|
result.probs_test = _slice_if_not_none(data.probs_test, idx_test)
|
||||||
result.labels_test = _slice_if_not_none(data.labels_test, idx_test)
|
result.labels_test = _slice_if_not_none(data.labels_test, idx_test)
|
||||||
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
|
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
|
||||||
|
result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
@ -114,6 +114,8 @@ class GetSliceTest(absltest.TestCase):
|
||||||
labels_test = np.array([1, 2, 0, 2])
|
labels_test = np.array([1, 2, 0, 2])
|
||||||
loss_train = np.array([2, 0.25, 4, 3])
|
loss_train = np.array([2, 0.25, 4, 3])
|
||||||
loss_test = np.array([0.5, 3.5, 7, 4.5])
|
loss_test = np.array([0.5, 3.5, 7, 4.5])
|
||||||
|
entropy_train = np.array([0.4, 8, 0.6, 10])
|
||||||
|
entropy_test = np.array([15, 10.5, 4.5, 0.3])
|
||||||
|
|
||||||
self.input_data = AttackInputData(
|
self.input_data = AttackInputData(
|
||||||
logits_train=logits_train,
|
logits_train=logits_train,
|
||||||
|
@ -123,7 +125,9 @@ class GetSliceTest(absltest.TestCase):
|
||||||
labels_train=labels_train,
|
labels_train=labels_train,
|
||||||
labels_test=labels_test,
|
labels_test=labels_test,
|
||||||
loss_train=loss_train,
|
loss_train=loss_train,
|
||||||
loss_test=loss_test)
|
loss_test=loss_test,
|
||||||
|
entropy_train=entropy_train,
|
||||||
|
entropy_test=entropy_test)
|
||||||
|
|
||||||
def test_slice_entire_dataset(self):
|
def test_slice_entire_dataset(self):
|
||||||
entire_dataset_slice = SingleSliceSpec()
|
entire_dataset_slice = SingleSliceSpec()
|
||||||
|
@ -159,6 +163,12 @@ class GetSliceTest(absltest.TestCase):
|
||||||
self.assertTrue((output.loss_train == [2, 4]).all())
|
self.assertTrue((output.loss_train == [2, 4]).all())
|
||||||
self.assertTrue((output.loss_test == [0.5]).all())
|
self.assertTrue((output.loss_test == [0.5]).all())
|
||||||
|
|
||||||
|
# Check entropy
|
||||||
|
self.assertLen(output.entropy_train, 2)
|
||||||
|
self.assertLen(output.entropy_test, 1)
|
||||||
|
self.assertTrue((output.entropy_train == [0.4, 0.6]).all())
|
||||||
|
self.assertTrue((output.entropy_test == [15]).all())
|
||||||
|
|
||||||
def test_slice_by_percentile(self):
|
def test_slice_by_percentile(self):
|
||||||
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
||||||
output = get_slice(self.input_data, percentile_slice)
|
output = get_slice(self.input_data, percentile_slice)
|
||||||
|
|
|
@ -97,6 +97,21 @@ def _run_threshold_attack(attack_input: AttackInputData):
|
||||||
roc_curve=roc_curve)
|
roc_curve=roc_curve)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_threshold_entropy_attack(attack_input: AttackInputData):
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
np.concatenate((np.zeros(attack_input.get_train_size()),
|
||||||
|
np.ones(attack_input.get_test_size()))),
|
||||||
|
np.concatenate(
|
||||||
|
(attack_input.get_entropy_train(), attack_input.get_entropy_test())))
|
||||||
|
|
||||||
|
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||||
|
|
||||||
|
return SingleAttackResult(
|
||||||
|
slice_spec=_get_slice_spec(attack_input),
|
||||||
|
attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK,
|
||||||
|
roc_curve=roc_curve)
|
||||||
|
|
||||||
|
|
||||||
def _run_attack(attack_input: AttackInputData,
|
def _run_attack(attack_input: AttackInputData,
|
||||||
attack_type: AttackType,
|
attack_type: AttackType,
|
||||||
balance_attacker_training: bool = True):
|
balance_attacker_training: bool = True):
|
||||||
|
@ -104,7 +119,8 @@ def _run_attack(attack_input: AttackInputData,
|
||||||
if attack_type.is_trained_attack:
|
if attack_type.is_trained_attack:
|
||||||
return _run_trained_attack(attack_input, attack_type,
|
return _run_trained_attack(attack_input, attack_type,
|
||||||
balance_attacker_training)
|
balance_attacker_training)
|
||||||
|
if attack_type == AttackType.THRESHOLD_ENTROPY_ATTACK:
|
||||||
|
return _run_threshold_entropy_attack(attack_input)
|
||||||
return _run_threshold_attack(attack_input)
|
return _run_threshold_attack(attack_input)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,12 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
|
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
|
def test_run_attack_threshold_entropy_sets_attack_type(self):
|
||||||
|
result = mia._run_attack(
|
||||||
|
get_test_input(100, 100), AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||||
|
|
||||||
|
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||||
|
|
||||||
def test_run_attack_threshold_calculates_correct_auc(self):
|
def test_run_attack_threshold_calculates_correct_auc(self):
|
||||||
result = mia._run_attack(
|
result = mia._run_attack(
|
||||||
AttackInputData(
|
AttackInputData(
|
||||||
|
@ -64,6 +70,15 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
|
|
||||||
np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
|
np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
|
||||||
|
|
||||||
|
def test_run_attack_threshold_entropy_calculates_correct_auc(self):
|
||||||
|
result = mia._run_attack(
|
||||||
|
AttackInputData(
|
||||||
|
entropy_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]),
|
||||||
|
entropy_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])),
|
||||||
|
AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||||
|
|
||||||
|
np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
|
||||||
|
|
||||||
def test_run_attack_by_slice(self):
|
def test_run_attack_by_slice(self):
|
||||||
result = mia.run_attacks(
|
result = mia.run_attacks(
|
||||||
get_test_input(100, 100), SlicingSpec(by_class=True),
|
get_test_input(100, 100), SlicingSpec(by_class=True),
|
||||||
|
|
Loading…
Reference in a new issue