From 04dd758c8aeef7556226bcaca59ae1f349443dec Mon Sep 17 00:00:00 2001 From: Shuang Song Date: Mon, 21 Feb 2022 13:53:43 -0800 Subject: [PATCH] In secret generation for secret sharer, use np.random.RandomState. Restructure generate_secrets. PiperOrigin-RevId: 430082580 --- .../secret_sharer/generate_secrets.py | 113 +++++++++++------- .../secret_sharer/generate_secrets_test.py | 45 +++---- .../secret_sharer/secret_sharer_example.ipynb | 8 +- 3 files changed, 99 insertions(+), 67 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/secret_sharer/generate_secrets.py b/tensorflow_privacy/privacy/privacy_tests/secret_sharer/generate_secrets.py index 7fdd961..945d4d2 100644 --- a/tensorflow_privacy/privacy/privacy_tests/secret_sharer/generate_secrets.py +++ b/tensorflow_privacy/privacy/privacy_tests/secret_sharer/generate_secrets.py @@ -16,33 +16,32 @@ import dataclasses import itertools import string -from typing import Dict, List - +from typing import Any, Dict, MutableSequence, Sequence import numpy as np -def generate_random_sequences(vocab: List[str], +def generate_random_sequences(vocab: Sequence[str], pattern: str, n: int, - seed: int = 1) -> List[str]: - """Generate random sequences. + seed: int = 1) -> MutableSequence[str]: + """Generates random sequences. Args: - vocab: a list, the vocabulary for the sequences + vocab: the vocabulary for the sequences pattern: the pattern of the sequence. The length of the sequence will be inferred from the pattern. n: number of sequences to generate seed: random seed for numpy.random Returns: - A list of different random sequences from the given vocabulary + A sequence of different random sequences from the given vocabulary """ def count_placeholder(pattern): return sum([x[1] is not None for x in string.Formatter().parse(pattern)]) length = count_placeholder(pattern) - np.random.seed(seed) + rng = np.random.RandomState(seed) vocab_size = len(vocab) if vocab_size**length <= n: # Generate all possible sequences of the length @@ -56,60 +55,71 @@ def generate_random_sequences(vocab: List[str], idx = np.empty((0, length), dtype=int) while idx.shape[0] < n: # Generate a new set of indices - idx_new = np.random.randint(vocab_size, size=(n, length)) + idx_new = rng.randint(vocab_size, size=(n, length)) idx = np.concatenate([idx, idx_new], axis=0) # Add to existing indices idx = np.unique(idx, axis=0) # Remove duplicated indices idx = idx[:n] seq = np.array(vocab)[idx] # Join each row to get the sequence seq = np.apply_along_axis(lambda x: pattern.format(*list(x)), 1, seq) - seq = seq[np.random.permutation(n)] + seq = seq[rng.permutation(n)] return list(seq) +@dataclasses.dataclass +class TextSecretProperties: + """Properties of text secret. + + vocab: the vocabulary for the secrets + pattern: the pattern of the secrets + """ + vocab: Sequence[str] + pattern: str + + @dataclasses.dataclass class SecretConfig: """Configuration of secret for secrets sharer. - vocab: a list, the vocabulary for the secrets - pattern: the pattern of the secrets - num_repetitions: a list, number of repetitions for the secrets - num_secrets_for_repetitions: a list, number of secrets to be used for - different number of repetitions + num_repetitions: numbers of repetitions for the secrets + num_secrets_for_repetitions: numbers of secrets to be used for each + number of repetitions num_references: number of references sequences, i.e. random sequences that will not be inserted into training data + name: name that identifies the secrets set + properties: any properties of the secret, e.g. the vocabulary, the pattern """ - vocab: List[str] - pattern: str - num_repetitions: List[int] - num_secrets_for_repetitions: List[int] + num_repetitions: Sequence[int] + num_secrets_for_repetitions: Sequence[int] num_references: int + name: str = '' + properties: Any = None @dataclasses.dataclass -class Secrets: - """Secrets for secrets sharer. +class SecretsSet: + """A secrets set for secrets sharer. config: configuration of the secrets - secrets: a dictionary, key is the number of repetitions, value is a list of - different secrets - references: a list of references + secrets: a dictionary, key is the number of repetitions, value is a sequence + of different secrets + references: a sequence of references """ config: SecretConfig - secrets: Dict[int, List[str]] - references: List[str] + secrets: Dict[int, Sequence[Any]] + references: Sequence[Any] -def construct_secret(secret_config: SecretConfig, seqs: List[str]) -> Secrets: - """Construct a secret instance. +def construct_secret(secret_config: SecretConfig, + seqs: Sequence[Any]) -> SecretsSet: + """Constructs a SecretsSet instance given a sequence of samples. Args: secret_config: configuration of secret. - seqs: a list of random sequences that will be used for secrets and - references. + seqs: a sequence of samples that will be used for secrets and references. Returns: - a secret instance. + a SecretsSet instance. """ if len(seqs) < sum( secret_config.num_secrets_for_repetitions) + secret_config.num_references: @@ -120,32 +130,51 @@ def construct_secret(secret_config: SecretConfig, seqs: List[str]) -> Secrets: secret_config.num_repetitions, secret_config.num_secrets_for_repetitions): secrets[num_repetition] = seqs[i:i + num_secrets] i += num_secrets - return Secrets( + return SecretsSet( config=secret_config, secrets=secrets, references=seqs[-secret_config.num_references:]) -def generate_secrets_and_references(secret_configs: List[SecretConfig], - seed: int = 0) -> List[Secrets]: - """Generate a list of secret instances given a list of configurations. +def generate_text_secrets_and_references( + secret_configs: Sequence[SecretConfig], + seed: int = 0) -> MutableSequence[SecretsSet]: + """Generates a sequence of text secret sets given a sequence of configurations. Args: - secret_configs: a list of secret configurations. + secret_configs: a sequence of text secret configurations. seed: random seed. Returns: - A list of secret instances. + A sequence of SecretsSet instances. """ - secrets = [] + secrets_sets = [] for i, secret_config in enumerate(secret_configs): n = secret_config.num_references + sum( secret_config.num_secrets_for_repetitions) - seqs = generate_random_sequences(secret_config.vocab, secret_config.pattern, - n, seed + i) + seqs = generate_random_sequences(secret_config.properties.vocab, + secret_config.properties.pattern, n, + seed + i) if len(seqs) < n: raise ValueError( f'generate_random_sequences was not able to generate {n} sequences. Need to increase vocabulary.' ) - secrets.append(construct_secret(secret_config, seqs)) - return secrets + secrets_sets.append(construct_secret(secret_config, seqs)) + return secrets_sets + + +def construct_secret_dataset( + secrets_sets: Sequence[SecretsSet]) -> MutableSequence[Any]: + """Repeats secrets for the required number of times to get a secret dataset. + + Args: + secrets_sets: a sequence of secert sets. + + Returns: + A sequence of samples. + """ + secrets_dataset = [] + for secrets_set in secrets_sets: + for r, seqs in secrets_set.secrets.items(): + secrets_dataset += list(seqs) * r + return secrets_dataset diff --git a/tensorflow_privacy/privacy/privacy_tests/secret_sharer/generate_secrets_test.py b/tensorflow_privacy/privacy/privacy_tests/secret_sharer/generate_secrets_test.py index 77fd5a8..9996d52 100644 --- a/tensorflow_privacy/privacy/privacy_tests/secret_sharer/generate_secrets_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/secret_sharer/generate_secrets_test.py @@ -13,10 +13,7 @@ # limitations under the License. from absl.testing import absltest -from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import construct_secret -from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_random_sequences -from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_secrets_and_references -from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig +from tensorflow_privacy.privacy.privacy_tests.secret_sharer import generate_secrets as gs class UtilsTest(absltest.TestCase): @@ -28,14 +25,15 @@ class UtilsTest(absltest.TestCase): def test_generate_random_sequences(self): """Test generate_random_sequences.""" # Test when n is larger than total number of possible sequences. - seqs = generate_random_sequences(['A', 'b', 'c'], '{}+{}', 10, seed=27) + seqs = gs.generate_random_sequences(['A', 'b', 'c'], '{}+{}', 10, seed=27) expected_seqs = [ 'A+c', 'c+c', 'b+b', 'A+b', 'b+c', 'c+A', 'c+b', 'A+A', 'b+A' ] self.assertEqual(seqs, expected_seqs) # Test when n is smaller than total number of possible sequences. - seqs = generate_random_sequences(list('01234'), 'prefix {}{}{}?', 8, seed=9) + seqs = gs.generate_random_sequences( + list('01234'), 'prefix {}{}{}?', 8, seed=9) expected_seqs = [ 'prefix 143?', 'prefix 031?', 'prefix 302?', 'prefix 042?', 'prefix 404?', 'prefix 024?', 'prefix 021?', 'prefix 403?' @@ -43,14 +41,14 @@ class UtilsTest(absltest.TestCase): self.assertEqual(seqs, expected_seqs) def test_construct_secret(self): - secret_config = SecretConfig( - vocab=None, - pattern='', + secret_config = gs.SecretConfig( num_repetitions=[1, 2, 8], num_secrets_for_repetitions=[2, 3, 1], - num_references=3) + num_references=3, + name='random secrets', + properties=gs.TextSecretProperties(vocab=None, pattern='')) seqs = list('0123456789') - secrets = construct_secret(secret_config, seqs) + secrets = gs.construct_secret(secret_config, seqs) self.assertEqual(secrets.config, secret_config) self.assertDictEqual(secrets.secrets, { 1: ['0', '1'], @@ -61,24 +59,29 @@ class UtilsTest(absltest.TestCase): # Test when the number of elements in seqs is not enough. seqs = list('01234567') - self.assertRaises(ValueError, construct_secret, secret_config, seqs) + self.assertRaises(ValueError, gs.construct_secret, secret_config, seqs) def test_generate_secrets_and_references(self): secret_configs = [ - SecretConfig( - vocab=['w1', 'w2', 'w3'], - pattern='{} {} suf', + gs.SecretConfig( num_repetitions=[1, 12], num_secrets_for_repetitions=[2, 1], - num_references=3), - SecretConfig( - vocab=['W 1', 'W 2', 'W 3'], - pattern='{}-{}', + num_references=3, + name='secret1', + properties=gs.TextSecretProperties( + vocab=['w1', 'w2', 'w3'], pattern='{} {} suf'), + ), + gs.SecretConfig( num_repetitions=[1, 2, 8], num_secrets_for_repetitions=[2, 3, 1], - num_references=3) + num_references=3, + name='secert2', + properties=gs.TextSecretProperties( + vocab=['W 1', 'W 2', 'W 3'], + pattern='{}-{}', + )) ] - secrets = generate_secrets_and_references(secret_configs, seed=27) + secrets = gs.generate_text_secrets_and_references(secret_configs, seed=27) self.assertEqual(secrets[0].config, secret_configs[0]) self.assertDictEqual(secrets[0].secrets, { 1: ['w3 w2 suf', 'w2 w1 suf'], diff --git a/tensorflow_privacy/privacy/privacy_tests/secret_sharer/secret_sharer_example.ipynb b/tensorflow_privacy/privacy/privacy_tests/secret_sharer/secret_sharer_example.ipynb index cae066a..f7c38fd 100644 --- a/tensorflow_privacy/privacy/privacy_tests/secret_sharer/secret_sharer_example.ipynb +++ b/tensorflow_privacy/privacy/privacy_tests/secret_sharer/secret_sharer_example.ipynb @@ -115,7 +115,7 @@ "import tensorflow as tf\n", "from official.utils.misc import keras_utils\n", "\n", - "from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig, generate_secrets_and_references, construct_secret\n", + "from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig, TextSecretProperties, generate_text_secrets_and_references, construct_secret\n", "from tensorflow_privacy.privacy.privacy_tests.secret_sharer.exposures import compute_exposure_interpolation, compute_exposure_extrapolation" ] }, @@ -430,10 +430,10 @@ "num_repetitions = [1, 10, 100]\n", "num_secrets_for_repetitions = [20] * len(num_repetitions)\n", "num_references = 65536\n", - "secret_configs = [SecretConfig(vocab, pattern, num_repetitions,\n", - " num_secrets_for_repetitions, num_references)\n", + "secret_configs = [SecretConfig(num_repetitions, num_secrets_for_repetitions, num_references,\n", + " properties=TextSecretProperties(vocab, pattern))\n", " for vocab, pattern in zip(vocabs, patterns)]\n", - "secrets = generate_secrets_and_references(secret_configs)\n", + "secrets = generate_text_secrets_and_references(secret_configs)\n", "\n", "# Let's look at the variable \"secrets\"\n", "print(f'\"secrets\" is a list and the length is {len(secrets)} because we have four sets of secrets.')\n",