diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py index d6b9867..5609eab 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py @@ -18,6 +18,7 @@ from absl.testing import absltest import numpy as np from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature @@ -34,6 +35,58 @@ def get_test_input(n_train, n_test): labels_test=np.array([i % 5 for i in range(n_test)])) +def get_seq2seq_test_input(n_train, n_test, max_seq_in_batch, max_tokens_in_sequence, vocab_size, seed=None): + """Get example inputs for attacks on seq2seq models.""" + if seed is not None: + np.random.seed(seed=seed) + + logits_train = [] + labels_train = [] + for i in range(n_train): + num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1 + batch_logits, batch_labels = _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, vocab_size) + logits_train.append(batch_logits) + labels_train.append(batch_labels) + + logits_test = [] + labels_test = [] + for i in range(n_test): + num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1 + batch_logits, batch_labels = _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, vocab_size) + logits_test.append(batch_logits) + labels_test.append(batch_labels) + + return Seq2SeqAttackInputData( + logits_train=iter(logits_train), + logits_test=iter(logits_test), + labels_train=iter(labels_train), + labels_test=iter(labels_test), + vocab_size=vocab_size, + train_size=n_train, + test_size=n_test + ) + + +def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, vocab_size): + num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence, num_sequences) + 1 + batch_logits, batch_labels = [], [] + for num_tokens in num_tokens_in_sequence: + logits, labels = _get_sequence_logits_and_labels(num_tokens, vocab_size) + batch_logits.append(logits) + batch_labels.append(labels) + return np.array(batch_logits, dtype=object), np.array(batch_labels, dtype=object) + + +def _get_sequence_logits_and_labels(num_tokens, vocab_size): + sequence_logits = [] + for i in range(num_tokens): + token_logits = np.random.random(vocab_size) + token_logits /= token_logits.sum() + sequence_logits.append(token_logits) + sequence_labels = np.random.choice(vocab_size, num_tokens) + return np.array(sequence_logits, dtype=np.float32), np.array(sequence_labels, dtype=np.float32) + + class RunAttacksTest(absltest.TestCase): def test_run_attacks_size(self): @@ -97,6 +150,34 @@ class RunAttacksTest(absltest.TestCase): # If accuracy is already present, simply return it. self.assertIsNone(mia._get_accuracy(None, labels)) + def test_run_seq2seq_attack_size(self): + result = mia.run_seq2seq_attack( + get_seq2seq_test_input(n_train=10, n_test=5, + max_seq_in_batch=3, + max_tokens_in_sequence=5, + vocab_size=2)) + + self.assertLen(result.single_attack_results, 1) + + def test_run_seq2seq_attack_trained_sets_attack_type(self): + result = mia.run_seq2seq_attack( + get_seq2seq_test_input(n_train=10, n_test=5, + max_seq_in_batch=3, + max_tokens_in_sequence=5, + vocab_size=2)) + seq2seq_result = list(result.single_attack_results)[0] + self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION) + + def test_run_seq2seq_attack_calculates_correct_auc(self): + result = mia.run_seq2seq_attack( + get_seq2seq_test_input(n_train=20, n_test=10, + max_seq_in_batch=3, + max_tokens_in_sequence=5, + vocab_size=3, seed=12345), + balance_attacker_training=False) + seq2seq_result = list(result.single_attack_results)[0] + np.testing.assert_almost_equal(seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2) + if __name__ == '__main__': absltest.main()