From 2312192573fcaea997da085638675bef34eef30b Mon Sep 17 00:00:00 2001 From: Liwei Song Date: Mon, 14 Dec 2020 15:02:56 -0500 Subject: [PATCH] update test code --- .../membership_inference_attack_test.py | 98 ------------------- 1 file changed, 98 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py index fd4db2b..be3092a 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py @@ -19,7 +19,6 @@ import numpy as np from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType -from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec @@ -35,67 +34,6 @@ def get_test_input(n_train, n_test): labels_test=np.array([i % 5 for i in range(n_test)])) -def get_seq2seq_test_input(n_train, - n_test, - max_seq_in_batch, - max_tokens_in_sequence, - vocab_size, - seed=None): - """Returns example inputs for attacks on seq2seq models.""" - if seed is not None: - np.random.seed(seed=seed) - - logits_train, labels_train = [], [] - for _ in range(n_train): - num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1 - batch_logits, batch_labels = _get_batch_logits_and_labels( - num_sequences, max_tokens_in_sequence, vocab_size) - logits_train.append(batch_logits) - labels_train.append(batch_labels) - - logits_test, labels_test = [], [] - for _ in range(n_test): - num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1 - batch_logits, batch_labels = _get_batch_logits_and_labels( - num_sequences, max_tokens_in_sequence, vocab_size) - logits_test.append(batch_logits) - labels_test.append(batch_labels) - - return Seq2SeqAttackInputData( - logits_train=iter(logits_train), - logits_test=iter(logits_test), - labels_train=iter(labels_train), - labels_test=iter(labels_test), - vocab_size=vocab_size, - train_size=n_train, - test_size=n_test) - - -def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, - vocab_size): - num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence, - num_sequences) + 1 - batch_logits, batch_labels = [], [] - for num_tokens in num_tokens_in_sequence: - logits, labels = _get_sequence_logits_and_labels(num_tokens, vocab_size) - batch_logits.append(logits) - batch_labels.append(labels) - return np.array( - batch_logits, dtype=object), np.array( - batch_labels, dtype=object) - - -def _get_sequence_logits_and_labels(num_tokens, vocab_size): - sequence_logits = [] - for _ in range(num_tokens): - token_logits = np.random.random(vocab_size) - token_logits /= token_logits.sum() - sequence_logits.append(token_logits) - sequence_labels = np.random.choice(vocab_size, num_tokens) - return np.array( - sequence_logits, dtype=np.float32), np.array( - sequence_labels, dtype=np.float32) - class RunAttacksTest(absltest.TestCase): @@ -160,42 +98,6 @@ class RunAttacksTest(absltest.TestCase): # If accuracy is already present, simply return it. self.assertIsNone(mia._get_accuracy(None, labels)) - def test_run_seq2seq_attack_size(self): - result = mia.run_seq2seq_attack( - get_seq2seq_test_input( - n_train=10, - n_test=5, - max_seq_in_batch=3, - max_tokens_in_sequence=5, - vocab_size=2)) - - self.assertLen(result.single_attack_results, 1) - - def test_run_seq2seq_attack_trained_sets_attack_type(self): - result = mia.run_seq2seq_attack( - get_seq2seq_test_input( - n_train=10, - n_test=5, - max_seq_in_batch=3, - max_tokens_in_sequence=5, - vocab_size=2)) - seq2seq_result = list(result.single_attack_results)[0] - self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION) - - def test_run_seq2seq_attack_calculates_correct_auc(self): - result = mia.run_seq2seq_attack( - get_seq2seq_test_input( - n_train=20, - n_test=10, - max_seq_in_batch=3, - max_tokens_in_sequence=5, - vocab_size=3, - seed=12345), - balance_attacker_training=False) - seq2seq_result = list(result.single_attack_results)[0] - np.testing.assert_almost_equal( - seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2) - def test_run_compute_privacy_risk_score_correct_score(self): result = mia._compute_privacy_risk_score( AttackInputData(