In secret generation for secret sharer, use np.random.RandomState. Restructure generate_secrets.

PiperOrigin-RevId: 430082580
This commit is contained in:
Shuang Song 2022-02-21 13:53:43 -08:00 committed by A. Unique TensorFlower
parent 89de03e0db
commit 04dd758c8a
3 changed files with 99 additions and 67 deletions

View file

@ -16,33 +16,32 @@
import dataclasses
import itertools
import string
from typing import Dict, List
from typing import Any, Dict, MutableSequence, Sequence
import numpy as np
def generate_random_sequences(vocab: List[str],
def generate_random_sequences(vocab: Sequence[str],
pattern: str,
n: int,
seed: int = 1) -> List[str]:
"""Generate random sequences.
seed: int = 1) -> MutableSequence[str]:
"""Generates random sequences.
Args:
vocab: a list, the vocabulary for the sequences
vocab: the vocabulary for the sequences
pattern: the pattern of the sequence. The length of the sequence will be
inferred from the pattern.
n: number of sequences to generate
seed: random seed for numpy.random
Returns:
A list of different random sequences from the given vocabulary
A sequence of different random sequences from the given vocabulary
"""
def count_placeholder(pattern):
return sum([x[1] is not None for x in string.Formatter().parse(pattern)])
length = count_placeholder(pattern)
np.random.seed(seed)
rng = np.random.RandomState(seed)
vocab_size = len(vocab)
if vocab_size**length <= n:
# Generate all possible sequences of the length
@ -56,60 +55,71 @@ def generate_random_sequences(vocab: List[str],
idx = np.empty((0, length), dtype=int)
while idx.shape[0] < n:
# Generate a new set of indices
idx_new = np.random.randint(vocab_size, size=(n, length))
idx_new = rng.randint(vocab_size, size=(n, length))
idx = np.concatenate([idx, idx_new], axis=0) # Add to existing indices
idx = np.unique(idx, axis=0) # Remove duplicated indices
idx = idx[:n]
seq = np.array(vocab)[idx]
# Join each row to get the sequence
seq = np.apply_along_axis(lambda x: pattern.format(*list(x)), 1, seq)
seq = seq[np.random.permutation(n)]
seq = seq[rng.permutation(n)]
return list(seq)
@dataclasses.dataclass
class TextSecretProperties:
"""Properties of text secret.
vocab: the vocabulary for the secrets
pattern: the pattern of the secrets
"""
vocab: Sequence[str]
pattern: str
@dataclasses.dataclass
class SecretConfig:
"""Configuration of secret for secrets sharer.
vocab: a list, the vocabulary for the secrets
pattern: the pattern of the secrets
num_repetitions: a list, number of repetitions for the secrets
num_secrets_for_repetitions: a list, number of secrets to be used for
different number of repetitions
num_repetitions: numbers of repetitions for the secrets
num_secrets_for_repetitions: numbers of secrets to be used for each
number of repetitions
num_references: number of references sequences, i.e. random sequences that
will not be inserted into training data
name: name that identifies the secrets set
properties: any properties of the secret, e.g. the vocabulary, the pattern
"""
vocab: List[str]
pattern: str
num_repetitions: List[int]
num_secrets_for_repetitions: List[int]
num_repetitions: Sequence[int]
num_secrets_for_repetitions: Sequence[int]
num_references: int
name: str = ''
properties: Any = None
@dataclasses.dataclass
class Secrets:
"""Secrets for secrets sharer.
class SecretsSet:
"""A secrets set for secrets sharer.
config: configuration of the secrets
secrets: a dictionary, key is the number of repetitions, value is a list of
different secrets
references: a list of references
secrets: a dictionary, key is the number of repetitions, value is a sequence
of different secrets
references: a sequence of references
"""
config: SecretConfig
secrets: Dict[int, List[str]]
references: List[str]
secrets: Dict[int, Sequence[Any]]
references: Sequence[Any]
def construct_secret(secret_config: SecretConfig, seqs: List[str]) -> Secrets:
"""Construct a secret instance.
def construct_secret(secret_config: SecretConfig,
seqs: Sequence[Any]) -> SecretsSet:
"""Constructs a SecretsSet instance given a sequence of samples.
Args:
secret_config: configuration of secret.
seqs: a list of random sequences that will be used for secrets and
references.
seqs: a sequence of samples that will be used for secrets and references.
Returns:
a secret instance.
a SecretsSet instance.
"""
if len(seqs) < sum(
secret_config.num_secrets_for_repetitions) + secret_config.num_references:
@ -120,32 +130,51 @@ def construct_secret(secret_config: SecretConfig, seqs: List[str]) -> Secrets:
secret_config.num_repetitions, secret_config.num_secrets_for_repetitions):
secrets[num_repetition] = seqs[i:i + num_secrets]
i += num_secrets
return Secrets(
return SecretsSet(
config=secret_config,
secrets=secrets,
references=seqs[-secret_config.num_references:])
def generate_secrets_and_references(secret_configs: List[SecretConfig],
seed: int = 0) -> List[Secrets]:
"""Generate a list of secret instances given a list of configurations.
def generate_text_secrets_and_references(
secret_configs: Sequence[SecretConfig],
seed: int = 0) -> MutableSequence[SecretsSet]:
"""Generates a sequence of text secret sets given a sequence of configurations.
Args:
secret_configs: a list of secret configurations.
secret_configs: a sequence of text secret configurations.
seed: random seed.
Returns:
A list of secret instances.
A sequence of SecretsSet instances.
"""
secrets = []
secrets_sets = []
for i, secret_config in enumerate(secret_configs):
n = secret_config.num_references + sum(
secret_config.num_secrets_for_repetitions)
seqs = generate_random_sequences(secret_config.vocab, secret_config.pattern,
n, seed + i)
seqs = generate_random_sequences(secret_config.properties.vocab,
secret_config.properties.pattern, n,
seed + i)
if len(seqs) < n:
raise ValueError(
f'generate_random_sequences was not able to generate {n} sequences. Need to increase vocabulary.'
)
secrets.append(construct_secret(secret_config, seqs))
return secrets
secrets_sets.append(construct_secret(secret_config, seqs))
return secrets_sets
def construct_secret_dataset(
secrets_sets: Sequence[SecretsSet]) -> MutableSequence[Any]:
"""Repeats secrets for the required number of times to get a secret dataset.
Args:
secrets_sets: a sequence of secert sets.
Returns:
A sequence of samples.
"""
secrets_dataset = []
for secrets_set in secrets_sets:
for r, seqs in secrets_set.secrets.items():
secrets_dataset += list(seqs) * r
return secrets_dataset

View file

@ -13,10 +13,7 @@
# limitations under the License.
from absl.testing import absltest
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import construct_secret
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_random_sequences
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_secrets_and_references
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig
from tensorflow_privacy.privacy.privacy_tests.secret_sharer import generate_secrets as gs
class UtilsTest(absltest.TestCase):
@ -28,14 +25,15 @@ class UtilsTest(absltest.TestCase):
def test_generate_random_sequences(self):
"""Test generate_random_sequences."""
# Test when n is larger than total number of possible sequences.
seqs = generate_random_sequences(['A', 'b', 'c'], '{}+{}', 10, seed=27)
seqs = gs.generate_random_sequences(['A', 'b', 'c'], '{}+{}', 10, seed=27)
expected_seqs = [
'A+c', 'c+c', 'b+b', 'A+b', 'b+c', 'c+A', 'c+b', 'A+A', 'b+A'
]
self.assertEqual(seqs, expected_seqs)
# Test when n is smaller than total number of possible sequences.
seqs = generate_random_sequences(list('01234'), 'prefix {}{}{}?', 8, seed=9)
seqs = gs.generate_random_sequences(
list('01234'), 'prefix {}{}{}?', 8, seed=9)
expected_seqs = [
'prefix 143?', 'prefix 031?', 'prefix 302?', 'prefix 042?',
'prefix 404?', 'prefix 024?', 'prefix 021?', 'prefix 403?'
@ -43,14 +41,14 @@ class UtilsTest(absltest.TestCase):
self.assertEqual(seqs, expected_seqs)
def test_construct_secret(self):
secret_config = SecretConfig(
vocab=None,
pattern='',
secret_config = gs.SecretConfig(
num_repetitions=[1, 2, 8],
num_secrets_for_repetitions=[2, 3, 1],
num_references=3)
num_references=3,
name='random secrets',
properties=gs.TextSecretProperties(vocab=None, pattern=''))
seqs = list('0123456789')
secrets = construct_secret(secret_config, seqs)
secrets = gs.construct_secret(secret_config, seqs)
self.assertEqual(secrets.config, secret_config)
self.assertDictEqual(secrets.secrets, {
1: ['0', '1'],
@ -61,24 +59,29 @@ class UtilsTest(absltest.TestCase):
# Test when the number of elements in seqs is not enough.
seqs = list('01234567')
self.assertRaises(ValueError, construct_secret, secret_config, seqs)
self.assertRaises(ValueError, gs.construct_secret, secret_config, seqs)
def test_generate_secrets_and_references(self):
secret_configs = [
SecretConfig(
vocab=['w1', 'w2', 'w3'],
pattern='{} {} suf',
gs.SecretConfig(
num_repetitions=[1, 12],
num_secrets_for_repetitions=[2, 1],
num_references=3),
SecretConfig(
vocab=['W 1', 'W 2', 'W 3'],
pattern='{}-{}',
num_references=3,
name='secret1',
properties=gs.TextSecretProperties(
vocab=['w1', 'w2', 'w3'], pattern='{} {} suf'),
),
gs.SecretConfig(
num_repetitions=[1, 2, 8],
num_secrets_for_repetitions=[2, 3, 1],
num_references=3)
num_references=3,
name='secert2',
properties=gs.TextSecretProperties(
vocab=['W 1', 'W 2', 'W 3'],
pattern='{}-{}',
))
]
secrets = generate_secrets_and_references(secret_configs, seed=27)
secrets = gs.generate_text_secrets_and_references(secret_configs, seed=27)
self.assertEqual(secrets[0].config, secret_configs[0])
self.assertDictEqual(secrets[0].secrets, {
1: ['w3 w2 suf', 'w2 w1 suf'],

View file

@ -115,7 +115,7 @@
"import tensorflow as tf\n",
"from official.utils.misc import keras_utils\n",
"\n",
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig, generate_secrets_and_references, construct_secret\n",
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig, TextSecretProperties, generate_text_secrets_and_references, construct_secret\n",
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.exposures import compute_exposure_interpolation, compute_exposure_extrapolation"
]
},
@ -430,10 +430,10 @@
"num_repetitions = [1, 10, 100]\n",
"num_secrets_for_repetitions = [20] * len(num_repetitions)\n",
"num_references = 65536\n",
"secret_configs = [SecretConfig(vocab, pattern, num_repetitions,\n",
" num_secrets_for_repetitions, num_references)\n",
"secret_configs = [SecretConfig(num_repetitions, num_secrets_for_repetitions, num_references,\n",
" properties=TextSecretProperties(vocab, pattern))\n",
" for vocab, pattern in zip(vocabs, patterns)]\n",
"secrets = generate_secrets_and_references(secret_configs)\n",
"secrets = generate_text_secrets_and_references(secret_configs)\n",
"\n",
"# Let's look at the variable \"secrets\"\n",
"print(f'\"secrets\" is a list and the length is {len(secrets)} because we have four sets of secrets.')\n",