In secret generation for secret sharer, use np.random.RandomState. Restructure generate_secrets.
PiperOrigin-RevId: 430082580
This commit is contained in:
parent
89de03e0db
commit
04dd758c8a
3 changed files with 99 additions and 67 deletions
|
@ -16,33 +16,32 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import itertools
|
import itertools
|
||||||
import string
|
import string
|
||||||
from typing import Dict, List
|
from typing import Any, Dict, MutableSequence, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def generate_random_sequences(vocab: List[str],
|
def generate_random_sequences(vocab: Sequence[str],
|
||||||
pattern: str,
|
pattern: str,
|
||||||
n: int,
|
n: int,
|
||||||
seed: int = 1) -> List[str]:
|
seed: int = 1) -> MutableSequence[str]:
|
||||||
"""Generate random sequences.
|
"""Generates random sequences.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab: a list, the vocabulary for the sequences
|
vocab: the vocabulary for the sequences
|
||||||
pattern: the pattern of the sequence. The length of the sequence will be
|
pattern: the pattern of the sequence. The length of the sequence will be
|
||||||
inferred from the pattern.
|
inferred from the pattern.
|
||||||
n: number of sequences to generate
|
n: number of sequences to generate
|
||||||
seed: random seed for numpy.random
|
seed: random seed for numpy.random
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of different random sequences from the given vocabulary
|
A sequence of different random sequences from the given vocabulary
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def count_placeholder(pattern):
|
def count_placeholder(pattern):
|
||||||
return sum([x[1] is not None for x in string.Formatter().parse(pattern)])
|
return sum([x[1] is not None for x in string.Formatter().parse(pattern)])
|
||||||
|
|
||||||
length = count_placeholder(pattern)
|
length = count_placeholder(pattern)
|
||||||
np.random.seed(seed)
|
rng = np.random.RandomState(seed)
|
||||||
vocab_size = len(vocab)
|
vocab_size = len(vocab)
|
||||||
if vocab_size**length <= n:
|
if vocab_size**length <= n:
|
||||||
# Generate all possible sequences of the length
|
# Generate all possible sequences of the length
|
||||||
|
@ -56,60 +55,71 @@ def generate_random_sequences(vocab: List[str],
|
||||||
idx = np.empty((0, length), dtype=int)
|
idx = np.empty((0, length), dtype=int)
|
||||||
while idx.shape[0] < n:
|
while idx.shape[0] < n:
|
||||||
# Generate a new set of indices
|
# Generate a new set of indices
|
||||||
idx_new = np.random.randint(vocab_size, size=(n, length))
|
idx_new = rng.randint(vocab_size, size=(n, length))
|
||||||
idx = np.concatenate([idx, idx_new], axis=0) # Add to existing indices
|
idx = np.concatenate([idx, idx_new], axis=0) # Add to existing indices
|
||||||
idx = np.unique(idx, axis=0) # Remove duplicated indices
|
idx = np.unique(idx, axis=0) # Remove duplicated indices
|
||||||
idx = idx[:n]
|
idx = idx[:n]
|
||||||
seq = np.array(vocab)[idx]
|
seq = np.array(vocab)[idx]
|
||||||
# Join each row to get the sequence
|
# Join each row to get the sequence
|
||||||
seq = np.apply_along_axis(lambda x: pattern.format(*list(x)), 1, seq)
|
seq = np.apply_along_axis(lambda x: pattern.format(*list(x)), 1, seq)
|
||||||
seq = seq[np.random.permutation(n)]
|
seq = seq[rng.permutation(n)]
|
||||||
return list(seq)
|
return list(seq)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TextSecretProperties:
|
||||||
|
"""Properties of text secret.
|
||||||
|
|
||||||
|
vocab: the vocabulary for the secrets
|
||||||
|
pattern: the pattern of the secrets
|
||||||
|
"""
|
||||||
|
vocab: Sequence[str]
|
||||||
|
pattern: str
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class SecretConfig:
|
class SecretConfig:
|
||||||
"""Configuration of secret for secrets sharer.
|
"""Configuration of secret for secrets sharer.
|
||||||
|
|
||||||
vocab: a list, the vocabulary for the secrets
|
num_repetitions: numbers of repetitions for the secrets
|
||||||
pattern: the pattern of the secrets
|
num_secrets_for_repetitions: numbers of secrets to be used for each
|
||||||
num_repetitions: a list, number of repetitions for the secrets
|
number of repetitions
|
||||||
num_secrets_for_repetitions: a list, number of secrets to be used for
|
|
||||||
different number of repetitions
|
|
||||||
num_references: number of references sequences, i.e. random sequences that
|
num_references: number of references sequences, i.e. random sequences that
|
||||||
will not be inserted into training data
|
will not be inserted into training data
|
||||||
|
name: name that identifies the secrets set
|
||||||
|
properties: any properties of the secret, e.g. the vocabulary, the pattern
|
||||||
"""
|
"""
|
||||||
vocab: List[str]
|
num_repetitions: Sequence[int]
|
||||||
pattern: str
|
num_secrets_for_repetitions: Sequence[int]
|
||||||
num_repetitions: List[int]
|
|
||||||
num_secrets_for_repetitions: List[int]
|
|
||||||
num_references: int
|
num_references: int
|
||||||
|
name: str = ''
|
||||||
|
properties: Any = None
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Secrets:
|
class SecretsSet:
|
||||||
"""Secrets for secrets sharer.
|
"""A secrets set for secrets sharer.
|
||||||
|
|
||||||
config: configuration of the secrets
|
config: configuration of the secrets
|
||||||
secrets: a dictionary, key is the number of repetitions, value is a list of
|
secrets: a dictionary, key is the number of repetitions, value is a sequence
|
||||||
different secrets
|
of different secrets
|
||||||
references: a list of references
|
references: a sequence of references
|
||||||
"""
|
"""
|
||||||
config: SecretConfig
|
config: SecretConfig
|
||||||
secrets: Dict[int, List[str]]
|
secrets: Dict[int, Sequence[Any]]
|
||||||
references: List[str]
|
references: Sequence[Any]
|
||||||
|
|
||||||
|
|
||||||
def construct_secret(secret_config: SecretConfig, seqs: List[str]) -> Secrets:
|
def construct_secret(secret_config: SecretConfig,
|
||||||
"""Construct a secret instance.
|
seqs: Sequence[Any]) -> SecretsSet:
|
||||||
|
"""Constructs a SecretsSet instance given a sequence of samples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
secret_config: configuration of secret.
|
secret_config: configuration of secret.
|
||||||
seqs: a list of random sequences that will be used for secrets and
|
seqs: a sequence of samples that will be used for secrets and references.
|
||||||
references.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a secret instance.
|
a SecretsSet instance.
|
||||||
"""
|
"""
|
||||||
if len(seqs) < sum(
|
if len(seqs) < sum(
|
||||||
secret_config.num_secrets_for_repetitions) + secret_config.num_references:
|
secret_config.num_secrets_for_repetitions) + secret_config.num_references:
|
||||||
|
@ -120,32 +130,51 @@ def construct_secret(secret_config: SecretConfig, seqs: List[str]) -> Secrets:
|
||||||
secret_config.num_repetitions, secret_config.num_secrets_for_repetitions):
|
secret_config.num_repetitions, secret_config.num_secrets_for_repetitions):
|
||||||
secrets[num_repetition] = seqs[i:i + num_secrets]
|
secrets[num_repetition] = seqs[i:i + num_secrets]
|
||||||
i += num_secrets
|
i += num_secrets
|
||||||
return Secrets(
|
return SecretsSet(
|
||||||
config=secret_config,
|
config=secret_config,
|
||||||
secrets=secrets,
|
secrets=secrets,
|
||||||
references=seqs[-secret_config.num_references:])
|
references=seqs[-secret_config.num_references:])
|
||||||
|
|
||||||
|
|
||||||
def generate_secrets_and_references(secret_configs: List[SecretConfig],
|
def generate_text_secrets_and_references(
|
||||||
seed: int = 0) -> List[Secrets]:
|
secret_configs: Sequence[SecretConfig],
|
||||||
"""Generate a list of secret instances given a list of configurations.
|
seed: int = 0) -> MutableSequence[SecretsSet]:
|
||||||
|
"""Generates a sequence of text secret sets given a sequence of configurations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
secret_configs: a list of secret configurations.
|
secret_configs: a sequence of text secret configurations.
|
||||||
seed: random seed.
|
seed: random seed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of secret instances.
|
A sequence of SecretsSet instances.
|
||||||
"""
|
"""
|
||||||
secrets = []
|
secrets_sets = []
|
||||||
for i, secret_config in enumerate(secret_configs):
|
for i, secret_config in enumerate(secret_configs):
|
||||||
n = secret_config.num_references + sum(
|
n = secret_config.num_references + sum(
|
||||||
secret_config.num_secrets_for_repetitions)
|
secret_config.num_secrets_for_repetitions)
|
||||||
seqs = generate_random_sequences(secret_config.vocab, secret_config.pattern,
|
seqs = generate_random_sequences(secret_config.properties.vocab,
|
||||||
n, seed + i)
|
secret_config.properties.pattern, n,
|
||||||
|
seed + i)
|
||||||
if len(seqs) < n:
|
if len(seqs) < n:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'generate_random_sequences was not able to generate {n} sequences. Need to increase vocabulary.'
|
f'generate_random_sequences was not able to generate {n} sequences. Need to increase vocabulary.'
|
||||||
)
|
)
|
||||||
secrets.append(construct_secret(secret_config, seqs))
|
secrets_sets.append(construct_secret(secret_config, seqs))
|
||||||
return secrets
|
return secrets_sets
|
||||||
|
|
||||||
|
|
||||||
|
def construct_secret_dataset(
|
||||||
|
secrets_sets: Sequence[SecretsSet]) -> MutableSequence[Any]:
|
||||||
|
"""Repeats secrets for the required number of times to get a secret dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secrets_sets: a sequence of secert sets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sequence of samples.
|
||||||
|
"""
|
||||||
|
secrets_dataset = []
|
||||||
|
for secrets_set in secrets_sets:
|
||||||
|
for r, seqs in secrets_set.secrets.items():
|
||||||
|
secrets_dataset += list(seqs) * r
|
||||||
|
return secrets_dataset
|
||||||
|
|
|
@ -13,10 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import construct_secret
|
from tensorflow_privacy.privacy.privacy_tests.secret_sharer import generate_secrets as gs
|
||||||
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_random_sequences
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_secrets_and_references
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig
|
|
||||||
|
|
||||||
|
|
||||||
class UtilsTest(absltest.TestCase):
|
class UtilsTest(absltest.TestCase):
|
||||||
|
@ -28,14 +25,15 @@ class UtilsTest(absltest.TestCase):
|
||||||
def test_generate_random_sequences(self):
|
def test_generate_random_sequences(self):
|
||||||
"""Test generate_random_sequences."""
|
"""Test generate_random_sequences."""
|
||||||
# Test when n is larger than total number of possible sequences.
|
# Test when n is larger than total number of possible sequences.
|
||||||
seqs = generate_random_sequences(['A', 'b', 'c'], '{}+{}', 10, seed=27)
|
seqs = gs.generate_random_sequences(['A', 'b', 'c'], '{}+{}', 10, seed=27)
|
||||||
expected_seqs = [
|
expected_seqs = [
|
||||||
'A+c', 'c+c', 'b+b', 'A+b', 'b+c', 'c+A', 'c+b', 'A+A', 'b+A'
|
'A+c', 'c+c', 'b+b', 'A+b', 'b+c', 'c+A', 'c+b', 'A+A', 'b+A'
|
||||||
]
|
]
|
||||||
self.assertEqual(seqs, expected_seqs)
|
self.assertEqual(seqs, expected_seqs)
|
||||||
|
|
||||||
# Test when n is smaller than total number of possible sequences.
|
# Test when n is smaller than total number of possible sequences.
|
||||||
seqs = generate_random_sequences(list('01234'), 'prefix {}{}{}?', 8, seed=9)
|
seqs = gs.generate_random_sequences(
|
||||||
|
list('01234'), 'prefix {}{}{}?', 8, seed=9)
|
||||||
expected_seqs = [
|
expected_seqs = [
|
||||||
'prefix 143?', 'prefix 031?', 'prefix 302?', 'prefix 042?',
|
'prefix 143?', 'prefix 031?', 'prefix 302?', 'prefix 042?',
|
||||||
'prefix 404?', 'prefix 024?', 'prefix 021?', 'prefix 403?'
|
'prefix 404?', 'prefix 024?', 'prefix 021?', 'prefix 403?'
|
||||||
|
@ -43,14 +41,14 @@ class UtilsTest(absltest.TestCase):
|
||||||
self.assertEqual(seqs, expected_seqs)
|
self.assertEqual(seqs, expected_seqs)
|
||||||
|
|
||||||
def test_construct_secret(self):
|
def test_construct_secret(self):
|
||||||
secret_config = SecretConfig(
|
secret_config = gs.SecretConfig(
|
||||||
vocab=None,
|
|
||||||
pattern='',
|
|
||||||
num_repetitions=[1, 2, 8],
|
num_repetitions=[1, 2, 8],
|
||||||
num_secrets_for_repetitions=[2, 3, 1],
|
num_secrets_for_repetitions=[2, 3, 1],
|
||||||
num_references=3)
|
num_references=3,
|
||||||
|
name='random secrets',
|
||||||
|
properties=gs.TextSecretProperties(vocab=None, pattern=''))
|
||||||
seqs = list('0123456789')
|
seqs = list('0123456789')
|
||||||
secrets = construct_secret(secret_config, seqs)
|
secrets = gs.construct_secret(secret_config, seqs)
|
||||||
self.assertEqual(secrets.config, secret_config)
|
self.assertEqual(secrets.config, secret_config)
|
||||||
self.assertDictEqual(secrets.secrets, {
|
self.assertDictEqual(secrets.secrets, {
|
||||||
1: ['0', '1'],
|
1: ['0', '1'],
|
||||||
|
@ -61,24 +59,29 @@ class UtilsTest(absltest.TestCase):
|
||||||
|
|
||||||
# Test when the number of elements in seqs is not enough.
|
# Test when the number of elements in seqs is not enough.
|
||||||
seqs = list('01234567')
|
seqs = list('01234567')
|
||||||
self.assertRaises(ValueError, construct_secret, secret_config, seqs)
|
self.assertRaises(ValueError, gs.construct_secret, secret_config, seqs)
|
||||||
|
|
||||||
def test_generate_secrets_and_references(self):
|
def test_generate_secrets_and_references(self):
|
||||||
secret_configs = [
|
secret_configs = [
|
||||||
SecretConfig(
|
gs.SecretConfig(
|
||||||
vocab=['w1', 'w2', 'w3'],
|
|
||||||
pattern='{} {} suf',
|
|
||||||
num_repetitions=[1, 12],
|
num_repetitions=[1, 12],
|
||||||
num_secrets_for_repetitions=[2, 1],
|
num_secrets_for_repetitions=[2, 1],
|
||||||
num_references=3),
|
num_references=3,
|
||||||
SecretConfig(
|
name='secret1',
|
||||||
vocab=['W 1', 'W 2', 'W 3'],
|
properties=gs.TextSecretProperties(
|
||||||
pattern='{}-{}',
|
vocab=['w1', 'w2', 'w3'], pattern='{} {} suf'),
|
||||||
|
),
|
||||||
|
gs.SecretConfig(
|
||||||
num_repetitions=[1, 2, 8],
|
num_repetitions=[1, 2, 8],
|
||||||
num_secrets_for_repetitions=[2, 3, 1],
|
num_secrets_for_repetitions=[2, 3, 1],
|
||||||
num_references=3)
|
num_references=3,
|
||||||
|
name='secert2',
|
||||||
|
properties=gs.TextSecretProperties(
|
||||||
|
vocab=['W 1', 'W 2', 'W 3'],
|
||||||
|
pattern='{}-{}',
|
||||||
|
))
|
||||||
]
|
]
|
||||||
secrets = generate_secrets_and_references(secret_configs, seed=27)
|
secrets = gs.generate_text_secrets_and_references(secret_configs, seed=27)
|
||||||
self.assertEqual(secrets[0].config, secret_configs[0])
|
self.assertEqual(secrets[0].config, secret_configs[0])
|
||||||
self.assertDictEqual(secrets[0].secrets, {
|
self.assertDictEqual(secrets[0].secrets, {
|
||||||
1: ['w3 w2 suf', 'w2 w1 suf'],
|
1: ['w3 w2 suf', 'w2 w1 suf'],
|
||||||
|
|
|
@ -115,7 +115,7 @@
|
||||||
"import tensorflow as tf\n",
|
"import tensorflow as tf\n",
|
||||||
"from official.utils.misc import keras_utils\n",
|
"from official.utils.misc import keras_utils\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig, generate_secrets_and_references, construct_secret\n",
|
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig, TextSecretProperties, generate_text_secrets_and_references, construct_secret\n",
|
||||||
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.exposures import compute_exposure_interpolation, compute_exposure_extrapolation"
|
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.exposures import compute_exposure_interpolation, compute_exposure_extrapolation"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -430,10 +430,10 @@
|
||||||
"num_repetitions = [1, 10, 100]\n",
|
"num_repetitions = [1, 10, 100]\n",
|
||||||
"num_secrets_for_repetitions = [20] * len(num_repetitions)\n",
|
"num_secrets_for_repetitions = [20] * len(num_repetitions)\n",
|
||||||
"num_references = 65536\n",
|
"num_references = 65536\n",
|
||||||
"secret_configs = [SecretConfig(vocab, pattern, num_repetitions,\n",
|
"secret_configs = [SecretConfig(num_repetitions, num_secrets_for_repetitions, num_references,\n",
|
||||||
" num_secrets_for_repetitions, num_references)\n",
|
" properties=TextSecretProperties(vocab, pattern))\n",
|
||||||
" for vocab, pattern in zip(vocabs, patterns)]\n",
|
" for vocab, pattern in zip(vocabs, patterns)]\n",
|
||||||
"secrets = generate_secrets_and_references(secret_configs)\n",
|
"secrets = generate_text_secrets_and_references(secret_configs)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Let's look at the variable \"secrets\"\n",
|
"# Let's look at the variable \"secrets\"\n",
|
||||||
"print(f'\"secrets\" is a list and the length is {len(secrets)} because we have four sets of secrets.')\n",
|
"print(f'\"secrets\" is a list and the length is {len(secrets)} because we have four sets of secrets.')\n",
|
||||||
|
|
Loading…
Reference in a new issue