Internal change.

PiperOrigin-RevId: 382171367
This commit is contained in:
Shuang Song 2021-06-29 14:53:54 -07:00 committed by A. Unique TensorFlower
parent 94f03d09f3
commit 0caa10f674
5 changed files with 423 additions and 0 deletions

View file

@ -0,0 +1,37 @@
# Secret Sharer Attack
A good privacy-preserving model learns from the training data, but
doesn't memorize it.
This folder contains codes for conducting the Secret Sharer attack from [this paper](https://arxiv.org/abs/1802.08232).
It is a method to test if a machine learning model memorizes its training data.
The high level idea is to insert some random sequences as “secrets” into the
training data, and then measure if the model has memorized those secrets.
If there is significant memorization, it means that there can be potential
privacy risk.
## How to Use
### Overview of the files
- `generate_secrets.py` contains the code for generating secrets.
- `exposures.py` contains code for evaluating exposures.
- `secret_sharer_example.ipynb` is an example (character-level LSTM) for using
the above code to conduct secret sharer attack.
### Contact / Feedback
Fill out this
[Google form](https://docs.google.com/forms/d/1DPwr3_OfMcqAOA6sdelTVjIZhKxMZkXvs94z16UCDa4/edit)
or reach out to us at tf-privacy@google.com and let us know how youre using
this module. Were keen on hearing your stories, feedback, and suggestions!
## Contributing
If you wish to add novel attacks to the attack library, please check our
[guidelines](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/CONTRIBUTING.md).
## Copyright
Copyright 2021 - Google LLC

View file

@ -0,0 +1,80 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Measuring exposure for secret sharer attack."""
from typing import Dict, List
import numpy as np
from scipy.stats import skewnorm
def compute_exposure_interpolation(
perplexities: Dict[int, List[float]],
perplexities_reference: List[float]) -> Dict[int, List[float]]:
"""Get exposure using interpolation.
Args:
perplexities: a dictionary, key is number of secret repetitions,
value is a list of perplexities
perplexities_reference: a list, perplexities of the random sequences that
did not appear in the training data
Returns:
The exposure of every secret measured using interpolation (not necessarily
in the same order as the input)
"""
repetitions = list(perplexities.keys())
# Concatenate all perplexities, including those for references
perplexities_concat = np.concatenate([perplexities[r] for r in repetitions]
+ [perplexities_reference])
# Concatenate the number of repetitions for each secret
repetitions_concat = np.concatenate(
[[r] * len(perplexities[r]) for r in repetitions]
+ [[0] * len(perplexities_reference)])
# Sort the repetition list according to the corresponding perplexity
idx = np.argsort(perplexities_concat)
repetitions_concat = repetitions_concat[idx]
# In the sorted repetition list, if there are m examples with repetition 0
# (does not appear in training) in front of an example, then its rank is
# (m + 1). To get the number of examples with repetition 0 in front of
# any example, we use the cummulative sum of the indicator vecotr
# (repetitions_concat == 0).
cum_sum = np.cumsum(repetitions_concat == 0)
ranks = {r: cum_sum[repetitions_concat == r] + 1 for r in repetitions}
exposures = {r: np.log2(len(perplexities_reference)) - np.log2(ranks[r])
for r in repetitions}
return exposures
def compute_exposure_extrapolation(
perplexities: Dict[int, List[float]],
perplexities_reference: List[float]) -> Dict[int, List[float]]:
"""Get exposure using extrapolation.
Args:
perplexities: a dictionary, key is number of secret repetitions,
value is a list of perplexities
perplexities_reference: a list, perplexities of the random sequences that
did not appear in the training data
Returns:
The exposure of every secret measured using extrapolation
"""
# Fit a skew normal distribution using the perplexities of the references
snormal_param = skewnorm.fit(perplexities_reference)
# Estimate exposure using the fitted distribution
exposures = {r: -np.log2(skewnorm.cdf(perplexities[r], *snormal_param))
for r in perplexities.keys()}
return exposures

View file

@ -0,0 +1,71 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for tensorflow_privacy.privacy.secret_sharer.exposures."""
from absl.testing import absltest
import numpy as np
from scipy.stats import skewnorm
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.exposures import compute_exposure_extrapolation
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.exposures import compute_exposure_interpolation
class UtilsTest(absltest.TestCase):
def __init__(self, methodname):
"""Initialize the test class."""
super().__init__(methodname)
def test_exposure_interpolation(self):
"""Test exposure by interpolation."""
perplexities = {1: [0, 0.1], # smallest perplexities
2: [20.0], # largest perplexities
5: [3.5]} # rank = 4
perplexities_reference = [float(x) for x in range(1, 17)]
exposures = compute_exposure_interpolation(perplexities,
perplexities_reference)
num_perplexities_reference = len(perplexities_reference)
exposure_largest = np.log2(num_perplexities_reference)
exposure_smallest = np.log2(num_perplexities_reference) - np.log2(
num_perplexities_reference + 1)
expected_exposures = {
1: np.array([exposure_largest] * 2),
2: np.array([exposure_smallest]),
5: np.array([np.log2(num_perplexities_reference) - np.log2(4)])}
self.assertEqual(exposures.keys(), expected_exposures.keys())
for r in exposures.keys():
np.testing.assert_almost_equal(exposures[r], exposures[r])
def test_exposure_extrapolation(self):
parameters = (4, 0, 1)
perplexities = {1: skewnorm.rvs(*parameters, size=(2,)),
10: skewnorm.rvs(*parameters, size=(5,))}
perplexities_reference = skewnorm.rvs(*parameters, size=(10000,))
exposures = compute_exposure_extrapolation(perplexities,
perplexities_reference)
fitted_parameters = skewnorm.fit(perplexities_reference)
self.assertEqual(exposures.keys(), perplexities.keys())
for r in exposures.keys():
np.testing.assert_almost_equal(
exposures[r],
-np.log2(skewnorm.cdf(perplexities[r], *fitted_parameters)))
if __name__ == '__main__':
absltest.main()

View file

@ -0,0 +1,145 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate random sequences."""
import itertools
import string
from typing import Dict, List
from dataclasses import dataclass
import numpy as np
def generate_random_sequences(vocab: List[str], pattern: str, n: int,
seed: int = 1) -> List[str]:
"""Generate random sequences.
Args:
vocab: a list, the vocabulary for the sequences
pattern: the pattern of the sequence. The length of the sequence will be
inferred from the pattern.
n: number of sequences to generate
seed: random seed for numpy.random
Returns:
A list of different random sequences from the given vocabulary
"""
def count_placeholder(pattern):
return sum([x[1] is not None for x in string.Formatter().parse(pattern)])
length = count_placeholder(pattern)
np.random.seed(seed)
vocab_size = len(vocab)
if vocab_size**length <= n:
# Generate all possible sequences of the length
seq = np.array(list(itertools.product(vocab, repeat=length)))
if vocab_size**length < n:
print(f'The total number of random sequences is less than n={n}.',
f'Will return {vocab_size**length} sequences only.')
n = vocab_size**length
else:
# Generate idx where each row contains the indices for one random sequence
idx = np.empty((0, length), dtype=int)
while idx.shape[0] < n:
# Generate a new set of indices
idx_new = np.random.randint(vocab_size, size=(n, length))
idx = np.concatenate([idx, idx_new], axis=0) # Add to existing indices
idx = np.unique(idx, axis=0) # Remove duplicated indices
idx = idx[:n]
seq = np.array(vocab)[idx]
# Join each row to get the sequence
seq = np.apply_along_axis(lambda x: pattern.format(*list(x)), 1, seq)
seq = seq[np.random.permutation(n)]
return list(seq)
@dataclass
class SecretConfig:
"""Configuration of secret for secrets sharer.
vocab: a list, the vocabulary for the secrets
pattern: the pattern of the secrets
num_repetitions: a list, number of repetitions for the secrets
num_secrets_for_repetitions: a list, number of secrets to be used for
different number of repetitions
num_references: number of references sequences, i.e. random sequences that
will not be inserted into training data
"""
vocab: List[str]
pattern: str
num_repetitions: List[int]
num_secrets_for_repetitions: List[int]
num_references: int
@dataclass
class Secrets:
"""Secrets for secrets sharer.
config: configuration of the secrets
secrets: a dictionary, key is the number of repetitions, value is a list of
different secrets
references: a list of references
"""
config: SecretConfig
secrets: Dict[int, List[str]]
references: List[str]
def construct_secret(secret_config: SecretConfig, seqs: List[str]) -> Secrets:
"""Construct a secret instance.
Args:
secret_config: configuration of secret.
seqs: a list of random sequences that will be used for secrets and
references.
Returns:
a secret instance.
"""
if len(seqs) < sum(
secret_config.num_secrets_for_repetitions) + secret_config.num_references:
raise ValueError('seqs does not contain enough elements.')
secrets = {}
i = 0
for num_repetition, num_secrets in zip(
secret_config.num_repetitions, secret_config.num_secrets_for_repetitions):
secrets[num_repetition] = seqs[i:i + num_secrets]
i += num_secrets
return Secrets(config=secret_config,
secrets=secrets,
references=seqs[-secret_config.num_references:])
def generate_secrets_and_references(secret_configs: List[SecretConfig],
seed: int = 0) -> List[Secrets]:
"""Generate a list of secret instances given a list of configurations.
Args:
secret_configs: a list of secret configurations.
seed: random seed.
Returns:
A list of secret instances.
"""
secrets = []
for i, secret_config in enumerate(secret_configs):
n = secret_config.num_references + sum(
secret_config.num_secrets_for_repetitions)
seqs = generate_random_sequences(secret_config.vocab, secret_config.pattern,
n, seed + i)
if len(seqs) < n:
raise ValueError(
f'generate_random_sequences was not able to generate {n} sequences. Need to increase vocabulary.'
)
secrets.append(construct_secret(secret_config, seqs))
return secrets

View file

@ -0,0 +1,90 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets."""
from absl.testing import absltest
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import construct_secret
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_random_sequences
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import generate_secrets_and_references
from tensorflow_privacy.privacy.privacy_tests.secret_sharer.generate_secrets import SecretConfig
class UtilsTest(absltest.TestCase):
def __init__(self, methodname):
"""Initialize the test class."""
super().__init__(methodname)
def test_generate_random_sequences(self):
"""Test generate_random_sequences."""
# Test when n is larger than total number of possible sequences.
seqs = generate_random_sequences(['A', 'b', 'c'], '{}+{}', 10, seed=27)
expected_seqs = ['A+c', 'c+c', 'b+b', 'A+b', 'b+c',
'c+A', 'c+b', 'A+A', 'b+A']
self.assertEqual(seqs, expected_seqs)
# Test when n is smaller than total number of possible sequences.
seqs = generate_random_sequences(list('01234'), 'prefix {}{}{}?', 8, seed=9)
expected_seqs = ['prefix 143?', 'prefix 031?', 'prefix 302?', 'prefix 042?',
'prefix 404?', 'prefix 024?', 'prefix 021?', 'prefix 403?']
self.assertEqual(seqs, expected_seqs)
def test_construct_secret(self):
secret_config = SecretConfig(vocab=None, pattern='',
num_repetitions=[1, 2, 8],
num_secrets_for_repetitions=[2, 3, 1],
num_references=3)
seqs = list('0123456789')
secrets = construct_secret(secret_config, seqs)
self.assertEqual(secrets.config, secret_config)
self.assertDictEqual(secrets.secrets, {1: ['0', '1'],
2: ['2', '3', '4'],
8: ['5']})
self.assertEqual(secrets.references, ['7', '8', '9'])
# Test when the number of elements in seqs is not enough.
seqs = list('01234567')
self.assertRaises(ValueError, construct_secret, secret_config, seqs)
def test_generate_secrets_and_references(self):
secret_configs = [
SecretConfig(vocab=['w1', 'w2', 'w3'], pattern='{} {} suf',
num_repetitions=[1, 12],
num_secrets_for_repetitions=[2, 1],
num_references=3),
SecretConfig(vocab=['W 1', 'W 2', 'W 3'], pattern='{}-{}',
num_repetitions=[1, 2, 8],
num_secrets_for_repetitions=[2, 3, 1],
num_references=3)
]
secrets = generate_secrets_and_references(secret_configs, seed=27)
self.assertEqual(secrets[0].config, secret_configs[0])
self.assertDictEqual(secrets[0].secrets, {1: ['w3 w2 suf', 'w2 w1 suf'],
12: ['w1 w1 suf']})
self.assertEqual(secrets[0].references,
['w2 w3 suf', 'w2 w2 suf', 'w3 w1 suf'])
self.assertEqual(secrets[1].config, secret_configs[1])
self.assertDictEqual(secrets[1].secrets,
{1: ['W 3-W 2', 'W 1-W 3'],
2: ['W 3-W 1', 'W 2-W 1', 'W 1-W 1'],
8: ['W 2-W 2']})
self.assertEqual(secrets[1].references,
['W 2-W 3', 'W 3-W 3', 'W 1-W 2'])
if __name__ == '__main__':
absltest.main()