forked from 626_privacy/tensorflow_privacy
update test code
This commit is contained in:
parent
3f40b8c465
commit
2312192573
1 changed files with 0 additions and 98 deletions
|
@ -19,7 +19,6 @@ import numpy as np
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
|
@ -35,67 +34,6 @@ def get_test_input(n_train, n_test):
|
||||||
labels_test=np.array([i % 5 for i in range(n_test)]))
|
labels_test=np.array([i % 5 for i in range(n_test)]))
|
||||||
|
|
||||||
|
|
||||||
def get_seq2seq_test_input(n_train,
|
|
||||||
n_test,
|
|
||||||
max_seq_in_batch,
|
|
||||||
max_tokens_in_sequence,
|
|
||||||
vocab_size,
|
|
||||||
seed=None):
|
|
||||||
"""Returns example inputs for attacks on seq2seq models."""
|
|
||||||
if seed is not None:
|
|
||||||
np.random.seed(seed=seed)
|
|
||||||
|
|
||||||
logits_train, labels_train = [], []
|
|
||||||
for _ in range(n_train):
|
|
||||||
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
|
||||||
batch_logits, batch_labels = _get_batch_logits_and_labels(
|
|
||||||
num_sequences, max_tokens_in_sequence, vocab_size)
|
|
||||||
logits_train.append(batch_logits)
|
|
||||||
labels_train.append(batch_labels)
|
|
||||||
|
|
||||||
logits_test, labels_test = [], []
|
|
||||||
for _ in range(n_test):
|
|
||||||
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
|
||||||
batch_logits, batch_labels = _get_batch_logits_and_labels(
|
|
||||||
num_sequences, max_tokens_in_sequence, vocab_size)
|
|
||||||
logits_test.append(batch_logits)
|
|
||||||
labels_test.append(batch_labels)
|
|
||||||
|
|
||||||
return Seq2SeqAttackInputData(
|
|
||||||
logits_train=iter(logits_train),
|
|
||||||
logits_test=iter(logits_test),
|
|
||||||
labels_train=iter(labels_train),
|
|
||||||
labels_test=iter(labels_test),
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
train_size=n_train,
|
|
||||||
test_size=n_test)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
|
|
||||||
vocab_size):
|
|
||||||
num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
|
|
||||||
num_sequences) + 1
|
|
||||||
batch_logits, batch_labels = [], []
|
|
||||||
for num_tokens in num_tokens_in_sequence:
|
|
||||||
logits, labels = _get_sequence_logits_and_labels(num_tokens, vocab_size)
|
|
||||||
batch_logits.append(logits)
|
|
||||||
batch_labels.append(labels)
|
|
||||||
return np.array(
|
|
||||||
batch_logits, dtype=object), np.array(
|
|
||||||
batch_labels, dtype=object)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_sequence_logits_and_labels(num_tokens, vocab_size):
|
|
||||||
sequence_logits = []
|
|
||||||
for _ in range(num_tokens):
|
|
||||||
token_logits = np.random.random(vocab_size)
|
|
||||||
token_logits /= token_logits.sum()
|
|
||||||
sequence_logits.append(token_logits)
|
|
||||||
sequence_labels = np.random.choice(vocab_size, num_tokens)
|
|
||||||
return np.array(
|
|
||||||
sequence_logits, dtype=np.float32), np.array(
|
|
||||||
sequence_labels, dtype=np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
class RunAttacksTest(absltest.TestCase):
|
class RunAttacksTest(absltest.TestCase):
|
||||||
|
|
||||||
|
@ -160,42 +98,6 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
# If accuracy is already present, simply return it.
|
# If accuracy is already present, simply return it.
|
||||||
self.assertIsNone(mia._get_accuracy(None, labels))
|
self.assertIsNone(mia._get_accuracy(None, labels))
|
||||||
|
|
||||||
def test_run_seq2seq_attack_size(self):
|
|
||||||
result = mia.run_seq2seq_attack(
|
|
||||||
get_seq2seq_test_input(
|
|
||||||
n_train=10,
|
|
||||||
n_test=5,
|
|
||||||
max_seq_in_batch=3,
|
|
||||||
max_tokens_in_sequence=5,
|
|
||||||
vocab_size=2))
|
|
||||||
|
|
||||||
self.assertLen(result.single_attack_results, 1)
|
|
||||||
|
|
||||||
def test_run_seq2seq_attack_trained_sets_attack_type(self):
|
|
||||||
result = mia.run_seq2seq_attack(
|
|
||||||
get_seq2seq_test_input(
|
|
||||||
n_train=10,
|
|
||||||
n_test=5,
|
|
||||||
max_seq_in_batch=3,
|
|
||||||
max_tokens_in_sequence=5,
|
|
||||||
vocab_size=2))
|
|
||||||
seq2seq_result = list(result.single_attack_results)[0]
|
|
||||||
self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
|
||||||
|
|
||||||
def test_run_seq2seq_attack_calculates_correct_auc(self):
|
|
||||||
result = mia.run_seq2seq_attack(
|
|
||||||
get_seq2seq_test_input(
|
|
||||||
n_train=20,
|
|
||||||
n_test=10,
|
|
||||||
max_seq_in_batch=3,
|
|
||||||
max_tokens_in_sequence=5,
|
|
||||||
vocab_size=3,
|
|
||||||
seed=12345),
|
|
||||||
balance_attacker_training=False)
|
|
||||||
seq2seq_result = list(result.single_attack_results)[0]
|
|
||||||
np.testing.assert_almost_equal(
|
|
||||||
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
|
||||||
|
|
||||||
def test_run_compute_privacy_risk_score_correct_score(self):
|
def test_run_compute_privacy_risk_score_correct_score(self):
|
||||||
result = mia._compute_privacy_risk_score(
|
result = mia._compute_privacy_risk_score(
|
||||||
AttackInputData(
|
AttackInputData(
|
||||||
|
|
Loading…
Reference in a new issue