In secret generation for secret sharer, use np.random.RandomState. Restructure generate_secrets.
PiperOrigin-RevId: 430082580
This commit is contained in:
parent
89de03e0db
commit
04dd758c8a
3 changed files with 99 additions and 67 deletions
|
@ -16,33 +16,32 @@
|
|||
import dataclasses
|
||||
import itertools
|
||||
import string
|
||||
from typing import Dict, List
|
||||
|
||||
from typing import Any, Dict, MutableSequence, Sequence
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_random_sequences(vocab: List[str],
|
||||
def generate_random_sequences(vocab: Sequence[str],
|
||||
pattern: str,
|
||||
n: int,
|
||||
seed: int = 1) -> List[str]:
|
||||
"""Generate random sequences.
|
||||
seed: int = 1) -> MutableSequence[str]:
|
||||
"""Generates random sequences.
|
||||
|
||||
Args:
|
||||
vocab: a list, the vocabulary for the sequences
|
||||
vocab: the vocabulary for the sequences
|
||||
pattern: the pattern of the sequence. The length of the sequence will be
|
||||
inferred from the pattern.
|
||||
n: number of sequences to generate
|
||||
seed: random seed for numpy.random
|
||||
|
||||
Returns:
|
||||
A list of different random sequences from the given vocabulary
|
||||
A sequence of different random sequences from the given vocabulary
|
||||
"""
|
||||
|
||||
def count_placeholder(pattern):
|
||||
return sum([x[1] is not None for x in string.Formatter().parse(pattern)])
|
||||
|
||||
length = count_placeholder(pattern)
|
||||
np.random.seed(seed)
|
||||
rng = np.random.RandomState(seed)
|
||||
vocab_size = len(vocab)
|
||||
if vocab_size**length <= n:
|
||||
# Generate all possible sequences of the length
|
||||
|
@ -56,60 +55,71 @@ def generate_random_sequences(vocab: List[str],
|
|||
idx = np.empty((0, length), dtype=int)
|
||||
while idx.shape[0] < n:
|
||||
# Generate a new set of indices
|
||||
idx_new = np.random.randint(vocab_size, size=(n, length))
|
||||
idx_new = rng.randint(vocab_size, size=(n, length))
|
||||
idx = np.concatenate([idx, idx_new], axis=0) # Add to existing indices
|
||||
idx = np.unique(idx, axis=0) # Remove duplicated indices
|
||||
idx = idx[:n]
|
||||
seq = np.array(vocab)[idx]
|
||||
# Join each row to get the sequence
|
||||
seq = np.apply_along_axis(lambda x: pattern.format(*list(x)), 1, seq)
|
||||
seq = seq[np.random.permutation(n)]
|
||||
seq = seq[rng.permutation(n)]
|
||||
return list(seq)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TextSecretProperties:
|
||||
"""Properties of text secret.
|
||||
|
||||
vocab: the vocabulary for the secrets
|
||||
pattern: the pattern of the secrets
|
||||
"""
|
||||
vocab: Sequence[str]
|
||||
pattern: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SecretConfig:
|
||||
"""Configuration of secret for secrets sharer.
|
||||
|
||||
vocab: a list, the vocabulary for the secrets
|
||||
pattern: the pattern of the secrets
|
||||
num_repetitions: a list, number of repetitions for the secrets
|
||||
num_secrets_for_repetitions: a list, number of secrets to be used for
|
||||
different number of repetitions
|
||||
num_repetitions: numbers of repetitions for the secrets
|
||||
num_secrets_for_repetitions: numbers of secrets to be used for each
|
||||
number of repetitions
|
||||
num_references: number of references sequences, i.e. random sequences that
|
||||
will not be inserted into training data
|
||||
name: name that identifies the secrets set
|
||||
properties: any properties of the secret, e.g. the vocabulary, the pattern
|
||||
"""
|
||||
vocab: List[str]
|
||||
pattern: str
|
||||
num_repetitions: List[int]
|
||||
num_secrets_for_repetitions: List[int]
|
||||
num_repetitions: Sequence[int]
|
||||
num_secrets_for_repetitions: Sequence[int]
|
||||
num_references: int
|
||||
name: str = ''
|
||||
properties: Any = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Secrets:
|
||||
"""Secrets for secrets sharer.
|
||||
class SecretsSet:
|
||||
"""A secrets set for secrets sharer.
|
||||
|
||||
config: configuration of the secrets
|
||||
secrets: a dictionary, key is the number of repetitions, value is a list of
|
||||
different secrets
|
||||
references: a list of references
|
||||
secrets: a dictionary, key is the number of repetitions, value is a sequence
|
||||
of different secrets
|
||||
references: a sequence of references
|
||||
"""
|
||||
config: SecretConfig
|
||||
secrets: Dict[int, List[str]]
|
||||
references: List[str]
|
||||
secrets: Dict[int, Sequence[Any]]
|
||||
references: Sequence[Any]
|
||||
|
||||
|
||||
def construct_secret(secret_config: SecretConfig, seqs: List[str]) -> Secrets:
|
||||
"""Construct a secret instance.
|
||||
def construct_secret(secret_config: SecretConfig,
|
||||
seqs: Sequence[Any]) -> SecretsSet:
|
||||
"""Constructs a SecretsSet instance given a sequence of samples.
|
||||
|
||||
Args:
|
||||
secret_config: configuration of secret.
|
||||
seqs: a list of random sequences that will be used for secrets and
|
||||
references.
|
||||
seqs: a sequence of samples that will be used for secrets and references.
|
||||
|
||||
Returns:
|
||||
a secret instance.
|
||||
a SecretsSet instance.
|
||||
"""
|
||||
if len(seqs) < sum(
|
||||
secret_config.num_secrets_for_repetitions) + secret_config.num_references:
|
||||
|
@ -120,32 +130,51 @@ def construct_secret(secret_config: SecretConfig, seqs: List[str]) -> Secrets:
|
|||
secret_config.num_repetitions, secret_config.num_secrets_for_repetitions):
|
||||
secrets[num_repetition] = seqs[i:i + num_secrets]
|
||||
i += num_secrets
|
||||
return Secrets(
|
||||
return SecretsSet(
|
||||
config=secret_config,
|
||||
secrets=secrets,
|
||||
references=seqs[-secret_config.num_references:])
|
||||
|
||||
|
||||
def generate_secrets_and_references(secret_configs: List[SecretConfig],
|
||||
seed: int = 0) -> List[Secrets]:
|
||||
"""Generate a list of secret instances given a list of configurations.
|
||||
def generate_text_secrets_and_references(
|
||||
secret_configs: Sequence[SecretConfig],
|
||||
seed: int = 0) -> MutableSequence[SecretsSet]:
|
||||
"""Generates a sequence of text secret sets given a sequence of configurations.
|
||||
|
||||
Args:
|
||||
secret_configs: a list of secret configurations.
|
||||
secret_configs: a sequence of text secret configurations.
|
||||
seed: random seed.
|
||||
|
||||
Returns:
|
||||
A list of secret instances.
|
||||
A sequence of SecretsSet instances.
|
||||
"""
|
||||
secrets = []
|
||||
secrets_sets = []
|
||||
for i, secret_config in enumerate(secret_configs):
|
||||
n = secret_config.num_references + sum(
|
||||
secret_config.num_secrets_for_repetitions)
|
||||
seqs = generate_random_sequences(secret_config.vocab, secret_config.pattern,
|
||||
n, seed + i)
|
||||
seqs = generate_random_sequences(secret_config.properties.vocab,
|
||||
secret_config.properties.pattern, n,
|
||||
seed + i)
|
||||
if len(seqs) < n:
|
||||
raise ValueError(
|
||||
f'generate_random_sequences was not able to generate {n} sequences. Need to increase vocabulary.'
|
||||
)
|
||||
secrets.append(construct_secret(secret_config, seqs))
|
||||
return secrets
|
||||
secrets_sets.append(construct_secret(secret_config, seqs))
|
||||
return secrets_sets
|
||||
|
||||
|
||||
def construct_secret_dataset(
|
||||
secrets_sets: Sequence[SecretsSet]) -> MutableSequence[Any]:
|
||||
"""Repeats secrets for the required number of times to get a secret dataset.
|
||||
|
||||
Args:
|
||||
secrets_sets: a sequence of secert sets.
|
||||
|
||||
Returns:
|
||||
A sequence of samples.
|
||||
"""
|
||||
secrets_dataset = []
|
||||
for secrets_set in secrets_sets:
|
||||
for r, seqs in secrets_set.secrets.items():
|
||||
secrets_dataset += list(seqs) * r
|
||||
return secrets_dataset
|
||||
|
|
|
@ -13,10 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import construct_secret
|
||||
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_random_sequences
|
||||
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_secrets_and_references
|
||||
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig
|
||||
from tensorflow_privacy.privacy.privacy_tests.secret_sharer import generate_secrets as gs
|
||||
|
||||
|
||||
class UtilsTest(absltest.TestCase):
|
||||
|
@ -28,14 +25,15 @@ class UtilsTest(absltest.TestCase):
|
|||
def test_generate_random_sequences(self):
|
||||
"""Test generate_random_sequences."""
|
||||
# Test when n is larger than total number of possible sequences.
|
||||
seqs = generate_random_sequences(['A', 'b', 'c'], '{}+{}', 10, seed=27)
|
||||
seqs = gs.generate_random_sequences(['A', 'b', 'c'], '{}+{}', 10, seed=27)
|
||||
expected_seqs = [
|
||||
'A+c', 'c+c', 'b+b', 'A+b', 'b+c', 'c+A', 'c+b', 'A+A', 'b+A'
|
||||
]
|
||||
self.assertEqual(seqs, expected_seqs)
|
||||
|
||||
# Test when n is smaller than total number of possible sequences.
|
||||
seqs = generate_random_sequences(list('01234'), 'prefix {}{}{}?', 8, seed=9)
|
||||
seqs = gs.generate_random_sequences(
|
||||
list('01234'), 'prefix {}{}{}?', 8, seed=9)
|
||||
expected_seqs = [
|
||||
'prefix 143?', 'prefix 031?', 'prefix 302?', 'prefix 042?',
|
||||
'prefix 404?', 'prefix 024?', 'prefix 021?', 'prefix 403?'
|
||||
|
@ -43,14 +41,14 @@ class UtilsTest(absltest.TestCase):
|
|||
self.assertEqual(seqs, expected_seqs)
|
||||
|
||||
def test_construct_secret(self):
|
||||
secret_config = SecretConfig(
|
||||
vocab=None,
|
||||
pattern='',
|
||||
secret_config = gs.SecretConfig(
|
||||
num_repetitions=[1, 2, 8],
|
||||
num_secrets_for_repetitions=[2, 3, 1],
|
||||
num_references=3)
|
||||
num_references=3,
|
||||
name='random secrets',
|
||||
properties=gs.TextSecretProperties(vocab=None, pattern=''))
|
||||
seqs = list('0123456789')
|
||||
secrets = construct_secret(secret_config, seqs)
|
||||
secrets = gs.construct_secret(secret_config, seqs)
|
||||
self.assertEqual(secrets.config, secret_config)
|
||||
self.assertDictEqual(secrets.secrets, {
|
||||
1: ['0', '1'],
|
||||
|
@ -61,24 +59,29 @@ class UtilsTest(absltest.TestCase):
|
|||
|
||||
# Test when the number of elements in seqs is not enough.
|
||||
seqs = list('01234567')
|
||||
self.assertRaises(ValueError, construct_secret, secret_config, seqs)
|
||||
self.assertRaises(ValueError, gs.construct_secret, secret_config, seqs)
|
||||
|
||||
def test_generate_secrets_and_references(self):
|
||||
secret_configs = [
|
||||
SecretConfig(
|
||||
vocab=['w1', 'w2', 'w3'],
|
||||
pattern='{} {} suf',
|
||||
gs.SecretConfig(
|
||||
num_repetitions=[1, 12],
|
||||
num_secrets_for_repetitions=[2, 1],
|
||||
num_references=3),
|
||||
SecretConfig(
|
||||
vocab=['W 1', 'W 2', 'W 3'],
|
||||
pattern='{}-{}',
|
||||
num_references=3,
|
||||
name='secret1',
|
||||
properties=gs.TextSecretProperties(
|
||||
vocab=['w1', 'w2', 'w3'], pattern='{} {} suf'),
|
||||
),
|
||||
gs.SecretConfig(
|
||||
num_repetitions=[1, 2, 8],
|
||||
num_secrets_for_repetitions=[2, 3, 1],
|
||||
num_references=3)
|
||||
num_references=3,
|
||||
name='secert2',
|
||||
properties=gs.TextSecretProperties(
|
||||
vocab=['W 1', 'W 2', 'W 3'],
|
||||
pattern='{}-{}',
|
||||
))
|
||||
]
|
||||
secrets = generate_secrets_and_references(secret_configs, seed=27)
|
||||
secrets = gs.generate_text_secrets_and_references(secret_configs, seed=27)
|
||||
self.assertEqual(secrets[0].config, secret_configs[0])
|
||||
self.assertDictEqual(secrets[0].secrets, {
|
||||
1: ['w3 w2 suf', 'w2 w1 suf'],
|
||||
|
|
|
@ -115,7 +115,7 @@
|
|||
"import tensorflow as tf\n",
|
||||
"from official.utils.misc import keras_utils\n",
|
||||
"\n",
|
||||
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig, generate_secrets_and_references, construct_secret\n",
|
||||
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig, TextSecretProperties, generate_text_secrets_and_references, construct_secret\n",
|
||||
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.exposures import compute_exposure_interpolation, compute_exposure_extrapolation"
|
||||
]
|
||||
},
|
||||
|
@ -430,10 +430,10 @@
|
|||
"num_repetitions = [1, 10, 100]\n",
|
||||
"num_secrets_for_repetitions = [20] * len(num_repetitions)\n",
|
||||
"num_references = 65536\n",
|
||||
"secret_configs = [SecretConfig(vocab, pattern, num_repetitions,\n",
|
||||
" num_secrets_for_repetitions, num_references)\n",
|
||||
"secret_configs = [SecretConfig(num_repetitions, num_secrets_for_repetitions, num_references,\n",
|
||||
" properties=TextSecretProperties(vocab, pattern))\n",
|
||||
" for vocab, pattern in zip(vocabs, patterns)]\n",
|
||||
"secrets = generate_secrets_and_references(secret_configs)\n",
|
||||
"secrets = generate_text_secrets_and_references(secret_configs)\n",
|
||||
"\n",
|
||||
"# Let's look at the variable \"secrets\"\n",
|
||||
"print(f'\"secrets\" is a list and the length is {len(secrets)} because we have four sets of secrets.')\n",
|
||||
|
|
Loading…
Reference in a new issue