In secret generation for secret sharer, use np.random.RandomState. Restructure generate_secrets.

PiperOrigin-RevId: 430082580
This commit is contained in:
Shuang Song 2022-02-21 13:53:43 -08:00 committed by A. Unique TensorFlower
parent 89de03e0db
commit 04dd758c8a
3 changed files with 99 additions and 67 deletions

View file

@ -16,33 +16,32 @@
import dataclasses import dataclasses
import itertools import itertools
import string import string
from typing import Dict, List from typing import Any, Dict, MutableSequence, Sequence
import numpy as np import numpy as np
def generate_random_sequences(vocab: List[str], def generate_random_sequences(vocab: Sequence[str],
pattern: str, pattern: str,
n: int, n: int,
seed: int = 1) -> List[str]: seed: int = 1) -> MutableSequence[str]:
"""Generate random sequences. """Generates random sequences.
Args: Args:
vocab: a list, the vocabulary for the sequences vocab: the vocabulary for the sequences
pattern: the pattern of the sequence. The length of the sequence will be pattern: the pattern of the sequence. The length of the sequence will be
inferred from the pattern. inferred from the pattern.
n: number of sequences to generate n: number of sequences to generate
seed: random seed for numpy.random seed: random seed for numpy.random
Returns: Returns:
A list of different random sequences from the given vocabulary A sequence of different random sequences from the given vocabulary
""" """
def count_placeholder(pattern): def count_placeholder(pattern):
return sum([x[1] is not None for x in string.Formatter().parse(pattern)]) return sum([x[1] is not None for x in string.Formatter().parse(pattern)])
length = count_placeholder(pattern) length = count_placeholder(pattern)
np.random.seed(seed) rng = np.random.RandomState(seed)
vocab_size = len(vocab) vocab_size = len(vocab)
if vocab_size**length <= n: if vocab_size**length <= n:
# Generate all possible sequences of the length # Generate all possible sequences of the length
@ -56,60 +55,71 @@ def generate_random_sequences(vocab: List[str],
idx = np.empty((0, length), dtype=int) idx = np.empty((0, length), dtype=int)
while idx.shape[0] < n: while idx.shape[0] < n:
# Generate a new set of indices # Generate a new set of indices
idx_new = np.random.randint(vocab_size, size=(n, length)) idx_new = rng.randint(vocab_size, size=(n, length))
idx = np.concatenate([idx, idx_new], axis=0) # Add to existing indices idx = np.concatenate([idx, idx_new], axis=0) # Add to existing indices
idx = np.unique(idx, axis=0) # Remove duplicated indices idx = np.unique(idx, axis=0) # Remove duplicated indices
idx = idx[:n] idx = idx[:n]
seq = np.array(vocab)[idx] seq = np.array(vocab)[idx]
# Join each row to get the sequence # Join each row to get the sequence
seq = np.apply_along_axis(lambda x: pattern.format(*list(x)), 1, seq) seq = np.apply_along_axis(lambda x: pattern.format(*list(x)), 1, seq)
seq = seq[np.random.permutation(n)] seq = seq[rng.permutation(n)]
return list(seq) return list(seq)
@dataclasses.dataclass
class TextSecretProperties:
"""Properties of text secret.
vocab: the vocabulary for the secrets
pattern: the pattern of the secrets
"""
vocab: Sequence[str]
pattern: str
@dataclasses.dataclass @dataclasses.dataclass
class SecretConfig: class SecretConfig:
"""Configuration of secret for secrets sharer. """Configuration of secret for secrets sharer.
vocab: a list, the vocabulary for the secrets num_repetitions: numbers of repetitions for the secrets
pattern: the pattern of the secrets num_secrets_for_repetitions: numbers of secrets to be used for each
num_repetitions: a list, number of repetitions for the secrets number of repetitions
num_secrets_for_repetitions: a list, number of secrets to be used for
different number of repetitions
num_references: number of references sequences, i.e. random sequences that num_references: number of references sequences, i.e. random sequences that
will not be inserted into training data will not be inserted into training data
name: name that identifies the secrets set
properties: any properties of the secret, e.g. the vocabulary, the pattern
""" """
vocab: List[str] num_repetitions: Sequence[int]
pattern: str num_secrets_for_repetitions: Sequence[int]
num_repetitions: List[int]
num_secrets_for_repetitions: List[int]
num_references: int num_references: int
name: str = ''
properties: Any = None
@dataclasses.dataclass @dataclasses.dataclass
class Secrets: class SecretsSet:
"""Secrets for secrets sharer. """A secrets set for secrets sharer.
config: configuration of the secrets config: configuration of the secrets
secrets: a dictionary, key is the number of repetitions, value is a list of secrets: a dictionary, key is the number of repetitions, value is a sequence
different secrets of different secrets
references: a list of references references: a sequence of references
""" """
config: SecretConfig config: SecretConfig
secrets: Dict[int, List[str]] secrets: Dict[int, Sequence[Any]]
references: List[str] references: Sequence[Any]
def construct_secret(secret_config: SecretConfig, seqs: List[str]) -> Secrets: def construct_secret(secret_config: SecretConfig,
"""Construct a secret instance. seqs: Sequence[Any]) -> SecretsSet:
"""Constructs a SecretsSet instance given a sequence of samples.
Args: Args:
secret_config: configuration of secret. secret_config: configuration of secret.
seqs: a list of random sequences that will be used for secrets and seqs: a sequence of samples that will be used for secrets and references.
references.
Returns: Returns:
a secret instance. a SecretsSet instance.
""" """
if len(seqs) < sum( if len(seqs) < sum(
secret_config.num_secrets_for_repetitions) + secret_config.num_references: secret_config.num_secrets_for_repetitions) + secret_config.num_references:
@ -120,32 +130,51 @@ def construct_secret(secret_config: SecretConfig, seqs: List[str]) -> Secrets:
secret_config.num_repetitions, secret_config.num_secrets_for_repetitions): secret_config.num_repetitions, secret_config.num_secrets_for_repetitions):
secrets[num_repetition] = seqs[i:i + num_secrets] secrets[num_repetition] = seqs[i:i + num_secrets]
i += num_secrets i += num_secrets
return Secrets( return SecretsSet(
config=secret_config, config=secret_config,
secrets=secrets, secrets=secrets,
references=seqs[-secret_config.num_references:]) references=seqs[-secret_config.num_references:])
def generate_secrets_and_references(secret_configs: List[SecretConfig], def generate_text_secrets_and_references(
seed: int = 0) -> List[Secrets]: secret_configs: Sequence[SecretConfig],
"""Generate a list of secret instances given a list of configurations. seed: int = 0) -> MutableSequence[SecretsSet]:
"""Generates a sequence of text secret sets given a sequence of configurations.
Args: Args:
secret_configs: a list of secret configurations. secret_configs: a sequence of text secret configurations.
seed: random seed. seed: random seed.
Returns: Returns:
A list of secret instances. A sequence of SecretsSet instances.
""" """
secrets = [] secrets_sets = []
for i, secret_config in enumerate(secret_configs): for i, secret_config in enumerate(secret_configs):
n = secret_config.num_references + sum( n = secret_config.num_references + sum(
secret_config.num_secrets_for_repetitions) secret_config.num_secrets_for_repetitions)
seqs = generate_random_sequences(secret_config.vocab, secret_config.pattern, seqs = generate_random_sequences(secret_config.properties.vocab,
n, seed + i) secret_config.properties.pattern, n,
seed + i)
if len(seqs) < n: if len(seqs) < n:
raise ValueError( raise ValueError(
f'generate_random_sequences was not able to generate {n} sequences. Need to increase vocabulary.' f'generate_random_sequences was not able to generate {n} sequences. Need to increase vocabulary.'
) )
secrets.append(construct_secret(secret_config, seqs)) secrets_sets.append(construct_secret(secret_config, seqs))
return secrets return secrets_sets
def construct_secret_dataset(
secrets_sets: Sequence[SecretsSet]) -> MutableSequence[Any]:
"""Repeats secrets for the required number of times to get a secret dataset.
Args:
secrets_sets: a sequence of secert sets.
Returns:
A sequence of samples.
"""
secrets_dataset = []
for secrets_set in secrets_sets:
for r, seqs in secrets_set.secrets.items():
secrets_dataset += list(seqs) * r
return secrets_dataset

View file

@ -13,10 +13,7 @@
# limitations under the License. # limitations under the License.
from absl.testing import absltest from absl.testing import absltest
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import construct_secret from tensorflow_privacy.privacy.privacy_tests.secret_sharer import generate_secrets as gs
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_random_sequences
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_secrets_and_references
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig
class UtilsTest(absltest.TestCase): class UtilsTest(absltest.TestCase):
@ -28,14 +25,15 @@ class UtilsTest(absltest.TestCase):
def test_generate_random_sequences(self): def test_generate_random_sequences(self):
"""Test generate_random_sequences.""" """Test generate_random_sequences."""
# Test when n is larger than total number of possible sequences. # Test when n is larger than total number of possible sequences.
seqs = generate_random_sequences(['A', 'b', 'c'], '{}+{}', 10, seed=27) seqs = gs.generate_random_sequences(['A', 'b', 'c'], '{}+{}', 10, seed=27)
expected_seqs = [ expected_seqs = [
'A+c', 'c+c', 'b+b', 'A+b', 'b+c', 'c+A', 'c+b', 'A+A', 'b+A' 'A+c', 'c+c', 'b+b', 'A+b', 'b+c', 'c+A', 'c+b', 'A+A', 'b+A'
] ]
self.assertEqual(seqs, expected_seqs) self.assertEqual(seqs, expected_seqs)
# Test when n is smaller than total number of possible sequences. # Test when n is smaller than total number of possible sequences.
seqs = generate_random_sequences(list('01234'), 'prefix {}{}{}?', 8, seed=9) seqs = gs.generate_random_sequences(
list('01234'), 'prefix {}{}{}?', 8, seed=9)
expected_seqs = [ expected_seqs = [
'prefix 143?', 'prefix 031?', 'prefix 302?', 'prefix 042?', 'prefix 143?', 'prefix 031?', 'prefix 302?', 'prefix 042?',
'prefix 404?', 'prefix 024?', 'prefix 021?', 'prefix 403?' 'prefix 404?', 'prefix 024?', 'prefix 021?', 'prefix 403?'
@ -43,14 +41,14 @@ class UtilsTest(absltest.TestCase):
self.assertEqual(seqs, expected_seqs) self.assertEqual(seqs, expected_seqs)
def test_construct_secret(self): def test_construct_secret(self):
secret_config = SecretConfig( secret_config = gs.SecretConfig(
vocab=None,
pattern='',
num_repetitions=[1, 2, 8], num_repetitions=[1, 2, 8],
num_secrets_for_repetitions=[2, 3, 1], num_secrets_for_repetitions=[2, 3, 1],
num_references=3) num_references=3,
name='random secrets',
properties=gs.TextSecretProperties(vocab=None, pattern=''))
seqs = list('0123456789') seqs = list('0123456789')
secrets = construct_secret(secret_config, seqs) secrets = gs.construct_secret(secret_config, seqs)
self.assertEqual(secrets.config, secret_config) self.assertEqual(secrets.config, secret_config)
self.assertDictEqual(secrets.secrets, { self.assertDictEqual(secrets.secrets, {
1: ['0', '1'], 1: ['0', '1'],
@ -61,24 +59,29 @@ class UtilsTest(absltest.TestCase):
# Test when the number of elements in seqs is not enough. # Test when the number of elements in seqs is not enough.
seqs = list('01234567') seqs = list('01234567')
self.assertRaises(ValueError, construct_secret, secret_config, seqs) self.assertRaises(ValueError, gs.construct_secret, secret_config, seqs)
def test_generate_secrets_and_references(self): def test_generate_secrets_and_references(self):
secret_configs = [ secret_configs = [
SecretConfig( gs.SecretConfig(
vocab=['w1', 'w2', 'w3'],
pattern='{} {} suf',
num_repetitions=[1, 12], num_repetitions=[1, 12],
num_secrets_for_repetitions=[2, 1], num_secrets_for_repetitions=[2, 1],
num_references=3), num_references=3,
SecretConfig( name='secret1',
vocab=['W 1', 'W 2', 'W 3'], properties=gs.TextSecretProperties(
pattern='{}-{}', vocab=['w1', 'w2', 'w3'], pattern='{} {} suf'),
),
gs.SecretConfig(
num_repetitions=[1, 2, 8], num_repetitions=[1, 2, 8],
num_secrets_for_repetitions=[2, 3, 1], num_secrets_for_repetitions=[2, 3, 1],
num_references=3) num_references=3,
name='secert2',
properties=gs.TextSecretProperties(
vocab=['W 1', 'W 2', 'W 3'],
pattern='{}-{}',
))
] ]
secrets = generate_secrets_and_references(secret_configs, seed=27) secrets = gs.generate_text_secrets_and_references(secret_configs, seed=27)
self.assertEqual(secrets[0].config, secret_configs[0]) self.assertEqual(secrets[0].config, secret_configs[0])
self.assertDictEqual(secrets[0].secrets, { self.assertDictEqual(secrets[0].secrets, {
1: ['w3 w2 suf', 'w2 w1 suf'], 1: ['w3 w2 suf', 'w2 w1 suf'],

View file

@ -115,7 +115,7 @@
"import tensorflow as tf\n", "import tensorflow as tf\n",
"from official.utils.misc import keras_utils\n", "from official.utils.misc import keras_utils\n",
"\n", "\n",
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig, generate_secrets_and_references, construct_secret\n", "from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig, TextSecretProperties, generate_text_secrets_and_references, construct_secret\n",
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.exposures import compute_exposure_interpolation, compute_exposure_extrapolation" "from tensorflow_privacy.privacy.privacy_tests.secret_sharer.exposures import compute_exposure_interpolation, compute_exposure_extrapolation"
] ]
}, },
@ -430,10 +430,10 @@
"num_repetitions = [1, 10, 100]\n", "num_repetitions = [1, 10, 100]\n",
"num_secrets_for_repetitions = [20] * len(num_repetitions)\n", "num_secrets_for_repetitions = [20] * len(num_repetitions)\n",
"num_references = 65536\n", "num_references = 65536\n",
"secret_configs = [SecretConfig(vocab, pattern, num_repetitions,\n", "secret_configs = [SecretConfig(num_repetitions, num_secrets_for_repetitions, num_references,\n",
" num_secrets_for_repetitions, num_references)\n", " properties=TextSecretProperties(vocab, pattern))\n",
" for vocab, pattern in zip(vocabs, patterns)]\n", " for vocab, pattern in zip(vocabs, patterns)]\n",
"secrets = generate_secrets_and_references(secret_configs)\n", "secrets = generate_text_secrets_and_references(secret_configs)\n",
"\n", "\n",
"# Let's look at the variable \"secrets\"\n", "# Let's look at the variable \"secrets\"\n",
"print(f'\"secrets\" is a list and the length is {len(secrets)} because we have four sets of secrets.')\n", "print(f'\"secrets\" is a list and the length is {len(secrets)} because we have four sets of secrets.')\n",