Merge branch 'tensorflow:master' into neuracrypt

This commit is contained in:
Nicholas Carlini 2021-12-14 13:14:29 -08:00 committed by GitHub
commit f8d516c1c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
64 changed files with 5304 additions and 1278 deletions

View file

@ -37,9 +37,6 @@ flags.DEFINE_string('site_path', 'responsible_ai/privacy/api_docs/python/',
'The location of the doc setin the site.')
flags.DEFINE_bool('search_hints', True,
'Include metadata search hints in the generated files.')
flags.DEFINE_bool('gen_report', False,
('Generate an API report containing the health of the'
'docstrings of the public API.'))
FLAGS = flags.FLAGS
@ -85,8 +82,6 @@ def gen_api_docs():
code_url_prefix=FLAGS.code_url_prefix,
site_path=FLAGS.site_path,
search_hints=FLAGS.search_hints,
private_map={},
gen_report=FLAGS.gen_report,
# This callback cleans up a lot of aliases caused by internal imports.
callbacks=[public_api.explicit_package_contents_filter])

View file

@ -1,5 +1,5 @@
# TODO(b/181782485): Switch to the main book for launch - /responsible_ai/_book.yaml
book_path: /responsible_ai/privacy/_book.yaml
book_path: /responsible_ai/_book.yaml
project_path: /responsible_ai/_project.yaml
title: TensorFlow Privacy
description: >
@ -13,16 +13,21 @@ landing_page:
- classname: devsite-landing-row-50
description: >
<p>
Preventing ML models from exposing potentially sensitive information is a critical part of
using AI responsibly. To that end, <i>differentially private stochastic gradient descent
(DP-SGD)</i> is a modification to the standard stochastic gradient descent (SGD) algorithm
in machine learning. </p>
<p>Models trained with DP-SGD have provable differential privacy (DP)
guarantees, mitigating the risk of exposing sensitive training data. Intuitively, a model
trained with differential privacy should not be affected by any single training example in
its data set. DP-SGD techniques can also be used in federated learning to provide user-level
differential privacy. You can learn more about differentially private deep learning in <a
href="https://arxiv.org/pdf/1607.00133.pdf">the original paper</a>.
An important aspect of responsible AI usage is ensuring that ML models are prevented from
exposing potentially sensitive information, such as demographic information or other
attributes in the training dataset that could be used to identify people.
One way to achieve this is by using differentially private stochastic gradient descent
(DP-SGD), which is a modification to the standard stochastic gradient descent (SGD)
algorithm in machine learning.
</p>
<p>
Models trained with DP-SGD have measurable differential privacy (DP) improvements, which
helps mitigate the risk of exposing sensitive training data. Since the purpose of DP is
to help prevent individual data points from being identified, a model trained with DP
should not be affected by any single training example in its training data set. DP-SGD
techniques can also be used in federated learning to provide user-level differential privacy.
You can learn more about differentially private deep learning in
<a href="https://arxiv.org/pdf/1607.00133.pdf">the original paper</a>.
</p>
- code_block: |
@ -58,14 +63,19 @@ landing_page:
items:
- classname: devsite-landing-row-100
description: >
<p>Tensorflow Privacy (TF Privacy) is an open source library developed by teams in Google
Research. The library includes implementations of commonly used TensorFlow Optimizers for
training ML models with DP. The goal is to enable ML practitioners using standard Tensorflow
APIs to train privacy-preserving models by changing only a few lines of code.</p>
<p> The differentially private Optimizers can be used in conjunction with high-level APIs
<p>
Tensorflow Privacy (TF Privacy) is an open source library developed by teams in
Google Research. The library includes implementations of commonly used TensorFlow
Optimizers for training ML models with DP. The goal is to enable ML practitioners
using standard Tensorflow APIs to train privacy-preserving models by changing only a
few lines of code.
</p>
<p>
The differentially private optimizers can be used in conjunction with high-level APIs
that use the Optimizer class, especially Keras. Additionally, you can find differentially
private implementations of some Keras models. All of the Optimizers and models can be found
in the <a href="./privacy/api">API Documentation</a>.</p>
in the <a href="../api_docs/python/tf_privacy">API Documentation</a>.</p>
</p>
- classname: devsite-landing-row-cards
items:

View file

@ -1,8 +1,6 @@
toc:
- title: Overview
path: /responsible_ai/privacy/guide/
- title: Install
path: /responsible_ai/privacy/guide/install
- title: Get Started
path: /responsible_ai/privacy/guide/get_started
- title: Measure Privacy

View file

@ -1,3 +1,90 @@
# Get Started
## Tips
This document assumes you are already familiar with differential privacy, and
have determined that you would like to use TF Privacy to implement differential
privacy guarantees in your model(s). If youre not familiar with differential
privacy, please review
[the overview page](https://tensorflow.org/responsible_ai/privacy/guide). After
installing TF Privacy, get started by following these steps:
## 1. Choose a differentially private version of an existing Optimizer
If youre currently using a TensorFlow
[optimizer](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers), you
will most likely want to select an Optimizer with the name `DPKeras*Optimizer`,
such as [`DPKerasAdamOptimizer`] in [`TF Privacy`].
Optionally, you may try vectorized optimizers like
[`tf_privacy.VectorizedDPKerasAdamOptimizer`]. for a possible speed improvement
(in terms of global steps per second). The use of vectorized optimizers has been
found to provide inconsistent speedups in experiments, but is not yet well
understood. As before, you will most likely want to use an optimizer analogous
to the one you're using now. These vectorized optimizers use Tensorflow's
`vectorized_map` operator, which may not work with some other Tensorflow
operators. If this is the case for you, please
[open an issue on the TF Privacy GitHub repository](https://github.com/tensorflow/privacy/issues).
## 2. Compute loss for your input minibatch
When computing the loss for your input minibatch, make sure it is a vector with
one entry per example, instead of aggregating it into a scalar. This is
necessary since DP-SGD must be able to compute the loss for individual
microbatches.
## 3. Train your model
Train your model using the DP Optimizer (step 1) and vectorized loss (step 2).
There are two options for doing this:
* Pass the optimizer and loss as arguments to `Model.compile` before calling
`Model.fit`.
* When writing a custom training loop, use `Optimizer.minimize()` on the
vectorized loss.
Once this is done, its recommended that you tune your hyperparameters. For a
complete walkthrough see the
[classification privacy tutorial](../tutorials/classification_privacy.ipynb)
## 4. Tune the DP-SGD hyperparameters
All `tf_privacy` optimizers take three additional hyperparameters:
* `l2_norm_clip` or $C$ - Clipping norm (the maximum Euclidean (L2) norm of
each individual gradient computed per minibatch).
* `noise_multiplier` or $σ$ - Ratio of the standard deviation to the clipping
norm.
* `num_microbatches` or $B$ - Number of microbatches into which each minibatch
is split.
Generally, the lower the effective standard deviation $σC / B$, the better the
performance of the trained model on its evaluation metrics.
The three new DP-SGD hyperparameters have the following effects and tradeoffs:
1. The number of microbatches $B$: Generally, increasing this will improve
utility because it lowers the standard deviation of the noise. However, it
will slow down training in terms of time.
2. The clipping norm $C$: Since the standard deviation of the noise scales with
$C$, it is probably best to set $C$ to be some quantile (e.g. median, 75th
percentile, 90th percentile) of the gradient norms. Having too large a value
of $C$ adds unnecessarily large amounts of noise.
3. The noise multiplier $σ$: Of the three hyperparameters, the amount of
privacy depends only on the noise multiplier. The larger the noise
multiplier, the more privacy is obtained; however, this also comes with a
loss of utility.
These tradeoffs between utility, privacy, and speed in terms of steps/second are
summarized here:
![tradeoffs](./images/getting-started-img.png)
Follow these suggestions to find the optimal hyperparameters:
* Set $C$ to a quantile as recommended above. A value of 1.00 often works
well.
* Set $B$ = 1, for maximum training speed.
* Experiment to find the largest value of σ that still gives acceptable
utility. Generally, values of 0.01 or lower have been observed to work well.
* Once a suitable value of $σ$ is found, scale both $B$ and $σ$ by a constant
to achieve a reasonable level of privacy.

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

View file

@ -1,3 +0,0 @@
# Installation Instructions
## Tips

View file

@ -1,5 +1,47 @@
# Measure Privacy
[TOC]
Differential privacy is a framework for measuring the privacy guarantees
provided by an algorithm and can be expressed using the values ε (epsilon) and δ
(delta). Of the two, ε is more important and more sensitive to the choice of
hyperparameters. Roughly speaking, they mean the following:
## Tips
* ε gives a ceiling on how much the probability of a particular output can
increase by including (or removing) a single training example. You usually
want it to be a small constant (less than 10, or for more stringent privacy
guarantees, less than 1). However, this is only an upper bound, and a large
value of epsilon may still mean good practical privacy.
* δ bounds the probability of an arbitrary change in model behavior. You can
usually set this to a very small number (1e-7 or so) without compromising
utility. A rule of thumb is to set it to be less than the inverse of the
training data size.
The relationship between training hyperparameters and the resulting privacy in
terms of (ε, δ) is complicated and tricky to state explicitly. Our current
recommended approach is at the bottom of the [Get Started page](get_started.md),
which involves finding the maximum noise multiplier one can use while still
having reasonable utility, and then scaling the noise multiplier and number of
microbatches. TensorFlow Privacy provides a tool, `compute_dp_sgd_privacy` to
compute (ε, δ) based on the noise multiplier σ, the number of training steps
taken, and the fraction of input data consumed at each step. The amount of
privacy increases with the noise multiplier σ and decreases the more times the
data is used on training. Generally, in order to achieve an epsilon of at most
10.0, we need to set the noise multiplier to around 0.3 to 0.5, depending on the
dataset size and number of epochs. See the
[classification privacy tutorial](../tutorials/classification_privacy.ipynb) to
see the approach.
For more detail, see
[the original DP-SGD paper](https://arxiv.org/pdf/1607.00133.pdf).
You can use `compute_dp_sgd_privacy` to find out the epsilon given a fixed delta
value for your model [../tutorials/classification_privacy.ipynb]:
* `q` : the sampling ratio - the probability of an individual training point
being included in a mini batch (`batch_size/number_of_examples`).
* `noise_multiplier` : A float that governs the amount of noise added during
training. Generally, more noise results in better privacy and lower utility.
* `steps` : The number of global steps taken.
A detailed writeup of the theory behind the computation of epsilon and delta is
available at
[Differential Privacy of the Sampled Gaussian Mechanism](https://arxiv.org/abs/1908.10530).

View file

@ -1,6 +1,4 @@
toc:
- title: Overview
path: /responsible_ai/privacy/tutorials/
- title: Compute privacy
path: /responsible_ai/privacy/tutorials/classification_privacy
- title: Assess privacy risk

View file

@ -77,7 +77,7 @@
"id": "vsCUvXP0W4j2"
},
"source": [
"[Differential privacy](https://en.wikipedia.org/wiki/Differential_privacy) (DP) is a framework for measuring the privacy guarantees provided by an algorithm. Through the lens of differential privacy, you can design machine learning algorithms that responsibly train models on private data. Learning with differential privacy provides provable guarantees of privacy, mitigating the risk of exposing sensitive training data in machine learning. Intuitively, a model trained with differential privacy should not be affected by any single training example, or small set of training examples, in its data set. This mitigates the risk of exposing sensitive training data in ML."
"[Differential privacy](https://en.wikipedia.org/wiki/Differential_privacy) (DP) is a framework for measuring the privacy guarantees provided by an algorithm. Through the lens of differential privacy, you can design machine learning algorithms that responsibly train models on private data. Learning with differential privacy provides measurable guarantees of privacy, helping to mitigate the risk of exposing sensitive training data in machine learning. Intuitively, a model trained with differential privacy should not be affected by any single training example, or small set of training examples, in its data set. This helps mitigate the risk of exposing sensitive training data in ML."
]
},
{
@ -452,6 +452,7 @@
"colab": {
"collapsed_sections": [],
"name": "classification_privacy.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {

View file

@ -1,3 +0,0 @@
# PROJECT_NAME tutorials
Lorem ipsum dolor sit amet, consectetur adipiscing elit.

View file

@ -95,7 +95,6 @@
"from sklearn import metrics\n",
"\n",
"import tensorflow as tf\n",
"tf.compat.v1.disable_v2_behavior()\n",
"\n",
"import tensorflow_datasets as tfds\n",
"\n",
@ -137,14 +136,25 @@
},
"outputs": [],
"source": [
"from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyMetric\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata\n",
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec\n",
"from tensorflow_privacy.privacy.membership_inference_attack import privacy_report"
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia\n",
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData\n",
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection\n",
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType\n",
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyMetric\n",
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata\n",
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec\n",
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import privacy_report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VpOdtnbPbPXE"
},
"outputs": [],
"source": [
"import tensorflow_privacy"
]
},
{
@ -171,13 +181,13 @@
"dataset = 'cifar10'\n",
"num_classes = 10\n",
"activation = 'relu'\n",
"lr = 0.02\n",
"momentum = 0.9\n",
"batch_size = 250\n",
"epochs_per_report = 5\n",
"num_reports = 10\n",
"# Privacy risks are especially visible with lots of epochs.\n",
"total_epochs = epochs_per_report*num_reports "
"num_conv = 3\n",
"\n",
"batch_size=50\n",
"epochs_per_report = 2\n",
"total_epochs = 50\n",
"\n",
"lr = 0.001"
]
},
{
@ -197,7 +207,7 @@
},
"outputs": [],
"source": [
"#@title Load the data\n",
"#@title\n",
"print('Loading the dataset.')\n",
"train_ds = tfds.as_numpy(\n",
" tfds.load(dataset, split=tfds.Split.TRAIN, batch_size=-1))\n",
@ -212,7 +222,9 @@
"y_train = tf.keras.utils.to_categorical(y_train_indices, num_classes)\n",
"y_test = tf.keras.utils.to_categorical(y_test_indices, num_classes)\n",
"\n",
"input_shape = x_train.shape[1:]"
"input_shape = x_train.shape[1:]\n",
"\n",
"assert x_train.shape[0] % batch_size == 0, \"The tensorflow_privacy optimizer doesn't handle partial batches\""
]
},
{
@ -232,7 +244,7 @@
},
"outputs": [],
"source": [
"#@title Define the models\n",
"#@title\n",
"def small_cnn(input_shape: Tuple[int],\n",
" num_classes: int,\n",
" num_conv: int,\n",
@ -259,7 +271,13 @@
" model.add(tf.keras.layers.Flatten())\n",
" model.add(tf.keras.layers.Dense(64, activation=activation))\n",
" model.add(tf.keras.layers.Dense(num_classes))\n",
" return model\n"
" \n",
" model.compile(\n",
" loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),\n",
" optimizer=tf.keras.optimizers.Adam(learning_rate=lr),\n",
" metrics=['accuracy'])\n",
"\n",
" return model"
]
},
{
@ -268,7 +286,9 @@
"id": "hs0Smn24Dty-"
},
"source": [
"Build two-layer and a three-layer CNN models using that function. Again there's nothing provacy specific about this code. It uses standard models, layers, losses, and optimizers."
"Build two three-layer CNN models using that function.\n",
"\n",
"Configure the first to use a basic SGD optimizer, an the second to use a differentially private optimizer (`tf_privacy.DPKerasAdamOptimizer`), so you can compare the results."
]
},
{
@ -279,16 +299,10 @@
},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(lr=lr, momentum=momentum)\n",
"loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)\n",
"\n",
"three_layer_model = small_cnn(\n",
" input_shape, num_classes, num_conv=3, activation=activation)\n",
"three_layer_model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])\n",
"\n",
"two_layer_model = small_cnn(\n",
"model_2layers = small_cnn(\n",
" input_shape, num_classes, num_conv=2, activation=activation)\n",
"two_layer_model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])"
"model_3layers = small_cnn(\n",
" input_shape, num_classes, num_conv=3, activation=activation)"
]
},
{
@ -318,42 +332,42 @@
" def __init__(self, epochs_per_report, model_name):\n",
" self.epochs_per_report = epochs_per_report\n",
" self.model_name = model_name\n",
" self.epochs = []\n",
" self.attack_results = [] \n",
" self.attack_results = []\n",
"\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" epoch = epoch+1\n",
"\n",
" def on_epoch_end(self, n, logs=None):\n",
" epoch = n + 1\n",
" if epoch % self.epochs_per_report != 0:\n",
" return\n",
" \n",
" print(f\"\\nRunning privacy report for epoch: {epoch}\")\n",
" self.epochs.append(epoch)\n",
"\n",
" logits_train = model.predict(x_train, batch_size=batch_size)\n",
" logits_test = model.predict(x_test, batch_size=batch_size)\n",
" print(f'\\nRunning privacy report for epoch: {epoch}\\n')\n",
"\n",
" logits_train = self.model.predict(x_train, batch_size=batch_size)\n",
" logits_test = self.model.predict(x_test, batch_size=batch_size)\n",
"\n",
" prob_train = special.softmax(logits_train, axis=1)\n",
" prob_test = special.softmax(logits_test, axis=1)\n",
"\n",
" # Add metadata to generate a privacy report.\n",
" privacy_report_metadata = PrivacyReportMetadata(\n",
" accuracy_train=metrics.accuracy_score(y_train_indices,\n",
" np.argmax(prob_train, axis=1)),\n",
" accuracy_test=metrics.accuracy_score(y_test_indices,\n",
" np.argmax(prob_test, axis=1)),\n",
" # Show the validation accuracy on the plot\n",
" # It's what you send to train_accuracy that gets plotted.\n",
" accuracy_train=logs['val_accuracy'], \n",
" accuracy_test=logs['val_accuracy'],\n",
" epoch_num=epoch,\n",
" model_variant_label=self.model_name)\n",
"\n",
" attack_results = mia.run_attacks(\n",
" AttackInputData(\n",
" labels_train=np.asarray([x[0] for x in y_train_indices]),\n",
" labels_test=np.asarray([x[0] for x in y_test_indices]),\n",
" labels_train=y_train_indices[:, 0],\n",
" labels_test=y_test_indices[:, 0],\n",
" probs_train=prob_train,\n",
" probs_test=prob_test),\n",
" SlicingSpec(entire_dataset=True, by_class=True),\n",
" attack_types=(AttackType.THRESHOLD_ATTACK,\n",
" AttackType.LOGISTIC_REGRESSION),\n",
" privacy_report_metadata=privacy_report_metadata)\n",
"\n",
" self.attack_results.append(attack_results)\n"
]
},
@ -365,7 +379,18 @@
"source": [
"### Train the models\n",
"\n",
"The next code block trains the two models. The `all_reports` list is used to collect all the results from all the models' training runs. The individual reports are tagged witht the `model_name`, so there's no confusion about which model generated which report. "
"The next code block trains the two models. The `all_reports` list is used to collect all the results from all the models' training runs. The individual reports are tagged witht the `model_name`, so there's no confusion about which model generated which report."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o3U76c2Y4irD"
},
"outputs": [],
"source": [
"all_reports = []"
]
},
{
@ -376,19 +401,8 @@
},
"outputs": [],
"source": [
"all_reports = []\n",
"\n",
"models = {\n",
" 'two layer model': two_layer_model,\n",
" 'three layer model': three_layer_model,\n",
"}\n",
"\n",
"for model_name, model in models.items():\n",
" print(f\"\\n\\n\\nFitting {model_name}\\n\")\n",
" callback = PrivacyMetrics(epochs_per_report, \n",
" model_name)\n",
"\n",
" model.fit(\n",
"callback = PrivacyMetrics(epochs_per_report, \"2 Layers\")\n",
"history = model_2layers.fit(\n",
" x_train,\n",
" y_train,\n",
" batch_size=batch_size,\n",
@ -396,8 +410,29 @@
" validation_data=(x_test, y_test),\n",
" callbacks=[callback],\n",
" shuffle=True)\n",
" \n",
" all_reports.extend(callback.attack_results)\n"
"\n",
"all_reports.extend(callback.attack_results)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "27qLElOR4y_i"
},
"outputs": [],
"source": [
"callback = PrivacyMetrics(epochs_per_report, \"3 Layers\")\n",
"history = model_3layers.fit(\n",
" x_train,\n",
" y_train,\n",
" batch_size=batch_size,\n",
" epochs=total_epochs,\n",
" validation_data=(x_test, y_test),\n",
" callbacks=[callback],\n",
" shuffle=True)\n",
"\n",
"all_reports.extend(callback.attack_results)"
]
},
{
@ -470,7 +505,10 @@
"source": [
"privacy_metrics = (PrivacyMetric.AUC, PrivacyMetric.ATTACKER_ADVANTAGE)\n",
"utility_privacy_plot = privacy_report.plot_privacy_vs_accuracy(\n",
" results, privacy_metrics=privacy_metrics)"
" results, privacy_metrics=privacy_metrics)\n",
"\n",
"for axis in utility_privacy_plot.axes:\n",
" axis.set_xlabel('Validation accuracy')"
]
},
{
@ -490,8 +528,7 @@
"id": "7u3BAg87v3qv"
},
"source": [
"This is the end of the colab!\n",
"Feel free to analyze your own results."
"This is the end of the tutorial. Feel free to analyze your own results."
]
}
],
@ -500,6 +537,7 @@
"colab": {
"collapsed_sections": [],
"name": "privacy_report.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {

View file

@ -4,3 +4,5 @@ tensorflow-estimator>=2.3.0
attrs>=21.2.0
mpmath
dm-tree~=0.1.1
tensorflow-probability>=0.13.0
tensorflow-datasets>=4.4.0

View file

@ -17,7 +17,7 @@ from setuptools import setup
setup(
name='tensorflow_privacy',
version='0.6.1',
version='0.7.3',
url='https://github.com/tensorflow/privacy',
license='Apache-2.0',
install_requires=[
@ -26,6 +26,8 @@ setup(
'attrs>=21.2.0', # for tree_aggregation_query.py.
'mpmath', # used in tests only
'dm-tree~=0.1.1', # used in tests only
'tensorflow-probability>=0.13.0', # For discrete Gaussian.
'tensorflow-datasets>=4.4.0'
],
# Explicit dependence on TensorFlow is not supported.
# See https://github.com/tensorflow/tensorflow/issues/7166

View file

@ -26,20 +26,35 @@ from tensorflow_privacy.version import __version__ # pylint: disable=g-bad-impo
if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts.
pass
else:
# TensorFlow v1 imports
from tensorflow_privacy import v1
# DpEvents
from tensorflow_privacy.privacy.analysis.dp_event import DpEvent
from tensorflow_privacy.privacy.analysis.dp_event import NoOpDpEvent
from tensorflow_privacy.privacy.analysis.dp_event import NonPrivateDpEvent
from tensorflow_privacy.privacy.analysis.dp_event import UnsupportedDpEvent
from tensorflow_privacy.privacy.analysis.dp_event import GaussianDpEvent
from tensorflow_privacy.privacy.analysis.dp_event import SelfComposedDpEvent
from tensorflow_privacy.privacy.analysis.dp_event import ComposedDpEvent
from tensorflow_privacy.privacy.analysis.dp_event import PoissonSampledDpEvent
from tensorflow_privacy.privacy.analysis.dp_event import SampledWithReplacementDpEvent
from tensorflow_privacy.privacy.analysis.dp_event import SampledWithoutReplacementDpEvent
# Analysis
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
from tensorflow_privacy.privacy.analysis.privacy_ledger import GaussianSumQueryEntry
from tensorflow_privacy.privacy.analysis.privacy_ledger import PrivacyLedger
from tensorflow_privacy.privacy.analysis.privacy_ledger import QueryWithLedger
from tensorflow_privacy.privacy.analysis.privacy_ledger import SampleEntry
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_heterogenous_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_heterogeneous_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp_from_ledger
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.analysis.tree_aggregation_accountant import compute_rdp_tree_restart
from tensorflow_privacy.privacy.analysis.tree_aggregation_accountant import compute_rdp_single_tree
from tensorflow_privacy.privacy.analysis.tree_aggregation_accountant import compute_zcdp_single_tree
# DPQuery classes
from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery
from tensorflow_privacy.privacy.dp_query.dp_query import SumAggregationDPQuery
from tensorflow_privacy.privacy.dp_query.discrete_gaussian_query import DiscreteGaussianSumQuery
from tensorflow_privacy.privacy.dp_query.distributed_discrete_gaussian_query import DistributedDiscreteGaussianSumQuery
from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery
from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery
from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery
@ -48,13 +63,15 @@ else:
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import QuantileEstimatorQuery
from tensorflow_privacy.privacy.dp_query.quantile_estimator_query import NoPrivacyQuantileEstimatorQuery
from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery
from tensorflow_privacy.privacy.dp_query import restart_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation
from tensorflow_privacy.privacy.dp_query.restart_query import RestartQuery
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeCumulativeSumQuery
from tensorflow_privacy.privacy.dp_query.tree_aggregation_query import TreeResidualSumQuery
from tensorflow_privacy.privacy.dp_query.tree_range_query import TreeRangeSumQuery
# Estimators
from tensorflow_privacy.privacy.estimators.dnn import DNNClassifier
from tensorflow_privacy.privacy.estimators.v1.dnn import DNNClassifier as DNNClassifierV1
# Keras Models
from tensorflow_privacy.privacy.keras_models.dp_keras_model import DPModel
@ -62,14 +79,6 @@ else:
from tensorflow_privacy.privacy.keras_models.dp_keras_model import make_dp_model_class
# Optimizers
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdagradGaussianOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdamGaussianOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import make_optimizer_class
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
@ -80,15 +89,6 @@ else:
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import VectorizedDPKerasSGDOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import make_vectorized_keras_optimizer_class
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGDOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagrad
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdam
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGD
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import make_vectorized_optimizer_class
try:
from tensorflow_privacy.privacy.bolt_on.models import BoltOnModel
from tensorflow_privacy.privacy.bolt_on.optimizers import BoltOn

View file

@ -32,16 +32,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from absl import app
from absl import flags
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
# Opting out of loading all sibling packages and their dependencies.
sys.skip_tf_privacy_import = True
FLAGS = flags.FLAGS
flags.DEFINE_integer('N', None, 'Total number of examples')

View file

@ -19,13 +19,9 @@ from __future__ import division
from __future__ import print_function
import math
import sys
from absl import app
# Opting out of loading all sibling packages and their dependencies.
sys.skip_tf_privacy_import = True
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp # pylint: disable=g-import-not-at-top
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent

View file

@ -34,16 +34,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from absl import app
from absl import flags
from tensorflow_privacy.privacy.analysis.compute_noise_from_budget_lib import compute_noise
# Opting out of loading all sibling packages and their dependencies.
sys.skip_tf_privacy_import = True
FLAGS = flags.FLAGS
flags.DEFINE_integer('N', None, 'Total number of examples')

View file

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import math
import sys
from absl import app
from scipy.optimize import bisect
@ -27,9 +26,6 @@ from scipy.optimize import bisect
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp # pylint: disable=g-import-not-at-top
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
# Opting out of loading all sibling packages and their dependencies.
sys.skip_tf_privacy_import = True
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
"""Compute and print results of DP-SGD analysis."""

View file

@ -0,0 +1,179 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Standard DpEvent classes.
A `DpEvent` represents the (hyper)parameters of a differentially
private query, amplification mechanism, or composition, that are necessary
and sufficient for privacy accounting. Various independent implementations of DP
algorithms that are functionally equivalent from an accounting perspective may
correspond to the same `DpEvent`. Similarly, various independent implementations
of accounting algorithms may consume the same `DpEvent`.
All `DpEvents` processed together are assumed to take place on a single dataset
of records. `DpEvents` fall into roughly three categories:
- `DpEvents` that release an output, and incur a privacy cost,
e.g., `GaussianDpEvent`.
- `DpEvents` that select a subset (or subsets) of the dataset, and run nested
`DpEvents` on those subsets, e.g., `PoissonSampledDpEvent`.
- `DpEvents` that represent (possibly sequentially) applying (multiple)
mechanisms to the dataset (or currently active subset). Currently, this is
only `ComposedDpEvent` and `SelfComposedDpEvent`.
Each `DpEvent` should completely document the mathematical behavior and
assumptions of the mechanism it represents so that the writer of an accountant
class can implement the accounting correctly without knowing any other
implementation details of the algorithm that produced it.
New mechanism types should be given a corresponding `DpEvent` class, although
not all accountants will be required to support them. In general,
`PrivacyAccountant` implementations are not required to be aware of all
`DpEvent` classes, but they should support the following basic events and handle
them appropriately: `NoOpDpEvent`, `NonPrivateDpEvent`, `ComposedDpEvent`, and
`SelfComposedDpEvent`. They should return `supports(event)` is False for
`UnsupportedDpEvent` or any other event type they have not been designed to
handle.
To ensure that a `PrivacyAccountant` does not accidentally start to return
incorrect results, the following should be enforced:
* `DpEvent` classes and their parameters should never be removed, barring some
extended, onerous deprecation process.
* New parameters cannot be added to existing mechanisms unless they are
optional. That is, old composed `DpEvent` objects that do not include them
must remain valid.
* The meaning of existing mechanisms or parameters must not change. That is,
existing mechanisms should not have their implementations change in ways that
alter their privacy properties; new `DpEvent` classes should be added
instead.
* `PrivacyAccountant` implementations are expected to return `supports(event)`
is `False` when processing unknown mechanisms.
"""
from typing import List
import attr
class DpEvent(object):
"""Represents application of a private mechanism.
A `DpEvent` describes a differentially private mechanism sufficiently for
computing the associated privacy losses, both in isolation and in combination
with other `DpEvent`s.
"""
@attr.s(frozen=True)
class NoOpDpEvent(DpEvent):
"""Represents appplication of an operation with no privacy impact.
A `NoOpDpEvent` is generally never required, but it can be useful as a
placeholder where a `DpEvent` is expected, such as in tests or some live
accounting pipelines.
"""
@attr.s(frozen=True)
class NonPrivateDpEvent(DpEvent):
"""Represents application of a non-private operation.
This `DpEvent` should be used when an operation is performed that does not
satisfy (epsilon, delta)-DP. All `PrivacyAccountant`s should return infinite
epsilon/delta when encountering a `NonPrivateDpEvent`.
"""
@attr.s(frozen=True)
class UnsupportedDpEvent(DpEvent):
"""Represents application of an as-yet unsupported operation.
This `DpEvent` should be used when an operation is performed that does not yet
have any associated DP description, or if the description is temporarily
inaccessible, for example, during development. All `PrivacyAccountant`s should
return `supports(event) == False` for `UnsupportedDpEvent`.
"""
@attr.s(frozen=True, slots=True, auto_attribs=True)
class GaussianDpEvent(DpEvent):
"""Represents an application of the Gaussian mechanism.
For values v_i and noise z ~ N(0, s^2I), this mechanism returns sum_i v_i + z.
If the norms of the values are bounded ||v_i|| <= C, the noise_multiplier is
defined as s / C.
"""
noise_multiplier: float
@attr.s(frozen=True, slots=True, auto_attribs=True)
class SelfComposedDpEvent(DpEvent):
"""Represents repeated application of a mechanism.
The repeated applications may be adaptive, where the query producing each
event depends on the results of prior queries.
This is equivalent to `ComposedDpEvent` that contains a list of length `count`
of identical copies of `event`.
"""
event: DpEvent
count: int
@attr.s(frozen=True, slots=True, auto_attribs=True)
class ComposedDpEvent(DpEvent):
"""Represents application of a series of composed mechanisms.
The composition may be adaptive, where the query producing each event depends
on the results of prior queries.
"""
events: List[DpEvent]
@attr.s(frozen=True, slots=True, auto_attribs=True)
class PoissonSampledDpEvent(DpEvent):
"""Represents an application of Poisson subsampling.
Each record in the dataset is included in the sample independently with
probability `sampling_probability`. Then the `DpEvent` `event` is applied
to the sample of records.
"""
sampling_probability: float
event: DpEvent
@attr.s(frozen=True, slots=True, auto_attribs=True)
class SampledWithReplacementDpEvent(DpEvent):
"""Represents sampling a fixed sized batch of records with replacement.
A sample of `sample_size` (possibly repeated) records is drawn uniformly at
random from the set of possible samples of a source dataset of size
`source_dataset_size`. Then the `DpEvent` `event` is applied to the sample of
records.
"""
source_dataset_size: int
sample_size: int
event: DpEvent
@attr.s(frozen=True, slots=True, auto_attribs=True)
class SampledWithoutReplacementDpEvent(DpEvent):
"""Represents sampling a fixed sized batch of records without replacement.
A sample of `sample_size` unique records is drawn uniformly at random from the
set of possible samples of a source dataset of size `source_dataset_size`.
Then the `DpEvent` `event` is applied to the sample of records.
"""
source_dataset_size: int
sample_size: int
event: DpEvent

View file

@ -0,0 +1,76 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Builder class for ComposedDpEvent."""
from tensorflow_privacy.privacy.analysis import dp_event
class DpEventBuilder(object):
"""Constructs a `DpEvent` representing the composition of a series of events.
Two common use cases of the `DpEventBuilder` are 1) for producing and tracking
a ledger of `DpEvent`s during sequential accounting using a
`PrivacyAccountant`, and 2) for building up a description of a composite
mechanism for subsequent batch accounting.
"""
def __init__(self):
# A list of (event, count) pairs.
self._event_counts = []
self._composed_event = None
def compose(self, event: dp_event.DpEvent, count: int = 1):
"""Composes new event into event represented by builder.
Args:
event: The new event to compose.
count: The number of times to compose the event.
"""
if not isinstance(event, dp_event.DpEvent):
raise TypeError('`event` must be a subclass of `DpEvent`. '
f'Found {type(event)}.')
if not isinstance(count, int):
raise TypeError(f'`count` must be an integer. Found {type(count)}.')
if count < 1:
raise ValueError(f'`count` must be positive. Found {count}.')
if isinstance(event, dp_event.NoOpDpEvent):
return
elif isinstance(event, dp_event.SelfComposedDpEvent):
self.compose(event.event, count * event.count)
else:
if self._event_counts and self._event_counts[-1][0] == event:
new_event_count = (event, self._event_counts[-1][1] + count)
self._event_counts[-1] = new_event_count
else:
self._event_counts.append((event, count))
self._composed_event = None
def build(self) -> dp_event.DpEvent:
"""Builds and returns the composed DpEvent represented by the builder."""
if not self._composed_event:
events = []
for event, count in self._event_counts:
if count == 1:
events.append(event)
else:
events.append(dp_event.SelfComposedDpEvent(event, count))
if not events:
self._composed_event = dp_event.NoOpDpEvent()
elif len(events) == 1:
self._composed_event = events[0]
else:
self._composed_event = dp_event.ComposedDpEvent(events)
return self._composed_event

View file

@ -0,0 +1,76 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for DpEventBuilder."""
from absl.testing import absltest
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.analysis import dp_event_builder
_gaussian_event = dp_event.GaussianDpEvent(1.0)
_poisson_event = dp_event.PoissonSampledDpEvent(_gaussian_event, 0.1)
_self_composed_event = dp_event.SelfComposedDpEvent(_gaussian_event, 3)
class DpEventBuilderTest(absltest.TestCase):
def test_no_op(self):
builder = dp_event_builder.DpEventBuilder()
self.assertEqual(dp_event.NoOpDpEvent(), builder.build())
def test_single(self):
builder = dp_event_builder.DpEventBuilder()
builder.compose(_gaussian_event)
self.assertEqual(_gaussian_event, builder.build())
def test_compose_no_op(self):
builder = dp_event_builder.DpEventBuilder()
builder.compose(dp_event.NoOpDpEvent())
builder.compose(_gaussian_event)
builder.compose(dp_event.NoOpDpEvent())
self.assertEqual(_gaussian_event, builder.build())
def test_compose_self(self):
builder = dp_event_builder.DpEventBuilder()
builder.compose(_gaussian_event)
builder.compose(_gaussian_event, 2)
self.assertEqual(_self_composed_event, builder.build())
def test_compose_heterogenous(self):
builder = dp_event_builder.DpEventBuilder()
builder.compose(_poisson_event)
builder.compose(_gaussian_event)
builder.compose(_gaussian_event, 2)
builder.compose(_poisson_event)
expected_event = dp_event.ComposedDpEvent(
[_poisson_event, _self_composed_event, _poisson_event])
self.assertEqual(expected_event, builder.build())
def test_compose_composed(self):
builder = dp_event_builder.DpEventBuilder()
composed_event = dp_event.ComposedDpEvent(
[_gaussian_event, _poisson_event, _self_composed_event])
builder.compose(_gaussian_event)
builder.compose(composed_event)
builder.compose(composed_event, 2)
builder.compose(_poisson_event)
builder.compose(_poisson_event)
expected_event = dp_event.ComposedDpEvent([
_gaussian_event,
dp_event.SelfComposedDpEvent(composed_event, 3),
dp_event.SelfComposedDpEvent(_poisson_event, 2)])
self.assertEqual(expected_event, builder.build())
if __name__ == '__main__':
absltest.main()

View file

@ -0,0 +1,127 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PrivacyAccountant abstract base class."""
import abc
import enum
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.analysis import dp_event_builder
class NeighboringRelation(enum.Enum):
ADD_OR_REMOVE_ONE = 1
REPLACE_ONE = 2
class UnsupportedEventError(Exception):
"""Exception to raise if _compose is called on unsupported event type."""
class PrivacyAccountant(metaclass=abc.ABCMeta):
"""Abstract base class for privacy accountants."""
def __init__(self, neighboring_relation: NeighboringRelation):
self._neighboring_relation = neighboring_relation
self._ledger = dp_event_builder.DpEventBuilder()
@property
def neighboring_relation(self) -> NeighboringRelation:
"""The neighboring relation used by the accountant.
The neighboring relation is expected to remain constant after
initialization. Subclasses should not override this property or change the
value of the private attribute.
"""
return self._neighboring_relation
@abc.abstractmethod
def supports(self, event: dp_event.DpEvent) -> bool:
"""Checks whether the `DpEvent` can be processed by this accountant.
In general this will require recursively checking the structure of the
`DpEvent`. In particular `ComposedDpEvent` and `SelfComposedDpEvent` should
be recursively examined.
Args:
event: The `DpEvent` to check.
Returns:
True iff this accountant supports processing `event`.
"""
@abc.abstractmethod
def _compose(self, event: dp_event.DpEvent, count: int = 1):
"""Updates internal state to account for application of a `DpEvent`.
Calls to `get_epsilon` or `get_delta` after calling `_compose` will return
values that account for this `DpEvent`.
Args:
event: A `DpEvent` to process.
count: The number of times to compose the event.
"""
def compose(self, event: dp_event.DpEvent, count: int = 1):
"""Updates internal state to account for application of a `DpEvent`.
Calls to `get_epsilon` or `get_delta` after calling `compose` will return
values that account for this `DpEvent`.
Args:
event: A `DpEvent` to process.
count: The number of times to compose the event.
Raises:
UnsupportedEventError: `event` is not supported by this
`PrivacyAccountant`.
"""
if not isinstance(event, dp_event.DpEvent):
raise TypeError(f'`event` must be `DpEvent`. Found {type(event)}.')
if not self.supports(event):
raise UnsupportedEventError('Unsupported event: {event}.')
self._ledger.compose(event, count)
self._compose(event, count)
@property
def ledger(self) -> dp_event.DpEvent:
"""Returns the (composed) `DpEvent` processed so far by this accountant."""
return self._ledger.build()
@abc.abstractmethod
def get_epsilon(self, target_delta: float) -> float:
"""Gets the current epsilon.
Args:
target_delta: The target delta.
Returns:
The current epsilon, accounting for all composed `DpEvent`s.
"""
def get_delta(self, target_epsilon: float) -> float:
"""Gets the current delta.
An implementer of `PrivacyAccountant` may choose not to override this, in
which case `NotImplementedError` will be raised.
Args:
target_epsilon: The target epsilon.
Returns:
The current delta, accounting for all composed `DpEvent`s.
"""
raise NotImplementedError()

View file

@ -0,0 +1,101 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstract base class for tests of `PrivacyAccountant` classes.
Checks that a class derived from `PrivacyAccountant` has the correct behavior
for standard `DpEvent` classes.
"""
from typing import Collection
from absl.testing import absltest
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.analysis import privacy_accountant
class PrivacyAccountantTest(absltest.TestCase):
def _make_test_accountants(
self) -> Collection[privacy_accountant.PrivacyAccountant]:
"""Makes a list of accountants to test.
Subclasses should define this to return a list of accountants to be tested.
Returns:
A list of accountants to test.
"""
return []
def test_make_test_accountants(self):
self.assertNotEmpty(self._make_test_accountants())
def test_unsupported(self):
class UnknownDpEvent(dp_event.DpEvent):
pass
for accountant in self._make_test_accountants():
for unsupported in [dp_event.UnsupportedDpEvent(), UnknownDpEvent()]:
self.assertFalse(accountant.supports(unsupported))
self.assertFalse(
accountant.supports(dp_event.SelfComposedDpEvent(unsupported, 10)))
self.assertFalse(
accountant.supports(dp_event.ComposedDpEvent([unsupported])))
def test_no_events(self):
for accountant in self._make_test_accountants():
self.assertEqual(accountant.get_epsilon(1e-12), 0)
self.assertEqual(accountant.get_epsilon(0), 0)
self.assertEqual(accountant.get_epsilon(1), 0)
try:
self.assertEqual(accountant.get_delta(1e-12), 0)
self.assertEqual(accountant.get_delta(0), 0)
self.assertEqual(accountant.get_delta(float('inf')), 0)
except NotImplementedError:
# Implementing `get_delta` is optional.
pass
def test_no_op(self):
for accountant in self._make_test_accountants():
event = dp_event.NoOpDpEvent()
self.assertTrue(accountant.supports(event))
accountant._compose(event)
self.assertEqual(accountant.get_epsilon(1e-12), 0)
self.assertEqual(accountant.get_epsilon(0), 0)
self.assertEqual(accountant.get_epsilon(1), 0)
try:
self.assertEqual(accountant.get_delta(1e-12), 0)
self.assertEqual(accountant.get_delta(0), 0)
self.assertEqual(accountant.get_delta(float('inf')), 0)
except NotImplementedError:
# Implementing `get_delta` is optional.
pass
def test_non_private(self):
for accountant in self._make_test_accountants():
event = dp_event.NonPrivateDpEvent()
self.assertTrue(accountant.supports(event))
accountant._compose(event)
self.assertEqual(accountant.get_epsilon(0.99), float('inf'))
self.assertEqual(accountant.get_epsilon(0), float('inf'))
self.assertEqual(accountant.get_epsilon(1), float('inf'))
try:
self.assertEqual(accountant.get_delta(100), 1)
self.assertEqual(accountant.get_delta(0), 1)
self.assertEqual(accountant.get_delta(float('inf')), 1)
except NotImplementedError:
# Implementing `get_delta` is optional.
pass

View file

@ -1,299 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PrivacyLedger class for keeping a record of private queries."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import tensor_buffer
from tensorflow_privacy.privacy.dp_query import dp_query
SampleEntry = collections.namedtuple( # pylint: disable=invalid-name
'SampleEntry', ['population_size', 'selection_probability', 'queries'])
GaussianSumQueryEntry = collections.namedtuple( # pylint: disable=invalid-name
'GaussianSumQueryEntry', ['l2_norm_bound', 'noise_stddev'])
def format_ledger(sample_array, query_array):
"""Converts array representation into a list of SampleEntries."""
samples = []
query_pos = 0
sample_pos = 0
for sample in sample_array:
population_size, selection_probability, num_queries = sample
queries = []
for _ in range(int(num_queries)):
query = query_array[query_pos]
assert int(query[0]) == sample_pos
queries.append(GaussianSumQueryEntry(*query[1:]))
query_pos += 1
samples.append(SampleEntry(population_size, selection_probability, queries))
sample_pos += 1
return samples
class PrivacyLedger(object):
"""Class for keeping a record of private queries.
The PrivacyLedger keeps a record of all queries executed over a given dataset
for the purpose of computing privacy guarantees. To use it, it must be
associated with a `DPQuery` object via a `QueryWithLedger`.
The current implementation works only with DPQueries that consist of composing
Gaussian sum mechanism with Poisson subsampling.
Example usage:
```
import tensorflow_privacy as tfp
dp_query = tfp.QueryWithLedger(
tensorflow_privacy.GaussianSumQuery(
l2_norm_clip=1.0, stddev=1.0),
population_size=10000,
selection_probability=0.01)
# Use dp_query here in training loop.
formatted_ledger = dp_query.ledger.get_formatted_ledger_eager()
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
list(range(5, 64)) + [128, 256, 512])
total_rdp = tfp.compute_rdp_from_ledger(formatted_ledger, orders)
epsilon = tfp.get_privacy_spent(orders, total_rdp, target_delta=1e-5)
```
"""
def __init__(self,
population_size,
selection_probability):
"""Initializes the PrivacyLedger.
Args:
population_size: An integer (may be variable) specifying the size of the
population, i.e. size of the training data used in each epoch.
selection_probability: A floating point value (may be variable) specifying
the probability each record is included in a sample.
Raises:
ValueError: If `selection_probability` is 0.
"""
self._population_size = population_size
self._selection_probability = selection_probability
if tf.executing_eagerly():
if tf.equal(selection_probability, 0):
raise ValueError('Selection probability cannot be 0.')
init_capacity = tf.cast(tf.math.ceil(1 / selection_probability), tf.int32)
else:
if selection_probability == 0:
raise ValueError('Selection probability cannot be 0.')
init_capacity = np.int(np.ceil(1 / selection_probability))
# The query buffer stores rows corresponding to GaussianSumQueryEntries.
self._query_buffer = tensor_buffer.TensorBuffer(
init_capacity, [3], tf.float32, 'query')
self._sample_var = tf.Variable(
initial_value=tf.zeros([3]), trainable=False, name='sample')
# The sample buffer stores rows corresponding to SampleEntries.
self._sample_buffer = tensor_buffer.TensorBuffer(
init_capacity, [3], tf.float32, 'sample')
self._sample_count = tf.Variable(
initial_value=0.0, trainable=False, name='sample_count')
self._query_count = tf.Variable(
initial_value=0.0, trainable=False, name='query_count')
self._cs = tf.CriticalSection()
def record_sum_query(self, l2_norm_bound, noise_stddev):
"""Records that a query was issued.
Args:
l2_norm_bound: The maximum l2 norm of the tensor group in the query.
noise_stddev: The standard deviation of the noise applied to the sum.
Returns:
An operation recording the sum query to the ledger. This should be called
for every Gaussian sum query that is issued on a sample.
"""
def _do_record_query():
with tf.control_dependencies(
[tf.assign(self._query_count, self._query_count + 1)]):
return self._query_buffer.append(
[self._sample_count, l2_norm_bound, noise_stddev])
return self._cs.execute(_do_record_query)
def finalize_sample(self):
"""Finalizes sample and records sample ledger entry.
This should be called once per application of the mechanism on a sample,
after all sum queries have been recorded.
Returns:
An operation recording the complete mechanism (sampling and sum
estimation) to the ledger.
"""
with tf.control_dependencies([
tf.assign(self._sample_var, [
self._population_size, self._selection_probability,
self._query_count
])
]):
with tf.control_dependencies([
tf.assign(self._sample_count, self._sample_count + 1),
tf.assign(self._query_count, 0)
]):
return self._sample_buffer.append(self._sample_var)
def get_unformatted_ledger(self):
"""Returns the raw sample and query values."""
return self._sample_buffer.values, self._query_buffer.values
def get_formatted_ledger(self, sess):
"""Gets the formatted query ledger.
Args:
sess: The tensorflow session in which the ledger was created.
Returns:
The query ledger as a list of `SampleEntry` instances.
"""
sample_array = sess.run(self._sample_buffer.values)
query_array = sess.run(self._query_buffer.values)
return format_ledger(sample_array, query_array)
def get_formatted_ledger_eager(self):
"""Gets the formatted query ledger.
Returns:
The query ledger as a list of `SampleEntry` instances.
"""
sample_array = self._sample_buffer.values.numpy()
query_array = self._query_buffer.values.numpy()
return format_ledger(sample_array, query_array)
class QueryWithLedger(dp_query.DPQuery):
"""A class for DP queries that record events to a `PrivacyLedger`.
`QueryWithLedger` should be the top-level query in a structure of queries that
may include sum queries, nested queries, etc. It should simply wrap another
query and contain a reference to the ledger. Any contained queries (including
those contained in the leaves of a nested query) should also contain a
reference to the same ledger object.
Only composed Gaussian sum queries with Poisson subsampling are supported.
This includes `GaussianSumQuery`, `QuantileEstimatorQuery`, and
`QuantileAdaptiveClipSumQuery`, as well as `NestedQuery` or `NormalizedQuery`
objects that contain the previous mentioned query types.
"""
def __init__(self, query,
population_size=None, selection_probability=None,
ledger=None):
"""Initializes the `QueryWithLedger`.
Args:
query: The query whose events should be recorded to the ledger. Any
subqueries (including those in the leaves of a nested query) should also
contain a reference to the same ledger given here.
population_size: An integer (may be variable) specifying the size of the
population, i.e. size of the training data used in each epoch. May be
`None` if `ledger` is specified.
selection_probability: A floating point value (may be variable) specifying
the probability each record is included in a sample under Poisson
subsampling. May be `None` if `ledger` is specified.
ledger: A `PrivacyLedger` to use. Must be specified if either of
`population_size` or `selection_probability` is `None`.
"""
self._query = query
if population_size is not None and selection_probability is not None:
self.set_ledger(PrivacyLedger(population_size, selection_probability))
elif ledger is not None:
self.set_ledger(ledger)
else:
raise ValueError('One of (population_size, selection_probability) or '
'ledger must be specified.')
@property
def ledger(self):
"""Gets the ledger that all inner queries record to."""
return self._ledger
def set_ledger(self, ledger):
"""Sets a new ledger."""
self._ledger = ledger
self._query.set_ledger(ledger)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return self._query.initial_global_state()
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return self._query.derive_sample_params(global_state)
def initial_sample_state(self, template):
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
return self._query.initial_sample_state(template)
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
return self._query.preprocess_record(params, record)
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
"""Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`."""
return self._query.accumulate_preprocessed_record(
sample_state, preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
return self._query.merge_sample_states(sample_state_1, sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`.
Besides noising and returning the result of the inner query, ensures that
the sample is recorded to the ledger.
Args:
sample_state: The sample state after all records have been accumulated.
global_state: The global state, storing long-term privacy bookkeeping.
Returns:
A tuple (result, new_global_state) where "result" is the result of the
query and "new_global_state" is the updated global state.
"""
# Ensure sample_state is fully aggregated before calling get_noised_result.
with tf.control_dependencies(tf.nest.flatten(sample_state)):
result, new_global_state = self._query.get_noised_result(
sample_state, global_state)
# Ensure inner queries have recorded before finalizing.
with tf.control_dependencies(tf.nest.flatten(result)):
finalize = self._ledger.finalize_sample()
# Ensure finalizing happens.
with tf.control_dependencies([finalize]):
return tf.nest.map_structure(tf.identity, result), new_global_state

View file

@ -1,133 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PrivacyLedger."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import nested_query
from tensorflow_privacy.privacy.dp_query import test_utils
tf.enable_eager_execution()
class PrivacyLedgerTest(tf.test.TestCase):
def test_fail_on_probability_zero(self):
with self.assertRaisesRegexp(ValueError,
'Selection probability cannot be 0.'):
privacy_ledger.PrivacyLedger(10, 0)
def test_basic(self):
ledger = privacy_ledger.PrivacyLedger(10, 0.1)
ledger.record_sum_query(5.0, 1.0)
ledger.record_sum_query(2.0, 0.5)
ledger.finalize_sample()
expected_queries = [[5.0, 1.0], [2.0, 0.5]]
formatted = ledger.get_formatted_ledger_eager()
sample = formatted[0]
self.assertAllClose(sample.population_size, 10.0)
self.assertAllClose(sample.selection_probability, 0.1)
self.assertAllClose(sorted(sample.queries), sorted(expected_queries))
def test_sum_query(self):
record1 = tf.constant([2.0, 0.0])
record2 = tf.constant([-1.0, 1.0])
population_size = tf.Variable(0)
selection_probability = tf.Variable(1.0)
query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=0.0)
query = privacy_ledger.QueryWithLedger(query, population_size,
selection_probability)
# First sample.
tf.assign(population_size, 10)
tf.assign(selection_probability, 0.1)
test_utils.run_query(query, [record1, record2])
expected_queries = [[10.0, 0.0]]
formatted = query.ledger.get_formatted_ledger_eager()
sample_1 = formatted[0]
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
# Second sample.
tf.assign(population_size, 20)
tf.assign(selection_probability, 0.2)
test_utils.run_query(query, [record1, record2])
formatted = query.ledger.get_formatted_ledger_eager()
sample_1, sample_2 = formatted
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
self.assertAllClose(sample_2.population_size, 20.0)
self.assertAllClose(sample_2.selection_probability, 0.2)
self.assertAllClose(sample_2.queries, expected_queries)
def test_nested_query(self):
population_size = tf.Variable(0)
selection_probability = tf.Variable(1.0)
query1 = gaussian_query.GaussianSumQuery(l2_norm_clip=4.0, stddev=2.0)
query2 = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=1.0)
query = nested_query.NestedQuery([query1, query2])
query = privacy_ledger.QueryWithLedger(query, population_size,
selection_probability)
record1 = [1.0, [12.0, 9.0]]
record2 = [5.0, [1.0, 2.0]]
# First sample.
tf.assign(population_size, 10)
tf.assign(selection_probability, 0.1)
test_utils.run_query(query, [record1, record2])
expected_queries = [[4.0, 2.0], [5.0, 1.0]]
formatted = query.ledger.get_formatted_ledger_eager()
sample_1 = formatted[0]
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries))
# Second sample.
tf.assign(population_size, 20)
tf.assign(selection_probability, 0.2)
test_utils.run_query(query, [record1, record2])
formatted = query.ledger.get_formatted_ledger_eager()
sample_1, sample_2 = formatted
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries))
self.assertAllClose(sample_2.population_size, 20.0)
self.assertAllClose(sample_2.selection_probability, 0.2)
self.assertAllClose(sorted(sample_2.queries), sorted(expected_queries))
if __name__ == '__main__':
tf.test.main()

View file

@ -42,12 +42,10 @@ from __future__ import print_function
import math
import sys
import numpy as np
from scipy import special
import six
########################
# LOG-SPACE ARITHMETIC #
########################
@ -102,8 +100,8 @@ def _log_print(logx):
def _log_comb(n, k):
return (special.gammaln(n + 1) -
special.gammaln(k + 1) - special.gammaln(n - k + 1))
return (special.gammaln(n + 1) - special.gammaln(k + 1) -
special.gammaln(n - k + 1))
def _compute_log_a_int(q, sigma, alpha):
@ -215,17 +213,19 @@ def _compute_delta(orders, rdp, eps):
# Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4):
logdeltas = [] # work in log space to avoid overflows
for (a, r) in zip(orders_vec, rdp_vec):
if a < 1: raise ValueError("Renyi divergence order must be >=1.")
if r < 0: raise ValueError("Renyi divergence must be >=0.")
if a < 1:
raise ValueError("Renyi divergence order must be >=1.")
if r < 0:
raise ValueError("Renyi divergence must be >=0.")
# For small alpha, we are better of with bound via KL divergence:
# delta <= sqrt(1-exp(-KL)).
# Take a min of the two bounds.
logdelta = 0.5*math.log1p(-math.exp(-r))
logdelta = 0.5 * math.log1p(-math.exp(-r))
if a > 1.01:
# This bound is not numerically stable as alpha->1.
# Thus we have a min value for alpha.
# The bound is also not useful for small alpha, so doesn't matter.
rdp_bound = (a - 1) * (r - eps + math.log1p(-1/a)) - math.log(a)
rdp_bound = (a - 1) * (r - eps + math.log1p(-1 / a)) - math.log(a)
logdelta = min(logdelta, rdp_bound)
logdeltas.append(logdelta)
@ -264,8 +264,10 @@ def _compute_eps(orders, rdp, delta):
# Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1).
eps_vec = []
for (a, r) in zip(orders_vec, rdp_vec):
if a < 1: raise ValueError("Renyi divergence order must be >=1.")
if r < 0: raise ValueError("Renyi divergence must be >=0.")
if a < 1:
raise ValueError("Renyi divergence order must be >=1.")
if r < 0:
raise ValueError("Renyi divergence must be >=0.")
if delta**2 + math.expm1(-r) >= 0:
# In this case, we can simply bound via KL divergence:
@ -378,7 +380,7 @@ def compute_rdp(q, noise_multiplier, steps, orders):
Args:
q: The sampling rate.
noise_multiplier: The ratio of the standard deviation of the Gaussian noise
to the l2-sensitivity of the function to which it is added.
to the l2-sensitivity of the function to which it is added.
steps: The number of steps.
orders: An array (or a scalar) of RDP orders.
@ -388,8 +390,8 @@ def compute_rdp(q, noise_multiplier, steps, orders):
if np.isscalar(orders):
rdp = _compute_rdp(q, noise_multiplier, orders)
else:
rdp = np.array([_compute_rdp(q, noise_multiplier, order)
for order in orders])
rdp = np.array(
[_compute_rdp(q, noise_multiplier, order) for order in orders])
return rdp * steps
@ -537,8 +539,8 @@ def _compute_rdp_sample_without_replacement_int(q, sigma, alpha):
return log_a
def compute_heterogenous_rdp(sampling_probabilities, noise_multipliers,
steps_list, orders):
def compute_heterogeneous_rdp(sampling_probabilities, noise_multipliers,
steps_list, orders):
"""Computes RDP of Heteregoneous Applications of Sampled Gaussian Mechanisms.
Args:
@ -572,8 +574,8 @@ def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
target_eps: If not `None`, the epsilon for which we compute the
corresponding delta.
target_delta: If not `None`, the delta for which we compute the
corresponding epsilon. Exactly one of `target_eps` and `target_delta`
must be `None`.
corresponding epsilon. Exactly one of `target_eps` and `target_delta` must
be `None`.
Returns:
A tuple of epsilon, delta, and the optimal order.
@ -595,24 +597,3 @@ def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
else:
eps, opt_order = _compute_eps(orders, rdp, target_delta)
return eps, target_delta, opt_order
def compute_rdp_from_ledger(ledger, orders):
"""Computes RDP of Sampled Gaussian Mechanism from ledger.
Args:
ledger: A formatted privacy ledger.
orders: An array (or a scalar) of RDP orders.
Returns:
RDP at all orders. Can be `np.inf`.
"""
total_rdp = np.zeros_like(orders, dtype=float)
for sample in ledger:
# Compute equivalent z from l2_clip_bounds and noise stddevs in sample.
# See https://arxiv.org/pdf/1812.06210.pdf for derivation of this formula.
effective_z = sum([
(q.noise_stddev / q.l2_norm_bound)**-2 for q in sample.queries])**-0.5
total_rdp += compute_rdp(
sample.selection_probability, effective_z, 1, orders)
return total_rdp

View file

@ -21,7 +21,6 @@ from __future__ import print_function
import math
import sys
from absl.testing import absltest
from absl.testing import parameterized
from mpmath import exp
from mpmath import inf
@ -31,7 +30,6 @@ from mpmath import quad
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.analysis import rdp_accountant
@ -87,9 +85,9 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
steps_list = [1, 1]
orders = 20
self.assertEqual(
rdp_accountant.compute_heterogenous_rdp(sampling_probabilities,
noise_multipliers, steps_list,
orders), 0.1)
rdp_accountant.compute_heterogeneous_rdp(sampling_probabilities,
noise_multipliers, steps_list,
orders), 0.1)
def test_compute_rdp_no_data(self):
# q = 0
@ -121,16 +119,47 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
[6.5007e-04, 1.0854e-03, 2.1808e-03, 2.3846e-02, 1.6742e+02, np.inf],
rtol=1e-4)
params = ({'q': 1e-7, 'sigma': .1, 'order': 1.01},
{'q': 1e-6, 'sigma': .1, 'order': 256},
{'q': 1e-5, 'sigma': .1, 'order': 256.1},
{'q': 1e-6, 'sigma': 1, 'order': 27},
{'q': 1e-4, 'sigma': 1., 'order': 1.5},
{'q': 1e-3, 'sigma': 1., 'order': 2},
{'q': .01, 'sigma': 10, 'order': 20},
{'q': .1, 'sigma': 100, 'order': 20.5},
{'q': .99, 'sigma': .1, 'order': 256},
{'q': .999, 'sigma': 100, 'order': 256.1})
params = ({
'q': 1e-7,
'sigma': .1,
'order': 1.01
}, {
'q': 1e-6,
'sigma': .1,
'order': 256
}, {
'q': 1e-5,
'sigma': .1,
'order': 256.1
}, {
'q': 1e-6,
'sigma': 1,
'order': 27
}, {
'q': 1e-4,
'sigma': 1.,
'order': 1.5
}, {
'q': 1e-3,
'sigma': 1.,
'order': 2
}, {
'q': .01,
'sigma': 10,
'order': 20
}, {
'q': .1,
'sigma': 100,
'order': 20.5
}, {
'q': .99,
'sigma': .1,
'order': 256
}, {
'q': .999,
'sigma': 100,
'order': 256.1
})
# pylint:disable=undefined-variable
@parameterized.parameters(p for p in params)
@ -152,7 +181,8 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
self.assertAlmostEqual(eps, 1.32783806176)
# Second test for Gaussian noise (with no subsampling):
orders = [0.001*i for i in range(1000, 100000)] # Pick fine set of orders.
orders = [0.001 * i for i in range(1000, 100000)
] # Pick fine set of orders.
rdp = rdp_accountant.compute_rdp(1, 4.530877117, 1, orders)
# Scale is chosen to obtain exactly (1,1e-6)-DP.
eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-6)
@ -168,7 +198,7 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
self.assertAlmostEqual(delta, 1e-5)
# Second test for Gaussian noise (with no subsampling):
orders = [0.001*i for i in range(1000, 100000)] # Pick fine set of order.
orders = [0.001 * i for i in range(1000, 100000)] # Pick fine set of order.
rdp = rdp_accountant.compute_rdp(1, 4.530877117, 1, orders)
# Scale is chosen to obtain exactly (1,1e-6)-DP.
_, delta, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_eps=1)
@ -178,17 +208,13 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
orders = (1.25, 1.5, 1.75, 2., 2.5, 3., 4., 5., 6., 7., 8., 10., 12., 14.,
16., 20., 24., 28., 32., 64., 256.)
rdp = rdp_accountant.compute_rdp(q=1e-4,
noise_multiplier=.4,
steps=40000,
orders=orders)
rdp = rdp_accountant.compute_rdp(
q=1e-4, noise_multiplier=.4, steps=40000, orders=orders)
eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-6)
rdp += rdp_accountant.compute_rdp(q=0.1,
noise_multiplier=2,
steps=100,
orders=orders)
rdp += rdp_accountant.compute_rdp(
q=0.1, noise_multiplier=2, steps=100, orders=orders)
eps, _, _ = rdp_accountant.get_privacy_spent(orders, rdp, target_delta=1e-5)
# These tests use the old RDP -> approx DP conversion
# self.assertAlmostEqual(eps, 8.509656, places=5)
@ -217,43 +243,26 @@ class TestGaussianMoments(tf.test.TestCase, parameterized.TestCase):
def test_get_privacy_spent_gaussian(self):
# Compare the optimal bound for Gaussian with the one derived from RDP.
# Also compare the RDP upper bound with the "standard" upper bound.
orders = [0.1*x for x in range(10, 505)]
eps_vec = [0.1*x for x in range(500)]
orders = [0.1 * x for x in range(10, 505)]
eps_vec = [0.1 * x for x in range(500)]
rdp = rdp_accountant.compute_rdp(1, 1, 1, orders)
for eps in eps_vec:
_, delta, _ = rdp_accountant.get_privacy_spent(orders, rdp,
target_eps=eps)
_, delta, _ = rdp_accountant.get_privacy_spent(
orders, rdp, target_eps=eps)
# For comparison, we compute the optimal guarantee for Gaussian
# using https://arxiv.org/abs/1805.06530 Theorem 8 (in v2).
delta0 = math.erfc((eps-.5)/math.sqrt(2))/2
delta0 = delta0 - math.exp(eps)*math.erfc((eps+.5)/math.sqrt(2))/2
self.assertLessEqual(delta0, delta+1e-300) # need tolerance 10^-300
delta0 = math.erfc((eps - .5) / math.sqrt(2)) / 2
delta0 = delta0 - math.exp(eps) * math.erfc((eps + .5) / math.sqrt(2)) / 2
self.assertLessEqual(delta0, delta + 1e-300) # need tolerance 10^-300
# Compute the "standard" upper bound, which should be an upper bound.
# Note, if orders is too sparse, this will NOT be an upper bound.
if eps >= 0.5:
delta1 = math.exp(-0.5*(eps-0.5)**2)
delta1 = math.exp(-0.5 * (eps - 0.5)**2)
else:
delta1 = 1
self.assertLessEqual(delta, delta1+1e-300)
def test_compute_rdp_from_ledger(self):
orders = range(2, 33)
q = 0.1
n = 1000
l2_norm_clip = 3.14159
noise_stddev = 2.71828
steps = 3
query_entry = privacy_ledger.GaussianSumQueryEntry(
l2_norm_clip, noise_stddev)
ledger = [privacy_ledger.SampleEntry(n, q, [query_entry])] * steps
z = noise_stddev / l2_norm_clip
rdp = rdp_accountant.compute_rdp(q, z, steps, orders)
rdp_from_ledger = rdp_accountant.compute_rdp_from_ledger(ledger, orders)
self.assertSequenceAlmostEqual(rdp, rdp_from_ledger)
self.assertLessEqual(delta, delta1 + 1e-300)
if __name__ == '__main__':
absltest.main()
tf.test.main()

View file

@ -0,0 +1,614 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Privacy accountant that uses Renyi differential privacy."""
import math
from typing import Collection, Optional
import numpy as np
from scipy import special
import six
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.analysis import privacy_accountant
NeighborRel = privacy_accountant.NeighboringRelation
def _log_add(logx, logy):
"""Adds two numbers in the log space."""
a, b = min(logx, logy), max(logx, logy)
if a == -np.inf: # adding 0
return b
# Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b)
return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1)
def _log_sub(logx, logy):
"""Subtracts two numbers in the log space. Answer must be non-negative."""
if logx < logy:
raise ValueError('The result of subtraction must be non-negative.')
if logy == -np.inf: # subtracting 0
return logx
if logx == logy:
return -np.inf # 0 is represented as -np.inf in the log space.
try:
# Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y).
return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1
except OverflowError:
return logx
def _log_sub_sign(logx, logy):
"""Returns log(exp(logx)-exp(logy)) and its sign."""
if logx > logy:
s = True
mag = logx + np.log(1 - np.exp(logy - logx))
elif logx < logy:
s = False
mag = logy + np.log(1 - np.exp(logx - logy))
else:
s = True
mag = -np.inf
return s, mag
def _log_comb(n, k):
"""Computes log of binomial coefficient."""
return (special.gammaln(n + 1) - special.gammaln(k + 1) -
special.gammaln(n - k + 1))
def _compute_log_a_int(q, sigma, alpha):
"""Computes log(A_alpha) for integer alpha, 0 < q < 1."""
assert isinstance(alpha, six.integer_types)
# Initialize with 0 in the log space.
log_a = -np.inf
for i in range(alpha + 1):
log_coef_i = (
_log_comb(alpha, i) + i * math.log(q) + (alpha - i) * math.log(1 - q))
s = log_coef_i + (i * i - i) / (2 * (sigma**2))
log_a = _log_add(log_a, s)
return float(log_a)
def _compute_log_a_frac(q, sigma, alpha):
"""Computes log(A_alpha) for fractional alpha, 0 < q < 1."""
# The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are
# initialized to 0 in the log space:
log_a0, log_a1 = -np.inf, -np.inf
i = 0
z0 = sigma**2 * math.log(1 / q - 1) + .5
while True: # do ... until loop
coef = special.binom(alpha, i)
log_coef = math.log(abs(coef))
j = alpha - i
log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q)
log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q)
log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma))
log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma))
log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0
log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1
if coef > 0:
log_a0 = _log_add(log_a0, log_s0)
log_a1 = _log_add(log_a1, log_s1)
else:
log_a0 = _log_sub(log_a0, log_s0)
log_a1 = _log_sub(log_a1, log_s1)
i += 1
if max(log_s0, log_s1) < -30:
break
return _log_add(log_a0, log_a1)
def _log_erfc(x):
"""Computes log(erfc(x)) with high accuracy for large x."""
try:
return math.log(2) + special.log_ndtr(-x * 2**.5)
except NameError:
# If log_ndtr is not available, approximate as follows:
r = special.erfc(x)
if r == 0.0:
# Using the Laurent series at infinity for the tail of the erfc function:
# erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5)
# To verify in Mathematica:
# Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}]
return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 +
.625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8)
else:
return math.log(r)
def _compute_delta(orders, rdp, epsilon):
"""Compute delta given a list of RDP values and target epsilon.
Args:
orders: An array of orders.
rdp: An array of RDP guarantees.
epsilon: The target epsilon.
Returns:
Optimal delta.
Raises:
ValueError: If input is malformed.
"""
if epsilon < 0:
raise ValueError(f'Epsilon cannot be negative. Found {epsilon}.')
if len(orders) != len(rdp):
raise ValueError('Input lists must have the same length.')
# Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3):
# delta = min( np.exp((rdp - epsilon) * (orders - 1)) )
# Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4):
logdeltas = [] # work in log space to avoid overflows
for (a, r) in zip(orders, rdp):
if a < 1:
raise ValueError(f'Renyi divergence order must be at least 1. Found {a}.')
if r < 0:
raise ValueError(f'Renyi divergence cannot be negative. Found {r}.')
# For small alpha, we are better of with bound via KL divergence:
# delta <= sqrt(1-exp(-KL)).
# Take a min of the two bounds.
if r == 0:
logdelta = -np.inf
else:
logdelta = 0.5 * math.log1p(-math.exp(-r))
if a > 1.01:
# This bound is not numerically stable as alpha->1.
# Thus we have a min value for alpha.
# The bound is also not useful for small alpha, so doesn't matter.
rdp_bound = (a - 1) * (r - epsilon + math.log1p(-1 / a)) - math.log(a)
logdelta = min(logdelta, rdp_bound)
logdeltas.append(logdelta)
return min(math.exp(np.min(logdeltas)), 1.)
def _compute_epsilon(orders, rdp, delta):
"""Compute epsilon given a list of RDP values and target delta.
Args:
orders: An array of orders.
rdp: An array of RDP guarantees.
delta: The target delta. Must be >= 0.
Returns:
Optimal epsilon.
Raises:
ValueError: If input is malformed.
"""
if delta < 0:
raise ValueError(f'Delta cannot be negative. Found {delta}.')
if delta == 0:
if all(r == 0 for r in rdp):
return 0
else:
return np.inf
if len(orders) != len(rdp):
raise ValueError('Input lists must have the same length.')
# Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3):
# epsilon = min( rdp - math.log(delta) / (orders - 1) )
# Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4).
# Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1).
eps = []
for (a, r) in zip(orders, rdp):
if a < 1:
raise ValueError(f'Renyi divergence order must be at least 1. Found {a}.')
if r < 0:
raise ValueError(f'Renyi divergence cannot be negative. Found {r}.')
if delta**2 + math.expm1(-r) > 0:
# In this case, we can simply bound via KL divergence:
# delta <= sqrt(1-exp(-KL)).
epsilon = 0 # No need to try further computation if we have epsilon = 0.
elif a > 1.01:
# This bound is not numerically stable as alpha->1.
# Thus we have a min value of alpha.
# The bound is also not useful for small alpha, so doesn't matter.
epsilon = r + math.log1p(-1 / a) - math.log(delta * a) / (a - 1)
else:
# In this case we can't do anything. E.g., asking for delta = 0.
epsilon = np.inf
eps.append(epsilon)
return max(0, np.min(eps))
def _stable_inplace_diff_in_log(vec, signs, n=-1):
"""Replaces the first n-1 dims of vec with the log of abs difference operator.
Args:
vec: numpy array of floats with size larger than 'n'
signs: Optional numpy array of bools with the same size as vec in case one
needs to compute partial differences vec and signs jointly describe a
vector of real numbers' sign and abs in log scale.
n: Optonal upper bound on number of differences to compute. If negative, all
differences are computed.
Returns:
The first n-1 dimension of vec and signs will store the log-abs and sign of
the difference.
Raises:
ValueError: If input is malformed.
"""
assert vec.shape == signs.shape
if n < 0:
n = np.max(vec.shape) - 1
else:
assert np.max(vec.shape) >= n + 1
for j in range(0, n, 1):
if signs[j] == signs[j + 1]: # When the signs are the same
# if the signs are both positive, then we can just use the standard one
signs[j], vec[j] = _log_sub_sign(vec[j + 1], vec[j])
# otherwise, we do that but toggle the sign
if not signs[j + 1]:
signs[j] = ~signs[j]
else: # When the signs are different.
vec[j] = _log_add(vec[j], vec[j + 1])
signs[j] = signs[j + 1]
def _get_forward_diffs(fun, n):
"""Computes up to nth order forward difference evaluated at 0.
See Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf
Args:
fun: Function to compute forward differences of.
n: Number of differences to compute.
Returns:
Pair (deltas, signs_deltas) of the log deltas and their signs.
"""
func_vec = np.zeros(n + 3)
signs_func_vec = np.ones(n + 3, dtype=bool)
# ith coordinate of deltas stores log(abs(ith order discrete derivative))
deltas = np.zeros(n + 2)
signs_deltas = np.zeros(n + 2, dtype=bool)
for i in range(1, n + 3, 1):
func_vec[i] = fun(1.0 * (i - 1))
for i in range(0, n + 2, 1):
# Diff in log scale
_stable_inplace_diff_in_log(func_vec, signs_func_vec, n=n + 2 - i)
deltas[i] = func_vec[0]
signs_deltas[i] = signs_func_vec[0]
return deltas, signs_deltas
def _compute_log_a(q, noise_multiplier, alpha):
if float(alpha).is_integer():
return _compute_log_a_int(q, noise_multiplier, int(alpha))
else:
return _compute_log_a_frac(q, noise_multiplier, alpha)
def _compute_rdp_poisson_subsampled_gaussian(q, noise_multiplier, orders):
"""Computes RDP of the Poisson sampled Gaussian mechanism.
Args:
q: The sampling rate.
noise_multiplier: The ratio of the standard deviation of the Gaussian noise
to the l2-sensitivity of the function to which it is added.
orders: An array of RDP orders.
Returns:
The RDPs at all orders. Can be `np.inf`.
"""
def compute_one_order(q, alpha):
if np.isinf(alpha) or noise_multiplier == 0:
return np.inf
if q == 0:
return 0
if q == 1.:
return alpha / (2 * noise_multiplier**2)
return _compute_log_a(q, noise_multiplier, alpha) / (alpha - 1)
return np.array([compute_one_order(q, order) for order in orders])
def _compute_rdp_sample_wor_gaussian(q, noise_multiplier, orders):
"""Computes RDP of Gaussian mechanism using sampling without replacement.
This function applies to the following schemes:
1. Sampling w/o replacement: Sample a uniformly random subset of size m = q*n.
2. ``Replace one data point'' version of differential privacy, i.e., n is
considered public information.
Reference: Theorem 27 of https://arxiv.org/pdf/1808.00087.pdf (A strengthened
version applies subsampled-Gaussian mechanism.)
- Wang, Balle, Kasiviswanathan. "Subsampled Renyi Differential Privacy and
Analytical Moments Accountant." AISTATS'2019.
Args:
q: The sampling proportion = m / n. Assume m is an integer <= n.
noise_multiplier: The ratio of the standard deviation of the Gaussian noise
to the l2-sensitivity of the function to which it is added.
orders: An array of RDP orders.
Returns:
The RDPs at all orders, can be np.inf.
"""
return np.array([
_compute_rdp_sample_wor_gaussian_scalar(q, noise_multiplier, order)
for order in orders
])
def _compute_rdp_sample_wor_gaussian_scalar(q, sigma, alpha):
"""Compute RDP of the Sampled Gaussian mechanism at order alpha.
Args:
q: The sampling proportion = m / n. Assume m is an integer <= n.
sigma: The std of the additive Gaussian noise.
alpha: The order at which RDP is computed.
Returns:
RDP at alpha, can be np.inf.
"""
assert (q <= 1) and (q >= 0) and (alpha >= 1)
if q == 0:
return 0
if q == 1.:
return alpha / (2 * sigma**2)
if np.isinf(alpha):
return np.inf
if float(alpha).is_integer():
return _compute_rdp_sample_wor_gaussian_int(q, sigma, int(alpha)) / (
alpha - 1)
else:
# When alpha not an integer, we apply Corollary 10 of [WBK19] to interpolate
# the CGF and obtain an upper bound
alpha_f = math.floor(alpha)
alpha_c = math.ceil(alpha)
x = _compute_rdp_sample_wor_gaussian_int(q, sigma, alpha_f)
y = _compute_rdp_sample_wor_gaussian_int(q, sigma, alpha_c)
t = alpha - alpha_f
return ((1 - t) * x + t * y) / (alpha - 1)
def _compute_rdp_sample_wor_gaussian_int(q, sigma, alpha):
"""Compute log(A_alpha) for integer alpha, subsampling without replacement.
When alpha is smaller than max_alpha, compute the bound Theorem 27 exactly,
otherwise compute the bound with Stirling approximation.
Args:
q: The sampling proportion = m / n. Assume m is an integer <= n.
sigma: The std of the additive Gaussian noise.
alpha: The order at which RDP is computed.
Returns:
RDP at alpha, can be np.inf.
"""
max_alpha = 256
assert isinstance(alpha, six.integer_types)
if np.isinf(alpha):
return np.inf
elif alpha == 1:
return 0
def cgf(x):
# Return rdp(x+1)*x, the rdp of Gaussian mechanism is alpha/(2*sigma**2)
return x * 1.0 * (x + 1) / (2.0 * sigma**2)
def func(x):
# Return the rdp of Gaussian mechanism
return 1.0 * x / (2.0 * sigma**2)
# Initialize with 1 in the log space.
log_a = 0
# Calculates the log term when alpha = 2
log_f2m1 = func(2.0) + np.log(1 - np.exp(-func(2.0)))
if alpha <= max_alpha:
# We need forward differences of exp(cgf)
# The following line is the numerically stable way of implementing it.
# The output is in polar form with logarithmic magnitude
deltas, _ = _get_forward_diffs(cgf, alpha)
# Compute the bound exactly requires book keeping of O(alpha**2)
for i in range(2, alpha + 1):
if i == 2:
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(
np.log(4) + log_f2m1,
func(2.0) + np.log(2))
elif i > 2:
delta_lo = deltas[int(2 * np.floor(i / 2.0)) - 1]
delta_hi = deltas[int(2 * np.ceil(i / 2.0)) - 1]
s = np.log(4) + 0.5 * (delta_lo + delta_hi)
s = np.minimum(s, np.log(2) + cgf(i - 1))
s += i * np.log(q) + _log_comb(alpha, i)
log_a = _log_add(log_a, s)
return float(log_a)
else:
# Compute the bound with stirling approximation. Everything is O(x) now.
for i in range(2, alpha + 1):
if i == 2:
s = 2 * np.log(q) + _log_comb(alpha, 2) + np.minimum(
np.log(4) + log_f2m1,
func(2.0) + np.log(2))
else:
s = np.log(2) + cgf(i - 1) + i * np.log(q) + _log_comb(alpha, i)
log_a = _log_add(log_a, s)
return log_a
def _effective_gaussian_noise_multiplier(event: dp_event.DpEvent):
"""Determines the effective noise multiplier of nested structure of Gaussians.
A series of Gaussian queries on the same data can be reexpressed as a single
query with pre- and post- processing. For details, see section 3 of
https://arxiv.org/pdf/1812.06210.pdf.
Args:
event: A `dp_event.DpEvent`. In order for conversion to be successful it
must consist of a single `dp_event.GaussianDpEvent`, or a nested structure
of `dp_event.ComposedDpEvent` and/or `dp_event.SelfComposedDpEvent`
bottoming out in `dp_event.GaussianDpEvent`s.
Returns:
The noise multiplier of the equivalent `dp_event.GaussianDpEvent`, or None
if the input event was not a `dp_event.GaussianDpEvent` or a nested
structure of `dp_event.ComposedDpEvent` and/or
`dp_event.SelfComposedDpEvent` bottoming out in `dp_event.GaussianDpEvent`s.
"""
if isinstance(event, dp_event.GaussianDpEvent):
return event.noise_multiplier
elif isinstance(event, dp_event.ComposedDpEvent):
sum_sigma_inv_sq = 0
for e in event.events:
sigma = _effective_gaussian_noise_multiplier(e)
if sigma is None:
return None
sum_sigma_inv_sq += sigma**-2
return sum_sigma_inv_sq**-0.5
elif isinstance(event, dp_event.SelfComposedDpEvent):
sigma = _effective_gaussian_noise_multiplier(event.event)
return None if sigma is None else (event.count * sigma**-2)**-0.5
else:
return None
class RdpAccountant(privacy_accountant.PrivacyAccountant):
"""Privacy accountant that uses Renyi differential privacy."""
def __init__(
self,
orders: Optional[Collection[float]] = None,
neighboring_relation: NeighborRel = NeighborRel.ADD_OR_REMOVE_ONE,
):
super(RdpAccountant, self).__init__(neighboring_relation)
if orders is None:
# Default orders chosen to give good coverage for Gaussian mechanism in
# the privacy regime of interest. In the future, more orders might be
# added, in particular, fractional orders between 1.0 and 10.0 or so.
orders = [
2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 20, 24, 28, 32, 48, 64, 128,
256, 512, 1024
]
self._orders = np.array(orders)
self._rdp = np.zeros_like(orders, dtype=np.float64)
def supports(self, event: dp_event.DpEvent) -> bool:
return self._maybe_compose(event, 0, False)
def _compose(self, event: dp_event.DpEvent, count: int = 1):
self._maybe_compose(event, count, True)
def _maybe_compose(self, event: dp_event.DpEvent, count: int,
do_compose: bool) -> bool:
"""Traverses `event` and performs composition if `do_compose` is True.
If `do_compose` is False, can be used to check whether composition is
supported.
Args:
event: A `DpEvent` to process.
count: The number of times to compose the event.
do_compose: Whether to actually perform the composition.
Returns:
True if event is supported, otherwise False.
"""
if isinstance(event, dp_event.NoOpDpEvent):
return True
elif isinstance(event, dp_event.NonPrivateDpEvent):
if do_compose:
self._rdp += np.inf
return True
elif isinstance(event, dp_event.SelfComposedDpEvent):
return self._maybe_compose(event.event, event.count * count, do_compose)
elif isinstance(event, dp_event.ComposedDpEvent):
return all(
self._maybe_compose(e, count, do_compose) for e in event.events)
elif isinstance(event, dp_event.GaussianDpEvent):
if do_compose:
self._rdp += count * _compute_rdp_poisson_subsampled_gaussian(
q=1.0, noise_multiplier=event.noise_multiplier, orders=self._orders)
return True
elif isinstance(event, dp_event.PoissonSampledDpEvent):
if self._neighboring_relation is not NeighborRel.ADD_OR_REMOVE_ONE:
return False
gaussian_noise_multiplier = _effective_gaussian_noise_multiplier(
event.event)
if gaussian_noise_multiplier is None:
return False
if do_compose:
self._rdp += count * _compute_rdp_poisson_subsampled_gaussian(
q=event.sampling_probability,
noise_multiplier=gaussian_noise_multiplier,
orders=self._orders)
return True
elif isinstance(event, dp_event.SampledWithoutReplacementDpEvent):
if self._neighboring_relation is not NeighborRel.REPLACE_ONE:
return False
gaussian_noise_multiplier = _effective_gaussian_noise_multiplier(
event.event)
if gaussian_noise_multiplier is None:
return False
if do_compose:
self._rdp += count * _compute_rdp_sample_wor_gaussian(
q=event.sample_size / event.source_dataset_size,
noise_multiplier=gaussian_noise_multiplier,
orders=self._orders)
return True
else:
# Unsupported event (including `UnsupportedDpEvent`).
return False
def get_epsilon(self, target_delta: float) -> float:
return _compute_epsilon(self._orders, self._rdp, target_delta)
def get_delta(self, target_epsilon: float) -> float:
return _compute_delta(self._orders, self._rdp, target_epsilon)

View file

@ -0,0 +1,355 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for rdp_privacy_accountant."""
import math
import sys
from absl.testing import absltest
from absl.testing import parameterized
import mpmath
import numpy as np
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.analysis import privacy_accountant
from tensorflow_privacy.privacy.analysis import privacy_accountant_test
from tensorflow_privacy.privacy.analysis import rdp_privacy_accountant
def _get_test_rdp(event, count=1):
accountant = rdp_privacy_accountant.RdpAccountant(orders=[2.71828])
accountant.compose(event, count)
return accountant._rdp[0]
def _log_float_mp(x):
# Convert multi-precision input to float log space.
if x >= sys.float_info.min:
return float(mpmath.log(x))
else:
return -np.inf
def _compute_a_mp(sigma, q, alpha):
"""Compute A_alpha for arbitrary alpha by numerical integration."""
def mu0(x):
return mpmath.npdf(x, mu=0, sigma=sigma)
def _mu_over_mu0(x, q, sigma):
return (1 - q) + q * mpmath.exp((2 * x - 1) / (2 * sigma**2))
def a_alpha_fn(z):
return mu0(z) * _mu_over_mu0(z, q, sigma)**alpha
bounds = (-mpmath.inf, mpmath.inf)
a_alpha, _ = mpmath.quad(a_alpha_fn, bounds, error=True, maxdegree=8)
return a_alpha
class RdpPrivacyAccountantTest(privacy_accountant_test.PrivacyAccountantTest,
parameterized.TestCase):
def _make_test_accountants(self):
return [
rdp_privacy_accountant.RdpAccountant(
[2.0], privacy_accountant.NeighboringRelation.ADD_OR_REMOVE_ONE),
rdp_privacy_accountant.RdpAccountant(
[2.0], privacy_accountant.NeighboringRelation.REPLACE_ONE)
]
def test_supports(self):
aor_accountant = rdp_privacy_accountant.RdpAccountant(
[2.0], privacy_accountant.NeighboringRelation.ADD_OR_REMOVE_ONE)
ro_accountant = rdp_privacy_accountant.RdpAccountant(
[2.0], privacy_accountant.NeighboringRelation.REPLACE_ONE)
event = dp_event.GaussianDpEvent(1.0)
self.assertTrue(aor_accountant.supports(event))
self.assertTrue(ro_accountant.supports(event))
event = dp_event.SelfComposedDpEvent(dp_event.GaussianDpEvent(1.0), 6)
self.assertTrue(aor_accountant.supports(event))
self.assertTrue(ro_accountant.supports(event))
event = dp_event.ComposedDpEvent(
[dp_event.GaussianDpEvent(1.0),
dp_event.GaussianDpEvent(2.0)])
self.assertTrue(aor_accountant.supports(event))
self.assertTrue(ro_accountant.supports(event))
event = dp_event.PoissonSampledDpEvent(0.1, dp_event.GaussianDpEvent(1.0))
self.assertTrue(aor_accountant.supports(event))
self.assertFalse(ro_accountant.supports(event))
composed_gaussian = dp_event.ComposedDpEvent(
[dp_event.GaussianDpEvent(1.0),
dp_event.GaussianDpEvent(2.0)])
event = dp_event.PoissonSampledDpEvent(0.1, composed_gaussian)
self.assertTrue(aor_accountant.supports(event))
self.assertFalse(ro_accountant.supports(event))
event = dp_event.SampledWithoutReplacementDpEvent(
1000, 10, dp_event.GaussianDpEvent(1.0))
self.assertFalse(aor_accountant.supports(event))
self.assertTrue(ro_accountant.supports(event))
event = dp_event.SampledWithoutReplacementDpEvent(1000, 10,
composed_gaussian)
self.assertFalse(aor_accountant.supports(event))
self.assertTrue(ro_accountant.supports(event))
event = dp_event.SampledWithReplacementDpEvent(
1000, 10, dp_event.GaussianDpEvent(1.0))
self.assertFalse(aor_accountant.supports(event))
self.assertFalse(ro_accountant.supports(event))
def test_rdp_composition(self):
base_event = dp_event.GaussianDpEvent(3.14159)
base_rdp = _get_test_rdp(base_event)
rdp_with_count = _get_test_rdp(base_event, count=6)
self.assertAlmostEqual(rdp_with_count, base_rdp * 6)
rdp_with_self_compose = _get_test_rdp(
dp_event.SelfComposedDpEvent(base_event, 6))
self.assertAlmostEqual(rdp_with_self_compose, base_rdp * 6)
rdp_with_self_compose_and_count = _get_test_rdp(
dp_event.SelfComposedDpEvent(base_event, 2), count=3)
self.assertAlmostEqual(rdp_with_self_compose_and_count, base_rdp * 6)
rdp_with_compose = _get_test_rdp(dp_event.ComposedDpEvent([base_event] * 6))
self.assertAlmostEqual(rdp_with_compose, base_rdp * 6)
rdp_with_compose_and_self_compose = _get_test_rdp(
dp_event.ComposedDpEvent([
dp_event.SelfComposedDpEvent(base_event, 1),
dp_event.SelfComposedDpEvent(base_event, 2),
dp_event.SelfComposedDpEvent(base_event, 3)
]))
self.assertAlmostEqual(rdp_with_compose_and_self_compose, base_rdp * 6)
base_event_2 = dp_event.GaussianDpEvent(1.61803)
base_rdp_2 = _get_test_rdp(base_event_2)
rdp_with_heterogeneous_compose = _get_test_rdp(
dp_event.ComposedDpEvent([base_event, base_event_2]))
self.assertAlmostEqual(rdp_with_heterogeneous_compose,
base_rdp + base_rdp_2)
def test_zero_poisson_sample(self):
accountant = rdp_privacy_accountant.RdpAccountant([3.14159])
accountant.compose(
dp_event.PoissonSampledDpEvent(0, dp_event.GaussianDpEvent(1.0)))
self.assertEqual(accountant.get_epsilon(1e-10), 0)
self.assertEqual(accountant.get_delta(1e-10), 0)
def test_zero_fixed_batch_sample(self):
accountant = rdp_privacy_accountant.RdpAccountant(
[3.14159], privacy_accountant.NeighboringRelation.REPLACE_ONE)
accountant.compose(
dp_event.SampledWithoutReplacementDpEvent(
1000, 0, dp_event.GaussianDpEvent(1.0)))
self.assertEqual(accountant.get_epsilon(1e-10), 0)
self.assertEqual(accountant.get_delta(1e-10), 0)
def test_epsilon_non_private_gaussian(self):
accountant = rdp_privacy_accountant.RdpAccountant([3.14159])
accountant.compose(dp_event.GaussianDpEvent(0))
self.assertEqual(accountant.get_epsilon(1e-1), np.inf)
def test_compute_rdp_gaussian(self):
alpha = 3.14159
sigma = 2.71828
event = dp_event.GaussianDpEvent(sigma)
accountant = rdp_privacy_accountant.RdpAccountant(orders=[alpha])
accountant.compose(event)
self.assertAlmostEqual(accountant._rdp[0], alpha / (2 * sigma**2))
def test_compute_rdp_multi_gaussian(self):
alpha = 3.14159
sigma1, sigma2 = 2.71828, 6.28319
rdp1 = alpha / (2 * sigma1**2)
rdp2 = alpha / (2 * sigma2**2)
rdp = rdp1 + rdp2
accountant = rdp_privacy_accountant.RdpAccountant(orders=[alpha])
accountant.compose(
dp_event.PoissonSampledDpEvent(
1.0,
dp_event.ComposedDpEvent([
dp_event.GaussianDpEvent(sigma1),
dp_event.GaussianDpEvent(sigma2)
])))
self.assertAlmostEqual(accountant._rdp[0], rdp)
def test_effective_gaussian_noise_multiplier(self):
np.random.seed(0xBAD5EED)
sigmas = np.random.uniform(size=(4,))
event = dp_event.ComposedDpEvent([
dp_event.GaussianDpEvent(sigmas[0]),
dp_event.SelfComposedDpEvent(dp_event.GaussianDpEvent(sigmas[1]), 3),
dp_event.ComposedDpEvent([
dp_event.GaussianDpEvent(sigmas[2]),
dp_event.GaussianDpEvent(sigmas[3])
])
])
sigma = rdp_privacy_accountant._effective_gaussian_noise_multiplier(event)
multi_sigmas = list(sigmas) + [sigmas[1]] * 2
expected = sum(s**-2 for s in multi_sigmas)**-0.5
self.assertAlmostEqual(sigma, expected)
def test_compute_rdp_poisson_sampled_gaussian(self):
orders = [1.5, 2.5, 5, 50, 100, np.inf]
noise_multiplier = 2.5
sampling_probability = 0.01
count = 50
event = dp_event.SelfComposedDpEvent(
dp_event.PoissonSampledDpEvent(
sampling_probability, dp_event.GaussianDpEvent(noise_multiplier)),
count)
accountant = rdp_privacy_accountant.RdpAccountant(orders=orders)
accountant.compose(event)
self.assertTrue(
np.allclose(
accountant._rdp, [
6.5007e-04, 1.0854e-03, 2.1808e-03, 2.3846e-02, 1.6742e+02,
np.inf
],
rtol=1e-4))
def test_compute_epsilon_delta_pure_dp(self):
orders = range(2, 33)
rdp = [1.1 for o in orders] # Constant corresponds to pure DP.
epsilon = rdp_privacy_accountant._compute_epsilon(orders, rdp, delta=1e-5)
# Compare with epsilon computed by hand.
self.assertAlmostEqual(epsilon, 1.32783806176)
delta = rdp_privacy_accountant._compute_delta(
orders, rdp, epsilon=1.32783806176)
self.assertAlmostEqual(delta, 1e-5)
def test_compute_epsilon_delta_gaussian(self):
orders = [0.001 * i for i in range(1000, 100000)]
# noise multiplier is chosen to obtain exactly (1,1e-6)-DP.
rdp = rdp_privacy_accountant._compute_rdp_poisson_subsampled_gaussian(
1, 4.530877117, orders)
eps = rdp_privacy_accountant._compute_epsilon(orders, rdp, delta=1e-6)
self.assertAlmostEqual(eps, 1)
delta = rdp_privacy_accountant._compute_delta(orders, rdp, epsilon=1)
self.assertAlmostEqual(delta, 1e-6)
params = ({
'q': 1e-7,
'sigma': .1,
'order': 1.01
}, {
'q': 1e-6,
'sigma': .1,
'order': 256
}, {
'q': 1e-5,
'sigma': .1,
'order': 256.1
}, {
'q': 1e-6,
'sigma': 1,
'order': 27
}, {
'q': 1e-4,
'sigma': 1.,
'order': 1.5
}, {
'q': 1e-3,
'sigma': 1.,
'order': 2
}, {
'q': .01,
'sigma': 10,
'order': 20
}, {
'q': .1,
'sigma': 100,
'order': 20.5
}, {
'q': .99,
'sigma': .1,
'order': 256
}, {
'q': .999,
'sigma': 100,
'order': 256.1
})
# pylint:disable=undefined-variable
@parameterized.parameters(p for p in params)
def test_compute_log_a_equals_mp(self, q, sigma, order):
# Compare the cheap computation of log(A) with an expensive, multi-precision
# computation.
log_a = rdp_privacy_accountant._compute_log_a(q, sigma, order)
log_a_mp = _log_float_mp(_compute_a_mp(sigma, q, order))
np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4)
def test_delta_bounds_gaussian(self):
# Compare the optimal bound for Gaussian with the one derived from RDP.
# Also compare the RDP upper bound with the "standard" upper bound.
orders = [0.1 * x for x in range(10, 505)]
eps_vec = [0.1 * x for x in range(500)]
rdp = rdp_privacy_accountant._compute_rdp_poisson_subsampled_gaussian(
1, 1, orders)
for eps in eps_vec:
delta = rdp_privacy_accountant._compute_delta(orders, rdp, epsilon=eps)
# For comparison, we compute the optimal guarantee for Gaussian
# using https://arxiv.org/abs/1805.06530 Theorem 8 (in v2).
delta0 = math.erfc((eps - .5) / math.sqrt(2)) / 2
delta0 = delta0 - math.exp(eps) * math.erfc((eps + .5) / math.sqrt(2)) / 2
self.assertLessEqual(delta0, delta + 1e-300) # need tolerance 10^-300
# Compute the "standard" upper bound, which should be an upper bound.
# Note, if orders is too sparse, this will NOT be an upper bound.
if eps >= 0.5:
delta1 = math.exp(-0.5 * (eps - 0.5)**2)
else:
delta1 = 1
self.assertLessEqual(delta, delta1 + 1e-300)
def test_epsilon_delta_consistency(self):
orders = range(2, 50) # Large range of orders (helps test for overflows).
for q in [0, 0.01, 0.1, 0.8, 1.]:
for multiplier in [0.0, 0.1, 1., 10., 100.]:
event = dp_event.PoissonSampledDpEvent(
q, dp_event.GaussianDpEvent(multiplier))
accountant = rdp_privacy_accountant.RdpAccountant(orders)
accountant.compose(event)
for delta in [.99, .9, .1, .01, 1e-3, 1e-5, 1e-9, 1e-12]:
epsilon = accountant.get_epsilon(delta)
delta2 = accountant.get_delta(epsilon)
if np.isposinf(epsilon):
self.assertEqual(delta2, 1.0)
elif epsilon == 0:
self.assertLessEqual(delta2, delta)
else:
self.assertAlmostEqual(delta, delta2)
if __name__ == '__main__':
absltest.main()

View file

@ -0,0 +1,366 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DP analysis of tree aggregation.
See Appendix D of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Functionality for computing differential privacy of tree aggregation of Gaussian
mechanism. Its public interface consists of the following methods:
compute_rdp_tree_restart(
noise_multiplier: float, steps_list: Union[int, Collection[int]],
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
computes RDP for DP-FTRL-TreeRestart.
compute_rdp_single_tree(
noise_multiplier: float, total_steps: int, max_participation: int,
min_separation: int,
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
computes RDP for DP-FTRL-NoTreeRestart.
For RDP to (epsilon, delta)-DP conversion, use the following public function
described in `rdp_accountant.py`:
get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta
(or eps) given RDP at multiple orders and
a target value for eps (or delta).
Example use:
(1) DP-FTRL-TreeRestart RDP:
Suppose we use Gaussian mechanism of `noise_multiplier`; a sample may appear
at most once for every epoch and tree is restarted every epoch; the number of
leaf nodes for every epoch are tracked in `steps_list`. For `target_delta`, the
estimated epsilon is:
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
rdp = compute_rdp_tree_restart(noise_multiplier, steps_list, orders)
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
(2) DP-FTRL-NoTreeRestart RDP:
Suppose we use Gaussian mechanism of `noise_multiplier`; a sample may appear
at most `max_participation` times for a total of `total_steps` leaf nodes in a
single tree; there are at least `min_separation` leaf nodes between the two
appearance of a same sample. For `target_delta`, the estimated epsilon is:
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
rdp = compute_rdp_single_tree(noise_multiplier, total_steps,
max_participation, min_separation, orders)
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import math
from typing import Collection, Union
import numpy as np
def _compute_rdp_tree_restart(sigma, steps_list, alpha):
"""Computes RDP of the Tree Aggregation Protocol at order alpha."""
if np.isinf(alpha):
return np.inf
tree_depths = [
math.floor(math.log2(float(steps))) + 1
for steps in steps_list
if steps > 0
]
return _compute_gaussian_rdp(
alpha=alpha, sum_sensitivity_square=sum(tree_depths), sigma=sigma)
def compute_rdp_tree_restart(
noise_multiplier: float, steps_list: Union[int, Collection[int]],
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
"""Computes RDP of the Tree Aggregation Protocol for Gaussian Mechanism.
This function implements the accounting when the tree is restarted at every
epoch. See appendix D of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
noise_multiplier: A non-negative float representing the ratio of the
standard deviation of the Gaussian noise to the l2-sensitivity of the
function to which it is added.
steps_list: A scalar or a list of non-negative intergers representing the
number of steps per epoch (between two restarts).
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders. Can be `np.inf`.
"""
_check_nonnegative(noise_multiplier, "noise_multiplier")
if noise_multiplier == 0:
return np.inf
if not steps_list:
raise ValueError(
"steps_list must be a non-empty list, or a non-zero scalar, got "
f"{steps_list}.")
if np.isscalar(steps_list):
steps_list = [steps_list]
for steps in steps_list:
if steps < 0:
raise ValueError(f"Steps must be non-negative, got {steps_list}")
if np.isscalar(orders):
rdp = _compute_rdp_tree_restart(noise_multiplier, steps_list, orders)
else:
rdp = np.array([
_compute_rdp_tree_restart(noise_multiplier, steps_list, alpha)
for alpha in orders
])
return rdp
def _check_nonnegative(value: Union[int, float], name: str):
if value < 0:
raise ValueError(f"Provided {name} must be non-negative, got {value}")
def _check_possible_tree_participation(num_participation: int,
min_separation: int, start: int,
end: int, steps: int) -> bool:
"""Check if participation is possible with `min_separation` in `steps`.
This function checks if it is possible for a sample to appear
`num_participation` in `steps`, assuming there are at least `min_separation`
nodes between the appearance of the same sample in the streaming data (leaf
nodes in tree aggregation). The first appearance of the sample is after
`start` steps, and the sample won't appear in the `end` steps after the given
`steps`.
Args:
num_participation: The number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
start: The first appearance of the sample is after `start` steps.
end: The sample won't appear in the `end` steps after the given `steps`.
steps: Total number of steps (leaf nodes in tree aggregation).
Returns:
True if a sample can appear `num_participation` with given conditions.
"""
return start + (min_separation + 1) * num_participation <= steps + end
@functools.lru_cache(maxsize=None)
def _tree_sensitivity_square_sum(num_participation: int, min_separation: int,
start: int, end: int, size: int) -> float:
"""Compute the worst-case sum of sensitivtiy square for `num_participation`.
This is the key algorithm for DP accounting for DP-FTRL tree aggregation
without restart, which recurrently counts the worst-case occurence of a sample
in all the nodes in a tree. This implements a dynamic programming algorithm
that exhausts the possible `num_participation` appearance of a sample in
`size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
num_participation: The number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y size in a streaming
setting, then `min_separation=y-x-1`.
start: The first appearance of the sample is after `start` steps.
end: The sample won't appear in the `end` steps after given `size` steps.
size: Total number of steps (leaf nodes in tree aggregation).
Returns:
The worst-case sum of sensitivity square for the given input.
"""
if not _check_possible_tree_participation(num_participation, min_separation,
start, end, size):
sum_value = -np.inf
elif num_participation == 0:
sum_value = 0.
elif num_participation == 1 and size == 1:
sum_value = 1.
else:
size_log2 = math.log2(size)
max_2power = math.floor(size_log2)
if max_2power == size_log2:
sum_value = num_participation**2
max_2power -= 1
else:
sum_value = 0.
candidate_sum = []
# i is the `num_participation` in the right subtree
for i in range(num_participation + 1):
# j is the `start` in the right subtree
for j in range(min_separation + 1):
left_sum = _tree_sensitivity_square_sum(
num_participation=num_participation - i,
min_separation=min_separation,
start=start,
end=j,
size=2**max_2power)
if np.isinf(left_sum):
candidate_sum.append(-np.inf)
continue # Early pruning for dynamic programming
right_sum = _tree_sensitivity_square_sum(
num_participation=i,
min_separation=min_separation,
start=j,
end=end,
size=size - 2**max_2power)
candidate_sum.append(left_sum + right_sum)
sum_value += max(candidate_sum)
return sum_value
def _max_tree_sensitivity_square_sum(max_participation: int,
min_separation: int, steps: int) -> float:
"""Compute the worst-case sum of sensitivity square in tree aggregation.
See Appendix D.2 of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
max_participation: The maximum number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
steps: Total number of steps (leaf nodes in tree aggregation).
Returns:
The worst-case sum of sensitivity square for the given input.
"""
num_participation = max_participation
while not _check_possible_tree_participation(
num_participation, min_separation, 0, min_separation, steps):
num_participation -= 1
candidate_sum = []
for num_part in range(1, num_participation + 1):
candidate_sum.append(
_tree_sensitivity_square_sum(num_part, min_separation, 0,
min_separation, steps))
return max(candidate_sum)
def _compute_gaussian_rdp(sigma: float, sum_sensitivity_square: float,
alpha: float) -> float:
"""Computes RDP of Gaussian mechanism."""
if np.isinf(alpha):
return np.inf
return alpha * sum_sensitivity_square / (2 * sigma**2)
def compute_rdp_single_tree(
noise_multiplier: float, total_steps: int, max_participation: int,
min_separation: int,
orders: Union[float, Collection[float]]) -> Union[float, Collection[float]]:
"""Computes RDP of the Tree Aggregation Protocol for a single tree.
The accounting assume a single tree is constructed for `total_steps` leaf
nodes, where the same sample will appear at most `max_participation` times,
and there are at least `min_separation` nodes between two appearance. The key
idea is to (recurrently) count the worst-case occurence of a sample
in all the nodes in a tree, which implements a dynamic programming algorithm
that exhausts the possible `num_participation` appearance of a sample in
`steps` leaf nodes.
See Appendix D of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
noise_multiplier: A non-negative float representing the ratio of the
standard deviation of the Gaussian noise to the l2-sensitivity of the
function to which it is added.
total_steps: Total number of steps (leaf nodes in tree aggregation).
max_participation: The maximum number of times a sample can appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
orders: An array (or a scalar) of RDP orders.
Returns:
The RDPs at all orders. Can be `np.inf`.
"""
_check_nonnegative(noise_multiplier, "noise_multiplier")
if noise_multiplier == 0:
return np.inf
_check_nonnegative(total_steps, "total_steps")
_check_nonnegative(max_participation, "max_participation")
_check_nonnegative(min_separation, "min_separation")
sum_sensitivity_square = _max_tree_sensitivity_square_sum(
max_participation, min_separation, total_steps)
if np.isscalar(orders):
rdp = _compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square,
orders)
else:
rdp = np.array([
_compute_gaussian_rdp(noise_multiplier, sum_sensitivity_square, alpha)
for alpha in orders
])
return rdp
def _compute_gaussian_zcdp(sigma: float,
sum_sensitivity_square: float) -> float:
"""Computes zCDP of Gaussian mechanism."""
return sum_sensitivity_square / (2 * sigma**2)
def compute_zcdp_single_tree(
noise_multiplier: float, total_steps: int, max_participation: int,
min_separation: int) -> Union[float, Collection[float]]:
"""Computes zCDP of the Tree Aggregation Protocol for a single tree.
The accounting assume a single tree is constructed for `total_steps` leaf
nodes, where the same sample will appear at most `max_participation` times,
and there are at least `min_separation` nodes between two appearance. The key
idea is to (recurrently) count the worst-case occurence of a sample
in all the nodes in a tree, which implements a dynamic programming algorithm
that exhausts the possible `num_participation` appearance of a sample in
`steps` leaf nodes.
See Appendix D of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
The Zero-Concentrated Differential Privacy (zCDP) definition is described in
"Concentrated Differential Privacy: Simplifications, Extensions,
and Lower Bounds" https://arxiv.org/abs/1605.02065
Args:
noise_multiplier: A non-negative float representing the ratio of the
standard deviation of the Gaussian noise to the l2-sensitivity of the
function to which it is added.
total_steps: Total number of steps (leaf nodes in tree aggregation).
max_participation: The maximum number of times a sample can appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y steps in a streaming
setting, then `min_separation=y-x-1`.
Returns:
The zCDP.
"""
_check_nonnegative(noise_multiplier, "noise_multiplier")
if noise_multiplier == 0:
return np.inf
_check_nonnegative(total_steps, "total_steps")
_check_nonnegative(max_participation, "max_participation")
_check_nonnegative(min_separation, "min_separation")
sum_sensitivity_square = _max_tree_sensitivity_square_sum(
max_participation, min_separation, total_steps)
return _compute_gaussian_zcdp(noise_multiplier, sum_sensitivity_square)

View file

@ -0,0 +1,195 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for rdp_accountant.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import rdp_accountant
from tensorflow_privacy.privacy.analysis import tree_aggregation_accountant
class TreeAggregationTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('eps20', 1.13, 19.74), ('eps2', 8.83, 2.04))
def test_compute_eps_tree(self, noise_multiplier, eps):
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
# This tests is based on the StackOverflow setting in "Practical and
# Private (Deep) Learning without Sampling or Shuffling". The calculated
# epsilon could be better as the method in this package keeps improving.
steps_list, target_delta = 1600, 1e-6
rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
new_eps = rdp_accountant.get_privacy_spent(
orders, rdp, target_delta=target_delta)[0]
self.assertLess(new_eps, eps)
@parameterized.named_parameters(
('restart4', [400] * 4),
('restart2', [800] * 2),
('adaptive', [10, 400, 400, 400, 390]),
)
def test_compose_tree_rdp(self, steps_list):
noise_multiplier, orders = 0.1, 1
rdp_list = [
tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps, orders) for steps in steps_list
]
rdp_composed = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
self.assertAllClose(rdp_composed, sum(rdp_list), rtol=1e-12)
@parameterized.named_parameters(
('restart4', [400] * 4),
('restart2', [800] * 2),
('adaptive', [10, 400, 400, 400, 390]),
)
def test_compute_eps_tree_decreasing(self, steps_list):
# Test privacy epsilon decreases with noise multiplier increasing when
# keeping other parameters the same.
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
target_delta = 1e-6
prev_eps = tree_aggregation_accountant.compute_rdp_tree_restart(
0, steps_list, orders)
for noise_multiplier in [0.1 * x for x in range(1, 100, 5)]:
rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
eps = rdp_accountant.get_privacy_spent(
orders, rdp, target_delta=target_delta)[0]
self.assertLess(eps, prev_eps)
@parameterized.named_parameters(
('negative_noise', -1, 3, 1),
('empty_steps', 1, [], 1),
('negative_steps', 1, -3, 1),
)
def test_compute_rdp_tree_restart_raise(self, noise_multiplier, steps_list,
orders):
with self.assertRaisesRegex(ValueError, 'must be'):
tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, steps_list, orders)
@parameterized.named_parameters(
('t100n0.1', 100, 0.1),
('t1000n0.01', 1000, 0.01),
)
def test_no_tree_no_sampling(self, total_steps, noise_multiplier):
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
tree_rdp = tree_aggregation_accountant.compute_rdp_tree_restart(
noise_multiplier, [1] * total_steps, orders)
rdp = rdp_accountant.compute_rdp(1., noise_multiplier, total_steps, orders)
self.assertAllClose(tree_rdp, rdp, rtol=1e-12)
@parameterized.named_parameters(
('negative_noise', -1, 3, 1, 1),
('negative_steps', 0.1, -3, 1, 1),
('negative_part', 0.1, 3, -1, 1),
('negative_sep', 0.1, 3, 1, -1),
)
def test_compute_rdp_single_tree_raise(self, noise_multiplier, total_steps,
max_participation, min_separation):
orders = 1
with self.assertRaisesRegex(ValueError, 'must be'):
tree_aggregation_accountant.compute_rdp_single_tree(
noise_multiplier, total_steps, max_participation, min_separation,
orders)
@parameterized.named_parameters(
('3', 3),
('8', 8),
('11', 11),
('19', 19),
)
def test_max_tree_sensitivity_square_sum_every_step(self, steps):
max_participation, min_separation = steps, 0
# If a sample will appear in every leaf node, we can infer the total
# sensitivity by adding all the nodes.
steps_bin = bin(steps)[2:]
depth = [
len(steps_bin) - 1 - i for i, v in enumerate(steps_bin) if v == '1'
]
expected = sum([2**d * (2**(d + 1) - 1) for d in depth])
self.assertEqual(
expected,
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
max_participation, min_separation, steps))
@parameterized.named_parameters(
('11', 11),
('19', 19),
('200', 200),
)
def test_max_tree_sensitivity_square_sum_every_step_part(self, max_part):
steps, min_separation = 8, 0
assert max_part > steps
# If a sample will appear in every leaf node, we can infer the total
# sensitivity by adding all the nodes.
expected = 120
self.assertEqual(
expected,
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
max_part, min_separation, steps))
@parameterized.named_parameters(
('3', 3),
('8', 8),
('11', 11),
('19', 19),
)
def test_max_tree_sensitivity_square_sum_every_step_part2(self, steps):
max_participation, min_separation = 2, 0
# If a sample will appear twice, the worst case is to put the two nodes at
# consecutive nodes of the deepest subtree.
steps_bin = bin(steps)[2:]
depth = len(steps_bin) - 1
expected = 2 + 4 * depth
self.assertEqual(
expected,
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
max_participation, min_separation, steps))
@parameterized.named_parameters(
('test1', 1, 7, 8, 4),
('test2', 3, 3, 9, 11),
('test3', 3, 2, 7, 9),
# This is an example showing worst-case sensitivity is larger than greedy
# in "Practical and Private (Deep) Learning without Sampling or Shuffling"
# https://arxiv.org/abs/2103.00039.
('test4', 8, 2, 24, 88),
)
def test_max_tree_sensitivity_square_sum_toy(self, max_participation,
min_separation, steps, expected):
self.assertEqual(
expected,
tree_aggregation_accountant._max_tree_sensitivity_square_sum(
max_participation, min_separation, steps))
def test_compute_gaussian_zcdp(self):
for sigma in tf.random.uniform([5], minval=0.01, maxval=100).numpy():
for sum_sensitivity_square in tf.random.uniform([5],
minval=0.01,
maxval=1000).numpy():
self.assertEqual(
tree_aggregation_accountant._compute_gaussian_rdp(
sigma, sum_sensitivity_square, alpha=1),
tree_aggregation_accountant._compute_gaussian_zcdp(
sigma, sum_sensitivity_square))
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,87 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for discrete Gaussian mechanism."""
import collections
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import dp_query
class DiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery for discrete Gaussian sum queries.
For each local record, we check the L2 norm bound and add discrete Gaussian
noise. In particular, this DPQuery does not perform L2 norm clipping and the
norms of the input records are expected to be bounded.
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple('_GlobalState',
['l2_norm_bound', 'stddev'])
# pylint: disable=invalid-name
_SampleParams = collections.namedtuple('_SampleParams',
['l2_norm_bound', 'stddev'])
def __init__(self, l2_norm_bound, stddev):
"""Initializes the DiscreteGaussianSumQuery.
Args:
l2_norm_bound: The L2 norm bound to verify for each record.
stddev: The stddev of the discrete Gaussian noise added to the sum.
"""
self._l2_norm_bound = l2_norm_bound
self._stddev = stddev
def initial_global_state(self):
return self._GlobalState(
tf.cast(self._l2_norm_bound, tf.float32),
tf.cast(self._stddev, tf.float32))
def derive_sample_params(self, global_state):
return self._SampleParams(global_state.l2_norm_bound, global_state.stddev)
def preprocess_record(self, params, record):
"""Check record norm and add noise to the record."""
record_as_list = tf.nest.flatten(record)
record_as_float_list = [tf.cast(x, tf.float32) for x in record_as_list]
tf.nest.map_structure(lambda x: tf.compat.v1.assert_type(x, tf.int32),
record_as_list)
dependencies = [
tf.compat.v1.assert_less_equal(
tf.linalg.global_norm(record_as_float_list),
params.l2_norm_bound,
message=f'Global L2 norm exceeds {params.l2_norm_bound}.')
]
with tf.control_dependencies(dependencies):
return tf.nest.map_structure(tf.identity, record)
def get_noised_result(self, sample_state, global_state):
"""Adds discrete Gaussian noise to the aggregate."""
# Round up the noise as the TF discrete Gaussian sampler only takes
# integer noise stddevs for now.
ceil_stddev = tf.cast(tf.math.ceil(global_state.stddev), tf.int32)
def add_noise(v):
noised_v = v + discrete_gaussian_utils.sample_discrete_gaussian(
scale=ceil_stddev, shape=tf.shape(v), dtype=v.dtype)
# Ensure shape as TF shape inference may fail due to custom noise sampler.
return tf.ensure_shape(noised_v, v.shape)
result = tf.nest.map_structure(add_noise, sample_state)
event = dp_event.UnsupportedDpEvent()
return result, global_state, event

View file

@ -0,0 +1,148 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for DiscreteGaussianSumQuery."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import test_utils
dg_sum_query = discrete_gaussian_query.DiscreteGaussianSumQuery
def silence_tf_error_messages(func):
"""Decorator that temporarily changes the TF logging levels."""
def wrapper(*args, **kwargs):
cur_verbosity = tf.compat.v1.logging.get_verbosity()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.FATAL)
func(*args, **kwargs)
tf.compat.v1.logging.set_verbosity(cur_verbosity) # Reset verbosity.
return wrapper
class DiscreteGaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
def test_sum_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([2, 0], dtype=tf.int32)
record2 = tf.constant([-1, 1], dtype=tf.int32)
query = dg_sum_query(l2_norm_bound=10, stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [1, 1]
self.assertAllEqual(result, expected)
@parameterized.product(sample_size=[1, 3])
def test_sum_multiple_shapes(self, sample_size):
with self.cached_session() as sess:
t1 = tf.constant([2, 0], dtype=tf.int32)
t2 = tf.constant([-1, 1, 3], dtype=tf.int32)
t3 = tf.constant([-2], dtype=tf.int32)
record = [t1, t2, t3]
sample = [record] * sample_size
query = dg_sum_query(l2_norm_bound=10, stddev=0.0)
query_result, _ = test_utils.run_query(query, sample)
expected = [sample_size * t1, sample_size * t2, sample_size * t3]
result, expected = sess.run([query_result, expected])
# Use `assertAllClose` for nested structures equality (with tolerance=0).
self.assertAllClose(result, expected, atol=0)
@parameterized.product(sample_size=[1, 3])
def test_sum_nested_record_structure(self, sample_size):
with self.cached_session() as sess:
t1 = tf.constant([1, 0], dtype=tf.int32)
t2 = tf.constant([1, 1, 1], dtype=tf.int32)
t3 = tf.constant([1], dtype=tf.int32)
t4 = tf.constant([[1, 1], [1, 1]], dtype=tf.int32)
record = [t1, dict(a=t2, b=[t3, (t4, t1)])]
sample = [record] * sample_size
query = dg_sum_query(l2_norm_bound=10, stddev=0.0)
query_result, _ = test_utils.run_query(query, sample)
result = sess.run(query_result)
s = sample_size
expected = [t1 * s, dict(a=t2 * s, b=[t3 * s, (t4 * s, t1 * s)])]
# Use `assertAllClose` for nested structures equality (with tolerance=0)
self.assertAllClose(result, expected, atol=0)
def test_sum_raise_on_float_inputs(self):
with self.cached_session() as sess:
record1 = tf.constant([2, 0], dtype=tf.float32)
record2 = tf.constant([-1, 1], dtype=tf.float32)
query = dg_sum_query(l2_norm_bound=10, stddev=0.0)
with self.assertRaises(TypeError):
query_result, _ = test_utils.run_query(query, [record1, record2])
sess.run(query_result)
@parameterized.product(l2_norm_bound=[0, 3, 10, 14.1])
@silence_tf_error_messages
def test_sum_raise_on_l2_norm_excess(self, l2_norm_bound):
with self.cached_session() as sess:
record = tf.constant([10, 10], dtype=tf.int32)
query = dg_sum_query(l2_norm_bound=l2_norm_bound, stddev=0.0)
with self.assertRaises(tf.errors.InvalidArgumentError):
query_result, _ = test_utils.run_query(query, [record])
sess.run(query_result)
def test_sum_float_norm_not_rounded(self):
"""Test that the float L2 norm bound doesn't get rounded/casted to integers."""
with self.cached_session() as sess:
# A casted/rounded norm bound would be insufficient.
l2_norm_bound = 14.2
record = tf.constant([10, 10], dtype=tf.int32)
query = dg_sum_query(l2_norm_bound=l2_norm_bound, stddev=0.0)
query_result, _ = test_utils.run_query(query, [record])
result = sess.run(query_result)
expected = [10, 10]
self.assertAllEqual(result, expected)
@parameterized.product(stddev=[10, 100, 1000])
def test_noisy_sum(self, stddev):
num_trials = 1000
record_1 = tf.zeros([num_trials], dtype=tf.int32)
record_2 = tf.ones([num_trials], dtype=tf.int32)
sample = [record_1, record_2]
query = dg_sum_query(l2_norm_bound=num_trials, stddev=stddev)
result, _ = test_utils.run_query(query, sample)
sampled_noise = discrete_gaussian_utils.sample_discrete_gaussian(
scale=tf.cast(stddev, tf.int32), shape=[num_trials], dtype=tf.int32)
result, sampled_noise = self.evaluate([result, sampled_noise])
# The standard error of the stddev should be roughly sigma / sqrt(2N - 2),
# (https://stats.stackexchange.com/questions/156518) so set a rtol to give
# < 0.01% of failure (within ~4 standard errors).
rtol = 4 / np.sqrt(2 * num_trials - 2)
self.assertAllClose(np.std(result), stddev, rtol=rtol)
# Use standard error of the mean to compare percentiles.
stderr = stddev / np.sqrt(num_trials)
self.assertAllClose(
np.percentile(result, [25, 50, 75]),
np.percentile(sampled_noise, [25, 50, 75]),
atol=4 * stderr)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,142 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Util functions for drawing discrete Gaussian samples.
The following functions implement a vectorized TF version of the sampling
algorithm described in the paper:
The Discrete Gaussian for Differential Privacy
https://arxiv.org/pdf/2004.00010.pdf
Note that the exact sampling implementation should use integer and fractional
parameters only. Here, we relax this constraint a bit and use vectorized
implementations of Bernoulli and discrete Laplace sampling that can take float
parameters.
"""
import tensorflow as tf
import tensorflow_probability as tf_prob
def _sample_discrete_laplace(t, shape):
"""Sample from discrete Laplace with scale t.
This method is based on the observation that sampling from Z ~ Lap(t) is
equivalent to sampling X, Y independently from Geo(1 - exp(-1/t)) and take
Z = X - Y.
Note also that tensorflow_probability's geometric sampler is based on floating
operations and may possibly be inexact.
Args:
t: The scale of the discrete Laplace distribution.
shape: The tensor shape of the tensors drawn.
Returns:
A tensor of the specified shape filled with random values.
"""
geometric_probs = 1.0 - tf.exp(-1.0 / tf.cast(t, tf.float64))
sampler = tf_prob.distributions.Geometric(probs=geometric_probs)
return tf.cast(sampler.sample(shape) - sampler.sample(shape), tf.int64)
def _sample_bernoulli(p):
"""Sample from Bernoulli(p)."""
return tf_prob.distributions.Bernoulli(probs=p, dtype=tf.int64).sample()
def _check_input_args(scale, shape, dtype):
"""Checks the input args to the discrete Gaussian sampler."""
if tf.as_dtype(dtype) not in (tf.int32, tf.int64):
raise ValueError(
f'Only tf.int32 and tf.int64 are supported. Found dtype `{dtype}`.')
checks = [
tf.compat.v1.assert_non_negative(scale),
tf.compat.v1.assert_integer(scale)
]
with tf.control_dependencies(checks):
return tf.identity(scale), shape, dtype
def _int_square(value):
"""Avoids the TF op `Square(T=...)` for ints as sampling can happen on clients."""
return (value - 1) * (value + 1) + 1
@tf.function
def _sample_discrete_gaussian_helper(scale, shape, dtype):
"""Draw samples from discrete Gaussian, assuming scale >= 0."""
scale = tf.cast(scale, tf.int64)
sq_scale = _int_square(scale)
# Scale for discrete Laplace. The sampling algorithm should be correct
# for any discrete Laplace scale, and the original paper uses
# `dlap_scale = floor(scale) + 1`. Here we use `dlap_scale = scale` (where
# input `scale` is restricted to integers >= 1) to simplify the fraction
# below. It turns out that for integer scales >= 1, `dlap_scale = scale` gives
# a good minimum success rate of ~70%, allowing a small oversampling factor.
dlap_scale = scale
oversample_factor = 1.5
# Draw at least some samples in case we got unlucky with small input shape.
min_n = 1000
target_n = tf.reduce_prod(tf.cast(shape, tf.int64))
oversample_n = oversample_factor * tf.cast(target_n, tf.float32)
draw_n = tf.maximum(min_n, tf.cast(oversample_n, tf.int32))
accepted_n = tf.constant(0, dtype=target_n.dtype)
result = tf.zeros((0,), dtype=tf.int64)
while accepted_n < target_n:
# Since the number of samples could be different in every retry, we need to
# manually specify the shape info for TF.
tf.autograph.experimental.set_loop_options(
shape_invariants=[(result, tf.TensorShape([None]))])
# Draw samples.
samples = _sample_discrete_laplace(dlap_scale, shape=(draw_n,))
z_numer = _int_square(tf.abs(samples) - scale)
z_denom = 2 * sq_scale
bern_probs = tf.exp(-1.0 * tf.divide(z_numer, z_denom))
accept = _sample_bernoulli(bern_probs)
# Keep successful samples and increment counter.
accepted_samples = samples[tf.equal(accept, 1)]
accepted_n += tf.cast(tf.size(accepted_samples), accepted_n.dtype)
result = tf.concat([result, accepted_samples], axis=0)
# Reduce the number of draws for any retries.
draw_n = tf.cast(target_n - accepted_n, tf.float32) * oversample_factor
draw_n = tf.maximum(min_n, tf.cast(draw_n, tf.int32))
return tf.cast(tf.reshape(result[:target_n], shape), dtype)
def sample_discrete_gaussian(scale, shape, dtype=tf.int32):
"""Draws (possibly inexact) samples from the discrete Gaussian distribution.
We relax some integer constraints to use vectorized implementations of
Bernoulli and discrete Laplace sampling. Integer operations are done in
tf.int64 as TF does not have direct support for fractions.
Args:
scale: The scale of the discrete Gaussian distribution.
shape: The shape of the output tensor.
dtype: The type of the output.
Returns:
A tensor of the specified shape filled with random values.
"""
scale, shape, dtype = _check_input_args(scale, shape, dtype)
return tf.cond(
tf.equal(scale, 0), lambda: tf.zeros(shape, dtype),
lambda: _sample_discrete_gaussian_helper(scale, shape, dtype))

View file

@ -0,0 +1,275 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for discrete_gaussian_utils."""
import collections
import fractions
import math
import random
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
EXACT_SAMPLER_SEED = 4242
class DiscreteGaussianUtilsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(dtype=[tf.bool, tf.float32, tf.float64])
def test_raise_on_bad_dtype(self, dtype):
with self.assertRaises(ValueError):
_ = discrete_gaussian_utils.sample_discrete_gaussian(1, (1,), dtype)
def test_raise_on_negative_scale(self):
with self.assertRaises(tf.errors.InvalidArgumentError):
_ = discrete_gaussian_utils.sample_discrete_gaussian(-10, (1,))
def test_raise_on_float_scale(self):
with self.assertRaises(TypeError):
_ = discrete_gaussian_utils.sample_discrete_gaussian(3.14, (1,))
@parameterized.product(shape=[(), (1,), (100,), (2, 2), (3, 3, 3),
(4, 1, 1, 1)])
def test_shapes(self, shape):
samples = discrete_gaussian_utils.sample_discrete_gaussian(100, shape)
samples = self.evaluate(samples)
self.assertAllEqual(samples.shape, shape)
@parameterized.product(dtype=[tf.int32, tf.int64])
def test_dtypes(self, dtype):
samples = discrete_gaussian_utils.sample_discrete_gaussian(1, (10,), dtype)
samples = self.evaluate(samples)
# Convert output np dtypes to tf dtypes.
self.assertEqual(tf.as_dtype(samples.dtype), dtype)
def test_zero_noise(self):
scale = 0
shape = (100,)
dtype = tf.int32
samples = discrete_gaussian_utils.sample_discrete_gaussian(
scale, shape, dtype=dtype)
samples = self.evaluate(samples)
self.assertAllEqual(samples, tf.zeros(shape, dtype=dtype))
@parameterized.named_parameters([('small_scale_small_n', 10, 2000, 1, 2),
('small_scale_large_n', 10, 5000, 1, 1),
('large_scale_small_n', 50, 2000, 2, 5),
('large_scale_large_n', 50, 5000, 2, 3)])
def test_match_exact_sampler(self, scale, num_samples, mean_std_atol,
percentile_atol):
true_samples = exact_sampler(scale, num_samples)
drawn_samples = discrete_gaussian_utils.sample_discrete_gaussian(
scale=scale, shape=(num_samples,))
drawn_samples = self.evaluate(drawn_samples)
# Check mean, std, and percentiles.
self.assertAllClose(
np.mean(true_samples), np.mean(drawn_samples), atol=mean_std_atol)
self.assertAllClose(
np.std(true_samples), np.std(drawn_samples), atol=mean_std_atol)
self.assertAllClose(
np.percentile(true_samples, [10, 30, 50, 70, 90]),
np.percentile(drawn_samples, [10, 30, 50, 70, 90]),
atol=percentile_atol)
@parameterized.named_parameters([('n_1000', 1000, 5e-2),
('n_10000', 10000, 5e-3)])
def test_kl_divergence(self, num_samples, kl_tolerance):
"""Compute KL divergence betwen empirical & true distribution."""
scale = 10
sq_sigma = scale * scale
drawn_samples = discrete_gaussian_utils.sample_discrete_gaussian(
scale=scale, shape=(num_samples,))
drawn_samples = self.evaluate(drawn_samples)
value_counts = collections.Counter(drawn_samples)
kl = 0
norm_const = dgauss_normalizing_constant(sq_sigma)
for value, count in value_counts.items():
kl += count * (
math.log(count * norm_const / num_samples) + value * value /
(2.0 * sq_sigma))
kl /= num_samples
self.assertLess(kl, kl_tolerance)
def exact_sampler(scale, num_samples, seed=EXACT_SAMPLER_SEED):
"""Implementation of the exact discrete gaussian distribution sampler.
Source: https://arxiv.org/pdf/2004.00010.pdf.
Args:
scale: The scale of the discrete Gaussian.
num_samples: The number of samples to generate.
seed: The seed for the random number generator to reproduce samples.
Returns:
A numpy array of discrete Gaussian samples.
"""
def randrange(a, rng):
return rng.randrange(a)
def bern_em1(rng):
"""Sample from Bernoulli(exp(-1))."""
k = 2
while True:
if randrange(k, rng) == 0: # if Bernoulli(1/k)==1
k = k + 1
else:
return k % 2
def bern_emab1(a, b, rng):
"""Sample from Bernoulli(exp(-a/b)), assuming 0 <= a <= b."""
assert isinstance(a, int)
assert isinstance(b, int)
assert 0 <= a <= b
k = 1
while True:
if randrange(b, rng) < a and randrange(k, rng) == 0: # if Bern(a/b/k)==1
k = k + 1
else:
return k % 2
def bern_emab(a, b, rng):
"""Sample from Bernoulli(exp(-a/b)), allowing a > b."""
while a > b:
if bern_em1(rng) == 0:
return 0
a = a - b
return bern_emab1(a, b, rng)
def geometric(t, rng):
"""Sample from geometric(1-exp(-1/t))."""
assert isinstance(t, int)
assert t > 0
while True:
u = randrange(t, rng)
if bern_emab1(u, t, rng) == 1:
while bern_em1(rng) == 1:
u = u + t
return u
def dlap(t, rng):
"""Sample from discrete Laplace with scale t.
Pr[x] = exp(-|x|/t) * (exp(1/t)-1)/(exp(1/t)+1). Supported on integers.
Args:
t: The scale.
rng: The random number generator.
Returns:
A discrete Laplace sample.
"""
assert isinstance(t, int)
assert t > 0
while True:
u = geometric(t, rng)
b = randrange(2, rng)
if b == 1:
return u
elif u > 0:
return -u
def floorsqrt(x):
"""Compute floor(sqrt(x)) exactly."""
assert x >= 0
a = 0 # maintain a^2<=x.
b = 1 # maintain b^2>x.
while b * b <= x:
b = 2 * b
# Do binary search.
while a + 1 < b:
c = (a + b) // 2
if c * c <= x:
a = c
else:
b = c
return a
def dgauss(ss, num, rng):
"""Sample from discrete Gaussian.
Args:
ss: Variance proxy, squared scale, sigma^2.
num: The number of samples to generate.
rng: The random number generator.
Returns:
A list of discrete Gaussian samples.
"""
ss = fractions.Fraction(ss) # cast to rational for exact arithmetic
assert ss > 0
t = floorsqrt(ss) + 1
results = []
trials = 0
while len(results) < num:
trials = trials + 1
y = dlap(t, rng)
z = (abs(y) - ss / t)**2 / (2 * ss)
if bern_emab(z.numerator, z.denominator, rng) == 1:
results.append(y)
return results, t, trials
rng = random.Random(seed)
return np.array(dgauss(scale * scale, num_samples, rng)[0])
def dgauss_normalizing_constant(sigma_sq):
"""Compute the normalizing constant of the discrete Gaussian.
Source: https://arxiv.org/pdf/2004.00010.pdf.
Args:
sigma_sq: Variance proxy, squared scale, sigma^2.
Returns:
The normalizing constant.
"""
original = None
poisson = None
if sigma_sq <= 1:
original = 0
x = 1000
while x > 0:
original = original + math.exp(-x * x / (2.0 * sigma_sq))
x = x - 1
original = 2 * original + 1
if sigma_sq * 100 >= 1:
poisson = 0
y = 1000
while y > 0:
poisson = poisson + math.exp(-math.pi * math.pi * sigma_sq * 2 * y * y)
y = y - 1
poisson = math.sqrt(2 * math.pi * sigma_sq) * (1 + 2 * poisson)
if poisson is None:
return original
if original is None:
return poisson
scale = max(1, math.sqrt(2 * math.pi * sigma_sq))
precision = 1e-15
assert -precision * scale <= original - poisson <= precision * scale
return (original + poisson) / 2
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,111 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for distributed discrete Gaussian mechanism."""
import collections
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import dp_query
class DistributedDiscreteGaussianSumQuery(dp_query.SumAggregationDPQuery):
"""Implements DPQuery for discrete distributed Gaussian sum queries.
For each local record, we check the L2 norm bound and add discrete Gaussian
noise. In particular, this DPQuery does not perform L2 norm clipping and the
norms of the input records are expected to be bounded.
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple('_GlobalState',
['l2_norm_bound', 'local_stddev'])
# pylint: disable=invalid-name
_SampleParams = collections.namedtuple('_SampleParams',
['l2_norm_bound', 'local_stddev'])
def __init__(self, l2_norm_bound, local_stddev):
"""Initializes the DistributedDiscreteGaussianSumQuery.
Args:
l2_norm_bound: The L2 norm bound to verify for each record.
local_stddev: The stddev of the local discrete Gaussian noise.
"""
self._l2_norm_bound = l2_norm_bound
self._local_stddev = local_stddev
def initial_global_state(self):
return self._GlobalState(
tf.cast(self._l2_norm_bound, tf.float32),
tf.cast(self._local_stddev, tf.float32))
def derive_sample_params(self, global_state):
return self._SampleParams(global_state.l2_norm_bound,
global_state.local_stddev)
def _add_local_noise(self, record, local_stddev, shares=1):
"""Add local discrete Gaussian noise to the record.
Args:
record: The record to which we generate and add local noise.
local_stddev: The stddev of the local discrete Gaussian noise.
shares: Number of shares of local noise to generate. Should be 1 for each
record. This can be useful when we want to generate multiple noise
shares at once.
Returns:
The record with local noise added.
"""
# Round up the noise as the TF discrete Gaussian sampler only takes
# integer noise stddevs for now.
ceil_local_stddev = tf.cast(tf.math.ceil(local_stddev), tf.int32)
def add_noise(v):
# Adds an extra dimension for `shares` number of draws.
shape = tf.concat([[shares], tf.shape(v)], axis=0)
dgauss_noise = discrete_gaussian_utils.sample_discrete_gaussian(
scale=ceil_local_stddev, shape=shape, dtype=v.dtype)
# Sum across the number of noise shares and add it.
noised_v = v + tf.reduce_sum(dgauss_noise, axis=0)
# Set shape as TF shape inference may fail due to custom noise sampler.
noised_v.set_shape(v.shape.as_list())
return noised_v
return tf.nest.map_structure(add_noise, record)
def preprocess_record(self, params, record):
"""Check record norm and add noise to the record."""
record_as_list = tf.nest.flatten(record)
record_as_float_list = [tf.cast(x, tf.float32) for x in record_as_list]
tf.nest.map_structure(lambda x: tf.compat.v1.assert_type(x, tf.int32),
record_as_list)
dependencies = [
tf.compat.v1.assert_less_equal(
tf.linalg.global_norm(record_as_float_list),
params.l2_norm_bound,
message=f'Global L2 norm exceeds {params.l2_norm_bound}.')
]
with tf.control_dependencies(dependencies):
result = tf.cond(
tf.equal(params.local_stddev, 0), lambda: record,
lambda: self._add_local_noise(record, params.local_stddev))
return result
def get_noised_result(self, sample_state, global_state):
# Note that by directly returning the aggregate, this assumes that there
# will not be missing local noise shares during execution.
event = dp_event.UnsupportedDpEvent()
return sample_state, global_state, event

View file

@ -0,0 +1,165 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for DistributedDiscreteGaussianSumQuery."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import discrete_gaussian_utils
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import test_utils
ddg_sum_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery
def silence_tf_error_messages(func):
"""Decorator that temporarily changes the TF logging levels."""
def wrapper(*args, **kwargs):
cur_verbosity = tf.compat.v1.logging.get_verbosity()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.FATAL)
func(*args, **kwargs)
tf.compat.v1.logging.set_verbosity(cur_verbosity) # Reset verbosity.
return wrapper
class DistributedDiscreteGaussianQueryTest(tf.test.TestCase,
parameterized.TestCase):
def test_sum_no_noise(self):
with self.cached_session() as sess:
record1 = tf.constant([2, 0], dtype=tf.int32)
record2 = tf.constant([-1, 1], dtype=tf.int32)
query = ddg_sum_query(l2_norm_bound=10, local_stddev=0.0)
query_result, _ = test_utils.run_query(query, [record1, record2])
result = sess.run(query_result)
expected = [1, 1]
self.assertAllEqual(result, expected)
@parameterized.product(sample_size=[1, 3])
def test_sum_multiple_shapes(self, sample_size):
with self.cached_session() as sess:
t1 = tf.constant([2, 0], dtype=tf.int32)
t2 = tf.constant([-1, 1, 3], dtype=tf.int32)
t3 = tf.constant([-2], dtype=tf.int32)
record = [t1, t2, t3]
sample = [record] * sample_size
query = ddg_sum_query(l2_norm_bound=10, local_stddev=0.0)
query_result, _ = test_utils.run_query(query, sample)
expected = [sample_size * t1, sample_size * t2, sample_size * t3]
result, expected = sess.run([query_result, expected])
# Use `assertAllClose` for nested structures equality (with tolerance=0).
self.assertAllClose(result, expected, atol=0)
@parameterized.product(sample_size=[1, 3])
def test_sum_nested_record_structure(self, sample_size):
with self.cached_session() as sess:
t1 = tf.constant([1, 0], dtype=tf.int32)
t2 = tf.constant([1, 1, 1], dtype=tf.int32)
t3 = tf.constant([1], dtype=tf.int32)
t4 = tf.constant([[1, 1], [1, 1]], dtype=tf.int32)
record = [t1, dict(a=t2, b=[t3, (t4, t1)])]
sample = [record] * sample_size
query = ddg_sum_query(l2_norm_bound=10, local_stddev=0.0)
query_result, _ = test_utils.run_query(query, sample)
result = sess.run(query_result)
s = sample_size
expected = [t1 * s, dict(a=t2 * s, b=[t3 * s, (t4 * s, t1 * s)])]
# Use `assertAllClose` for nested structures equality (with tolerance=0)
self.assertAllClose(result, expected, atol=0)
def test_sum_raise_on_float_inputs(self):
with self.cached_session() as sess:
record1 = tf.constant([2, 0], dtype=tf.float32)
record2 = tf.constant([-1, 1], dtype=tf.float32)
query = ddg_sum_query(l2_norm_bound=10, local_stddev=0.0)
with self.assertRaises(TypeError):
query_result, _ = test_utils.run_query(query, [record1, record2])
sess.run(query_result)
@parameterized.product(l2_norm_bound=[0, 3, 10, 14.1])
@silence_tf_error_messages
def test_sum_raise_on_l2_norm_excess(self, l2_norm_bound):
with self.cached_session() as sess:
record = tf.constant([10, 10], dtype=tf.int32)
query = ddg_sum_query(l2_norm_bound=l2_norm_bound, local_stddev=0.0)
with self.assertRaises(tf.errors.InvalidArgumentError):
query_result, _ = test_utils.run_query(query, [record])
sess.run(query_result)
def test_sum_float_norm_not_rounded(self):
"""Test that the float L2 norm bound doesn't get rounded/casted to integers."""
with self.cached_session() as sess:
# A casted/rounded norm bound would be insufficient.
l2_norm_bound = 14.2
record = tf.constant([10, 10], dtype=tf.int32)
query = ddg_sum_query(l2_norm_bound=l2_norm_bound, local_stddev=0.0)
query_result, _ = test_utils.run_query(query, [record])
result = sess.run(query_result)
expected = [10, 10]
self.assertAllEqual(result, expected)
@parameterized.named_parameters([('2_local_stddev_1_record', 2, 1),
('10_local_stddev_4_records', 10, 4),
('1000_local_stddev_1_record', 1000, 1),
('1000_local_stddev_25_records', 1000, 25)])
def test_sum_local_noise_shares(self, local_stddev, num_records):
"""Test the noise level of the sum of discrete Gaussians applied locally.
The sum of discrete Gaussians is not a discrete Gaussian, but it will be
extremely close for sigma >= 2. We will thus compare the aggregated noise
to a central discrete Gaussian noise with appropriately scaled stddev with
some reasonable tolerance.
Args:
local_stddev: The stddev of the local discrete Gaussian noise.
num_records: The number of records to be aggregated.
"""
# Aggregated local noises.
num_trials = 1000
record = tf.zeros([num_trials], dtype=tf.int32)
sample = [record] * num_records
query = ddg_sum_query(l2_norm_bound=10.0, local_stddev=local_stddev)
query_result, _ = test_utils.run_query(query, sample)
# Central discrete Gaussian noise.
central_stddev = np.sqrt(num_records) * local_stddev
central_noise = discrete_gaussian_utils.sample_discrete_gaussian(
scale=tf.cast(tf.round(central_stddev), record.dtype),
shape=tf.shape(record),
dtype=record.dtype)
agg_noise, central_noise = self.evaluate([query_result, central_noise])
mean_stddev = central_stddev * np.sqrt(num_trials) / num_trials
atol = 3.5 * mean_stddev
# Use the atol for mean as a rough default atol for stddev/percentile.
self.assertAllClose(np.mean(agg_noise), np.mean(central_noise), atol=atol)
self.assertAllClose(np.std(agg_noise), np.std(central_noise), atol=atol)
self.assertAllClose(
np.percentile(agg_noise, [25, 50, 75]),
np.percentile(central_noise, [25, 50, 75]),
atol=atol)
if __name__ == '__main__':
tf.test.main()

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An interface for differentially private query mechanisms.
The DPQuery class abstracts the differential privacy mechanism needed by DP-SGD.
@ -100,18 +99,6 @@ class DPQuery(object):
__metaclass__ = abc.ABCMeta
def set_ledger(self, ledger):
"""Supplies privacy ledger to which the query can record privacy events.
The ledger should be updated with each call to get_noised_result.
Args:
ledger: A `PrivacyLedger`.
"""
del ledger
raise TypeError(
'DPQuery type %s does not support set_ledger.' % type(self).__name__)
def initial_global_state(self):
"""Returns the initial global state for the DPQuery.
@ -155,7 +142,6 @@ class DPQuery(object):
as a template to create the initial sample state. It is assumed that the
leaves of the structure are python scalars or some type that has
properties `shape` and `dtype`.
Returns: An initial sample state.
"""
pass
@ -171,12 +157,12 @@ class DPQuery(object):
variables that are stored in self.
Args:
params: The parameters for the sample. In standard DP-SGD training,
the clipping norm for the sample's microbatch gradients (i.e.,
a maximum norm magnitude to which each gradient is clipped)
record: The record to be processed. In standard DP-SGD training,
the gradient computed for the examples in one microbatch, which
may be the gradient for just one example (for size 1 microbatches).
params: The parameters for the sample. In standard DP-SGD training, the
clipping norm for the sample's microbatch gradients (i.e., a maximum
norm magnitude to which each gradient is clipped)
record: The record to be processed. In standard DP-SGD training, the
gradient computed for the examples in one microbatch, which may be the
gradient for just one example (for size 1 microbatches).
Returns:
A structure of tensors to be aggregated.
@ -185,8 +171,7 @@ class DPQuery(object):
return record
@abc.abstractmethod
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record):
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
"""Accumulates a single preprocessed record into the sample state.
This method is intended to only do simple aggregation, typically just a sum.
@ -194,8 +179,8 @@ class DPQuery(object):
declaratively specify the type of aggregation required.
Args:
sample_state: The current sample state. In standard DP-SGD training,
the accumulated sum of previous clipped microbatch gradients.
sample_state: The current sample state. In standard DP-SGD training, the
accumulated sum of previous clipped microbatch gradients.
preprocessed_record: The preprocessed record to accumulate.
Returns:
@ -211,22 +196,22 @@ class DPQuery(object):
functions run on a single device. Typically this will be a simple sum.
Args:
params: The parameters for the sample. In standard DP-SGD training,
the clipping norm for the sample's microbatch gradients (i.e.,
a maximum norm magnitude to which each gradient is clipped)
sample_state: The current sample state. In standard DP-SGD training,
the accumulated sum of previous clipped microbatch gradients.
record: The record to accumulate. In standard DP-SGD training,
the gradient computed for the examples in one microbatch, which
may be the gradient for just one example (for size 1 microbatches).
params: The parameters for the sample. In standard DP-SGD training, the
clipping norm for the sample's microbatch gradients (i.e., a maximum
norm magnitude to which each gradient is clipped)
sample_state: The current sample state. In standard DP-SGD training, the
accumulated sum of previous clipped microbatch gradients.
record: The record to accumulate. In standard DP-SGD training, the
gradient computed for the examples in one microbatch, which may be the
gradient for just one example (for size 1 microbatches).
Returns:
The updated sample state. In standard DP-SGD training, the set of
previous microbatch gradients with the addition of the record argument.
"""
preprocessed_record = self.preprocess_record(params, record)
return self.accumulate_preprocessed_record(
sample_state, preprocessed_record)
return self.accumulate_preprocessed_record(sample_state,
preprocessed_record)
@abc.abstractmethod
def merge_sample_states(self, sample_state_1, sample_state_2):
@ -261,11 +246,14 @@ class DPQuery(object):
global_state: The global state, storing long-term privacy bookkeeping.
Returns:
A tuple (result, new_global_state) where "result" is the result of the
query and "new_global_state" is the updated global state. In standard
DP-SGD training, the result is a gradient update comprising a noised
average of the clipped gradients in the sample state---with the noise and
averaging performed in a manner that guarantees differential privacy.
A tuple `(result, new_global_state, event)` where:
* `result` is the result of the query,
* `new_global_state` is the updated global state, and
* `event` is the `DpEvent` that occurred.
In standard DP-SGD training, the result is a gradient update comprising a
noised average of the clipped gradients in the sample state---with the
noise and averaging performed in a manner that guarantees differential
privacy.
"""
pass
@ -312,7 +300,3 @@ class SumAggregationDPQuery(DPQuery):
def merge_sample_states(self, sample_state_1, sample_state_2):
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
return sample_state, global_state

View file

@ -22,6 +22,7 @@ import distutils
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query
@ -45,11 +46,6 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
"""
self._l2_norm_clip = l2_norm_clip
self._stddev = stddev
self._ledger = None
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
self._ledger = ledger
def make_global_state(self, l2_norm_clip, stddev):
"""Creates a global state from the given parameters."""
@ -100,12 +96,8 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
def add_noise(v):
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
if self._ledger:
dependencies = [
self._ledger.record_sum_query(global_state.l2_norm_clip,
global_state.stddev)
]
else:
dependencies = []
with tf.control_dependencies(dependencies):
return tf.nest.map_structure(add_noise, sample_state), global_state
result = tf.nest.map_structure(add_noise, sample_state)
noise_multiplier = global_state.stddev / global_state.l2_norm_clip
event = dp_event.GaussianDpEvent(noise_multiplier)
return result, global_state, event

View file

@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for queries over nested structures.
"""
"""Implements DPQuery interface for queries over nested structures."""
from __future__ import absolute_import
from __future__ import division
@ -22,6 +20,8 @@ from __future__ import print_function
import collections
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query
import tree
@ -60,16 +60,13 @@ class NestedQuery(dp_query.DPQuery):
def _map_to_queries(self, fn, *inputs, **kwargs):
"""Maps DPQuery methods to the subqueries."""
def caller(query, *args):
return getattr(query, fn)(*args, **kwargs)
return tree.map_structure_up_to(self._queries, caller, self._queries,
*inputs)
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
self._map_to_queries('set_ledger', ledger=ledger)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return self._map_to_queries('initial_global_state')
@ -89,28 +86,27 @@ class NestedQuery(dp_query.DPQuery):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
return self._map_to_queries('preprocess_record', params, record)
def accumulate_preprocessed_record(
self, sample_state, preprocessed_record):
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
"""Implements `tensorflow_privacy.DPQuery.accumulate_preprocessed_record`."""
return self._map_to_queries(
'accumulate_preprocessed_record',
sample_state,
preprocessed_record)
return self._map_to_queries('accumulate_preprocessed_record', sample_state,
preprocessed_record)
def merge_sample_states(self, sample_state_1, sample_state_2):
"""Implements `tensorflow_privacy.DPQuery.merge_sample_states`."""
return self._map_to_queries(
'merge_sample_states', sample_state_1, sample_state_2)
return self._map_to_queries('merge_sample_states', sample_state_1,
sample_state_2)
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
estimates_and_new_global_states = self._map_to_queries(
'get_noised_result', sample_state, global_state)
mapped_query_results = self._map_to_queries('get_noised_result',
sample_state, global_state)
flat_estimates, flat_new_global_states, flat_events = zip(
*tree.flatten_up_to(self._queries, mapped_query_results))
flat_estimates, flat_new_global_states = zip(
*tree.flatten_up_to(self._queries, estimates_and_new_global_states))
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
tf.nest.pack_sequence_as(self._queries, flat_new_global_states),
dp_event.ComposedDpEvent(events=flat_events))
def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
@ -118,12 +114,12 @@ class NestedQuery(dp_query.DPQuery):
def add_metrics(tuple_path, subquery, subquery_global_state):
metrics.update({
'/'.join(str(s) for s in tuple_path + (name,)): metric
for name, metric
in subquery.derive_metrics(subquery_global_state).items()})
'/'.join(str(s) for s in tuple_path + (name,)): metric for name,
metric in subquery.derive_metrics(subquery_global_state).items()
})
tree.map_structure_with_path_up_to(
self._queries, add_metrics, self._queries, global_state)
tree.map_structure_with_path_up_to(self._queries, add_metrics,
self._queries, global_state)
return metrics
@ -137,12 +133,13 @@ class NestedSumQuery(NestedQuery, dp_query.SumAggregationDPQuery):
Args:
queries: A nested structure of queries that must all be
SumAggregationDPQueries.
Raises: TypeError if any of the subqueries are not SumAggregationDPQueries.
"""
def check(query):
if not isinstance(query, dp_query.SumAggregationDPQuery):
raise TypeError('All subqueries must be SumAggregationDPQueries.')
tree.map_structure(check, queries)
super(NestedSumQuery, self).__init__(queries)

View file

@ -17,10 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query
@ -30,28 +29,9 @@ class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
Accumulates vectors without clipping or adding noise.
"""
def __init__(self):
self._ledger = None
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
warnings.warn(
'Attempt to use NoPrivacySumQuery with privacy ledger. Privacy '
'guarantees will be vacuous.')
self._ledger = ledger
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
if self._ledger:
dependencies = [
self._ledger.record_sum_query(float('inf'), 0.0)
]
else:
dependencies = []
with tf.control_dependencies(dependencies):
return sample_state, global_state
return sample_state, global_state, dp_event.NonPrivateDpEvent()
class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
@ -67,21 +47,10 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
privatized.
"""
def __init__(self):
"""Initializes the NoPrivacyAverageQuery."""
self._ledger = None
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
warnings.warn(
'Attempt to use NoPrivacyAverageQuery with privacy ledger. Privacy '
'guarantees will be vacuous.')
self._ledger = ledger
def initial_sample_state(self, template):
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
return (super(NoPrivacyAverageQuery, self).initial_sample_state(template),
tf.constant(0.0))
return (super(NoPrivacyAverageQuery,
self).initial_sample_state(template), tf.constant(0.0))
def preprocess_record(self, params, record, weight=1):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
@ -121,13 +90,5 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
sum_state, denominator = sample_state
if self._ledger:
dependencies = [
self._ledger.record_sum_query(float('inf'), 0.0)
]
else:
dependencies = []
with tf.control_dependencies(dependencies):
return (tf.nest.map_structure(lambda t: t / denominator,
sum_state), global_state)
result = tf.nest.map_structure(lambda t: t / denominator, sum_state)
return result, global_state, dp_event.NonPrivateDpEvent()

View file

@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for normalized queries.
"""
"""Implements DPQuery interface for normalized queries."""
from __future__ import absolute_import
from __future__ import division
@ -38,8 +36,8 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple(
'_GlobalState', ['numerator_state', 'denominator'])
_GlobalState = collections.namedtuple('_GlobalState',
['numerator_state', 'denominator'])
def __init__(self, numerator_query, denominator):
"""Initializes the NormalizedQuery.
@ -55,15 +53,11 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
assert isinstance(self._numerator, dp_query.SumAggregationDPQuery)
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
self._numerator.set_ledger(ledger)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
denominator = tf.cast(self._denominator, tf.float32)
return self._GlobalState(
self._numerator.initial_global_state(), denominator)
return self._GlobalState(self._numerator.initial_global_state(),
denominator)
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -80,13 +74,16 @@ class NormalizedQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_sum, new_sum_global_state = self._numerator.get_noised_result(
noised_sum, new_sum_global_state, event = self._numerator.get_noised_result(
sample_state, global_state.numerator_state)
def normalize(v):
return tf.truediv(v, global_state.denominator)
# The denominator is constant so the privacy cost comes from the numerator.
return (tf.nest.map_structure(normalize, noised_sum),
self._GlobalState(new_sum_global_state, global_state.denominator))
self._GlobalState(new_sum_global_state,
global_state.denominator), event)
def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""

View file

@ -21,6 +21,7 @@ import collections
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import quantile_estimator_query
@ -91,11 +92,6 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
assert isinstance(self._quantile_estimator_query,
dp_query.SumAggregationDPQuery)
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
self._sum_query.set_ledger(ledger)
self._quantile_estimator_query.set_ledger(ledger)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return self._GlobalState(
@ -128,11 +124,11 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_vectors, sum_state = self._sum_query.get_noised_result(
noised_vectors, sum_state, sum_event = self._sum_query.get_noised_result(
sample_state.sum_state, global_state.sum_state)
del sum_state # To be set explicitly later when we know the new clip.
new_l2_norm_clip, new_quantile_estimator_state = (
new_l2_norm_clip, new_quantile_estimator_state, quantile_event = (
self._quantile_estimator_query.get_noised_result(
sample_state.quantile_estimator_state,
global_state.quantile_estimator_state))
@ -146,7 +142,8 @@ class QuantileAdaptiveClipSumQuery(dp_query.SumAggregationDPQuery):
new_sum_query_state,
new_quantile_estimator_state)
return noised_vectors, new_global_state
event = dp_event.ComposedDpEvent(events=[sum_event, quantile_event])
return noised_vectors, new_global_state, event
def derive_metrics(self, global_state):
"""Returns the current clipping norm as a metric."""

View file

@ -22,7 +22,6 @@ from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import quantile_adaptive_clip_sum_query
from tensorflow_privacy.privacy.dp_query import test_utils
@ -231,7 +230,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
('start_high_arithmetic', False, False),
('start_high_geometric', False, True))
def test_adaptation_linspace(self, start_low, geometric):
# 100 records equally spaced from 0 to 10 in 0.1 increments.
# `num_records` records equally spaced from 0 to 10 in 0.1 increments.
# Test that we converge to the correct median value and bounce around it.
num_records = 21
records = [
@ -263,9 +262,10 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
('start_high_arithmetic', False, False),
('start_high_geometric', False, True))
def test_adaptation_all_equal(self, start_low, geometric):
# 20 equal records. Test that we converge to that record and bounce around
# it. Unlike the linspace test, the quantile-matching objective is very
# sharp at the optimum so a decaying learning rate is necessary.
# `num_records` equal records. Test that we converge to that record and
# bounce around it. Unlike the linspace test, the quantile-matching
# objective is very sharp at the optimum so a decaying learning rate is
# necessary.
num_records = 20
records = [tf.constant(5.0)] * num_records
@ -291,53 +291,6 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase,
if t > 40:
self.assertNear(actual_clip, 5.0, 0.5)
def test_ledger(self):
record1 = tf.constant([8.5])
record2 = tf.constant([-7.25])
population_size = tf.Variable(0)
selection_probability = tf.Variable(1.0)
query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery(
initial_l2_norm_clip=10.0,
noise_multiplier=1.0,
target_unclipped_quantile=0.0,
learning_rate=1.0,
clipped_count_stddev=0.0,
expected_num_records=2.0,
geometric_update=False)
query = privacy_ledger.QueryWithLedger(query, population_size,
selection_probability)
# First sample.
tf.assign(population_size, 10)
tf.assign(selection_probability, 0.1)
_, global_state = test_utils.run_query(query, [record1, record2])
expected_queries = [[10.0, 10.0], [0.5, 0.0]]
formatted = query.ledger.get_formatted_ledger_eager()
sample_1 = formatted[0]
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
# Second sample.
tf.assign(population_size, 20)
tf.assign(selection_probability, 0.2)
test_utils.run_query(query, [record1, record2], global_state)
formatted = query.ledger.get_formatted_ledger_eager()
sample_1, sample_2 = formatted
self.assertAllClose(sample_1.population_size, 10.0)
self.assertAllClose(sample_1.selection_probability, 0.1)
self.assertAllClose(sample_1.queries, expected_queries)
expected_queries_2 = [[9.0, 9.0], [0.5, 0.0]]
self.assertAllClose(sample_2.population_size, 20.0)
self.assertAllClose(sample_2.selection_probability, 0.2)
self.assertAllClose(sample_2.queries, expected_queries_2)
if __name__ == '__main__':
tf.test.main()

View file

@ -24,6 +24,7 @@ from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.dp_query import no_privacy_query
from tensorflow_privacy.privacy.dp_query import normalized_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
@ -73,6 +74,15 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
updating is preferred for non-negative records like vector norms that
could potentially be very large or very close to zero.
"""
if target_quantile < 0 or target_quantile > 1:
raise ValueError(
f'`target_quantile` must be between 0 and 1, got {target_quantile}.')
if learning_rate < 0:
raise ValueError(
f'`learning_rate` must be non-negative, got {learning_rate}')
self._initial_estimate = initial_estimate
self._target_quantile = target_quantile
self._learning_rate = learning_rate
@ -100,10 +110,6 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
l2_norm_clip=0.5, stddev=below_estimate_stddev),
denominator=expected_num_records)
def set_ledger(self, ledger):
"""Implements `tensorflow_privacy.DPQuery.set_ledger`."""
self._below_estimate_query.set_ledger(ledger)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return self._GlobalState(
@ -138,7 +144,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
below_estimate_result, new_below_estimate_state = (
below_estimate_result, new_below_estimate_state, below_estimate_event = (
self._below_estimate_query.get_noised_result(
sample_state, global_state.below_estimate_state))
@ -162,7 +168,7 @@ class QuantileEstimatorQuery(dp_query.SumAggregationDPQuery):
current_estimate=new_estimate,
below_estimate_state=new_below_estimate_state)
return new_estimate, new_global_state
return new_estimate, new_global_state, below_estimate_event
def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
@ -209,3 +215,37 @@ class NoPrivacyQuantileEstimatorQuery(QuantileEstimatorQuery):
del below_estimate_stddev
del expected_num_records
return no_privacy_query.NoPrivacyAverageQuery()
class TreeQuantileEstimatorQuery(QuantileEstimatorQuery):
"""Iterative process to estimate target quantile of a univariate distribution.
Unlike the base class, this uses a `TreeResidualSumQuery` to estimate the
fraction below estimate with an exact denominator. This assumes that below
estimate value is used in a SGD-like update and we want to privatize the
cumsum of the below estimate.
See "Practical and Private (Deep) Learning without Sampling or Shuffling"
(https://arxiv.org/abs/2103.00039) for tree aggregation and privacy
accounting, and "Differentially Private Learning with Adaptive Clipping"
(https://arxiv.org/abs/1905.03871) for how below estimate is used in a
SGD-like algorithm.
"""
def _construct_below_estimate_query(self, below_estimate_stddev,
expected_num_records):
# See comments in `QuantileEstimatorQuery._construct_below_estimate_query`
# for why clip norm 0.5 is used for the query.
sum_query = tree_aggregation_query.TreeResidualSumQuery.build_l2_gaussian_query(
clip_norm=0.5,
noise_multiplier=2 * below_estimate_stddev,
record_specs=tf.TensorSpec([]))
return normalized_query.NormalizedQuery(
sum_query, denominator=expected_num_records)
def reset_state(self, noised_results, global_state):
new_numerator_state = self._below_estimate_query._numerator.reset_state( # pylint: disable=protected-access,line-too-long
noised_results, global_state.below_estimate_state.numerator_state)
new_below_estimate_state = global_state.below_estimate_state._replace(
numerator_state=new_numerator_state)
return global_state._replace(below_estimate_state=new_below_estimate_state)

View file

@ -29,22 +29,26 @@ from tensorflow_privacy.privacy.dp_query import test_utils
tf.enable_eager_execution()
def _make_quantile_estimator_query(
initial_estimate,
target_quantile,
learning_rate,
below_estimate_stddev,
expected_num_records,
geometric_update):
def _make_quantile_estimator_query(initial_estimate,
target_quantile,
learning_rate,
below_estimate_stddev,
expected_num_records,
geometric_update,
tree_aggregation=False):
if expected_num_records is not None:
return quantile_estimator_query.QuantileEstimatorQuery(
initial_estimate,
target_quantile,
learning_rate,
below_estimate_stddev,
expected_num_records,
geometric_update)
if tree_aggregation:
return quantile_estimator_query.TreeQuantileEstimatorQuery(
initial_estimate, target_quantile, learning_rate,
below_estimate_stddev, expected_num_records, geometric_update)
else:
return quantile_estimator_query.QuantileEstimatorQuery(
initial_estimate, target_quantile, learning_rate,
below_estimate_stddev, expected_num_records, geometric_update)
else:
if tree_aggregation:
raise ValueError(
'Cannot set expected_num_records to None for tree aggregation.')
return quantile_estimator_query.NoPrivacyQuantileEstimatorQuery(
initial_estimate,
target_quantile,
@ -54,8 +58,9 @@ def _make_quantile_estimator_query(
class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('exact', True), ('fixed', False))
def test_target_zero(self, exact):
@parameterized.named_parameters(
('exact', True, False), ('fixed', False, False), ('tree', False, True))
def test_target_zero(self, exact, tree):
record1 = tf.constant(8.5)
record2 = tf.constant(7.25)
@ -65,7 +70,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
learning_rate=1.0,
below_estimate_stddev=0.0,
expected_num_records=(None if exact else 2.0),
geometric_update=False)
geometric_update=False,
tree_aggregation=tree)
global_state = query.initial_global_state()
@ -84,18 +90,20 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
@parameterized.named_parameters(('exact', True), ('fixed', False))
def test_target_zero_geometric(self, exact):
@parameterized.named_parameters(
('exact', True, False), ('fixed', False, False), ('tree', False, True))
def test_target_zero_geometric(self, exact, tree):
record1 = tf.constant(5.0)
record2 = tf.constant(2.5)
query = _make_quantile_estimator_query(
initial_estimate=16.0,
target_quantile=0.0,
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
below_estimate_stddev=0.0,
expected_num_records=(None if exact else 2.0),
geometric_update=True)
geometric_update=True,
tree_aggregation=tree)
global_state = query.initial_global_state()
@ -116,8 +124,9 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
@parameterized.named_parameters(('exact', True), ('fixed', False))
def test_target_one(self, exact):
@parameterized.named_parameters(
('exact', True, False), ('fixed', False, False), ('tree', False, True))
def test_target_one(self, exact, tree):
record1 = tf.constant(1.5)
record2 = tf.constant(2.75)
@ -127,7 +136,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
learning_rate=1.0,
below_estimate_stddev=0.0,
expected_num_records=(None if exact else 2.0),
geometric_update=False)
geometric_update=False,
tree_aggregation=tree)
global_state = query.initial_global_state()
@ -146,18 +156,20 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
@parameterized.named_parameters(('exact', True), ('fixed', False))
def test_target_one_geometric(self, exact):
@parameterized.named_parameters(
('exact', True, False), ('fixed', False, False), ('tree', False, True))
def test_target_one_geometric(self, exact, tree):
record1 = tf.constant(1.5)
record2 = tf.constant(3.0)
query = _make_quantile_estimator_query(
initial_estimate=0.5,
target_quantile=1.0,
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
learning_rate=np.log(2.0), # Geometric steps in powers of 2.
below_estimate_stddev=0.0,
expected_num_records=(None if exact else 2.0),
geometric_update=True)
geometric_update=True,
tree_aggregation=tree)
global_state = query.initial_global_state()
@ -179,15 +191,19 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(actual_estimate.numpy(), expected_estimate)
@parameterized.named_parameters(
('start_low_geometric_exact', True, True, True),
('start_low_arithmetic_exact', True, True, False),
('start_high_geometric_exact', True, False, True),
('start_high_arithmetic_exact', True, False, False),
('start_low_geometric_noised', False, True, True),
('start_low_arithmetic_noised', False, True, False),
('start_high_geometric_noised', False, False, True),
('start_high_arithmetic_noised', False, False, False))
def test_linspace(self, exact, start_low, geometric):
('start_low_geometric_exact', True, True, True, False),
('start_low_arithmetic_exact', True, True, False, False),
('start_high_geometric_exact', True, False, True, False),
('start_high_arithmetic_exact', True, False, False, False),
('start_low_geometric_noised', False, True, True, False),
('start_low_arithmetic_noised', False, True, False, False),
('start_high_geometric_noised', False, False, True, False),
('start_high_arithmetic_noised', False, False, False, False),
('start_low_geometric_tree', False, True, True, True),
('start_low_arithmetic_tree', False, True, False, True),
('start_high_geometric_tree', False, False, True, True),
('start_high_arithmetic_tree', False, False, False, True))
def test_linspace(self, exact, start_low, geometric, tree):
# 100 records equally spaced from 0 to 10 in 0.1 increments.
# Test that we converge to the correct median value and bounce around it.
num_records = 21
@ -200,7 +216,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
learning_rate=1.0,
below_estimate_stddev=(0.0 if exact else 1e-2),
expected_num_records=(None if exact else num_records),
geometric_update=geometric)
geometric_update=geometric,
tree_aggregation=tree)
global_state = query.initial_global_state()
@ -213,15 +230,19 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
self.assertNear(actual_estimate, 5.0, 0.25)
@parameterized.named_parameters(
('start_low_geometric_exact', True, True, True),
('start_low_arithmetic_exact', True, True, False),
('start_high_geometric_exact', True, False, True),
('start_high_arithmetic_exact', True, False, False),
('start_low_geometric_noised', False, True, True),
('start_low_arithmetic_noised', False, True, False),
('start_high_geometric_noised', False, False, True),
('start_high_arithmetic_noised', False, False, False))
def test_all_equal(self, exact, start_low, geometric):
('start_low_geometric_exact', True, True, True, False),
('start_low_arithmetic_exact', True, True, False, False),
('start_high_geometric_exact', True, False, True, False),
('start_high_arithmetic_exact', True, False, False, False),
('start_low_geometric_noised', False, True, True, False),
('start_low_arithmetic_noised', False, True, False, False),
('start_high_geometric_noised', False, False, True, False),
('start_high_arithmetic_noised', False, False, False, False),
('start_low_geometric_tree', False, True, True, True),
('start_low_arithmetic_tree', False, True, False, True),
('start_high_geometric_tree', False, False, True, True),
('start_high_arithmetic_tree', False, False, False, True))
def test_all_equal(self, exact, start_low, geometric, tree):
# 20 equal records. Test that we converge to that record and bounce around
# it. Unlike the linspace test, the quantile-matching objective is very
# sharp at the optimum so a decaying learning rate is necessary.
@ -236,7 +257,8 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
learning_rate=learning_rate,
below_estimate_stddev=(0.0 if exact else 1e-2),
expected_num_records=(None if exact else num_records),
geometric_update=geometric)
geometric_update=geometric,
tree_aggregation=tree)
global_state = query.initial_global_state()
@ -258,6 +280,38 @@ class QuantileEstimatorQueryTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaisesRegex(ValueError, 'scalar'):
query.accumulate_record(None, None, [1.0, 2.0])
def test_tree_noise_restart(self):
sample_num, tolerance, stddev = 1000, 0.3, 0.1
initial_estimate, expected_num_records = 5., 2.
record1 = tf.constant(1.)
record2 = tf.constant(10.)
query = _make_quantile_estimator_query(
initial_estimate=initial_estimate,
target_quantile=.5,
learning_rate=1.,
below_estimate_stddev=stddev,
expected_num_records=expected_num_records,
geometric_update=False,
tree_aggregation=True)
global_state = query.initial_global_state()
self.assertAllClose(global_state.current_estimate, initial_estimate)
# As the target quantile is accurate, there is no signal and only noise.
samples = []
for _ in range(sample_num):
noised_estimate, global_state = test_utils.run_query(
query, [record1, record2], global_state)
samples.append(noised_estimate.numpy())
global_state = query.reset_state(noised_estimate, global_state)
self.assertNotEqual(global_state.current_estimate, initial_estimate)
global_state = global_state._replace(current_estimate=initial_estimate)
self.assertAllClose(
np.std(samples), stddev / expected_num_records, rtol=tolerance)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,205 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements DPQuery interface for restarting the states of another query.
This query is used to compose with a DPQuery that has `reset_state` function.
"""
import abc
import collections
from typing import Optional
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
class RestartIndicator(metaclass=abc.ABCMeta):
"""Base class establishing interface for restarting the tree state.
A `RestartIndicator` maintains a state, and each time `next` is called, a bool
value is generated to indicate whether to restart, and the indicator state is
advanced.
"""
@abc.abstractmethod
def initialize(self):
"""Makes an initialized state for `RestartIndicator`.
Returns:
An initial state.
"""
raise NotImplementedError
@abc.abstractmethod
def next(self, state):
"""Gets next bool indicator and advances the `RestartIndicator` state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is bool indicator and new_state
is the advanced state.
"""
raise NotImplementedError
class PeriodicRoundRestartIndicator(RestartIndicator):
"""Indicator for resetting the tree state after every a few number of queries.
The indicator will maintain an internal counter as state.
"""
def __init__(self, frequency: int, warmup: Optional[int] = None):
"""Construct the `PeriodicRoundRestartIndicator`.
Args:
frequency: The `next` function will return `True` every `frequency` number
of `next` calls.
warmup: The first `True` will be returned at the `warmup` times call of
`next`.
"""
if frequency < 1:
raise ValueError('Restart frequency should be equal or larger than 1, '
f'got {frequency}')
if warmup is None:
warmup = 0
elif warmup <= 0 or warmup >= frequency:
raise ValueError(
f'Warmup should be between 1 and `frequency-1={frequency-1}`, '
f'got {warmup}')
self.frequency = frequency
self.warmup = warmup
def initialize(self):
"""Returns initialized state of 0 for `PeriodicRoundRestartIndicator`."""
return tf.constant(0, tf.int32)
def next(self, state):
"""Gets next bool indicator and advances the state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is the bool indicator and new_state
of `state+1`.
"""
frequency = tf.constant(self.frequency, tf.int32)
warmup = tf.constant(self.warmup, tf.int32)
state = state + tf.constant(1, tf.int32)
flag = tf.math.equal(tf.math.floormod(state, frequency), warmup)
return flag, state
class PeriodicTimeRestartIndicator(RestartIndicator):
"""Indicator for periodically resetting the tree state after a certain time.
The indicator will maintain a state to track the previous restart time.
"""
def __init__(self, period_seconds: float):
"""Construct the `PeriodicTimeRestartIndicator`.
Args:
period_seconds: The `next` function will return `True` if called after
`period_seconds`.
"""
if period_seconds <= 0:
raise ValueError('Restart period_seconds should be larger than 0, got '
f'{period_seconds}')
self.period_seconds = period_seconds
@tf.function
def initialize(self):
"""Returns initial time as state."""
return tf.timestamp()
@tf.function
def next(self, state):
"""Gets next bool indicator and advances the state.
Args:
state: The current state.
Returns:
A pair (value, new_state) where value is the bool indicator and new_state
of time.
"""
current_time = tf.timestamp()
current_period = current_time - state
reset_flag = tf.math.greater(
current_period,
tf.convert_to_tensor(self.period_seconds, current_period.dtype))
if reset_flag:
state = current_time
return reset_flag, state
class RestartQuery(dp_query.SumAggregationDPQuery):
"""`DPQuery` for `SumAggregationDPQuery` with a `reset_state` function."""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple(
'_GlobalState', ['inner_query_state', 'indicator_state'])
def __init__(self, inner_query: dp_query.SumAggregationDPQuery,
restart_indicator: RestartIndicator):
"""Initializes `RestartQuery`.
Args:
inner_query: A `SumAggregationDPQuery` has `reset_state` attribute.
restart_indicator: A `RestartIndicator` to generate the boolean indicator
for resetting the state.
"""
if not hasattr(inner_query, 'reset_state'):
raise ValueError(f'{type(inner_query)} must define `reset_state` to be '
'composed with `RestartQuery`.')
self._inner_query = inner_query
self._restart_indicator = restart_indicator
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return self._GlobalState(
inner_query_state=self._inner_query.initial_global_state(),
indicator_state=self._restart_indicator.initialize())
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return self._inner_query.derive_sample_params(
global_state.inner_query_state)
def initial_sample_state(self, template):
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
return self._inner_query.initial_sample_state(template)
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`."""
return self._inner_query.preprocess_record(params, record)
@tf.function
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
noised_results, inner_state, event = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state)
restart_flag, indicator_state = self._restart_indicator.next(
global_state.indicator_state)
if restart_flag:
inner_state = self._inner_query.reset_state(noised_results, inner_state)
return (noised_results, self._GlobalState(inner_state,
indicator_state), event)
def derive_metrics(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_metrics`."""
return self._inner_query.derive_metrics(global_state.inner_query_state)

View file

@ -0,0 +1,180 @@
# Copyright 2021, Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `restart_query`."""
from absl.testing import parameterized
import mock
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import restart_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
class RoundRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('zero', 0), ('negative', -1))
def test_round_raise(self, frequency):
with self.assertRaisesRegex(
ValueError, 'Restart frequency should be equal or larger than 1'):
restart_query.PeriodicRoundRestartIndicator(frequency)
@parameterized.named_parameters(('zero', 0), ('negative', -1), ('equal', 2),
('large', 3))
def test_round_raise_warmup(self, warmup):
frequency = 2
with self.assertRaisesRegex(
ValueError,
f'Warmup should be between 1 and `frequency-1={frequency-1}`'):
restart_query.PeriodicRoundRestartIndicator(frequency, warmup)
@parameterized.named_parameters(('f1', 1), ('f2', 2), ('f4', 4), ('f5', 5))
def test_round_indicator(self, frequency):
total_steps = 20
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
state = indicator.initialize()
for i in range(total_steps):
flag, state = indicator.next(state)
if i % frequency == frequency - 1:
self.assertTrue(flag)
else:
self.assertFalse(flag)
@parameterized.named_parameters(('f2', 2, 1), ('f4', 4, 3), ('f5', 5, 2))
def test_round_indicator_warmup(self, frequency, warmup):
total_steps = 20
indicator = restart_query.PeriodicRoundRestartIndicator(frequency, warmup)
state = indicator.initialize()
for i in range(total_steps):
flag, state = indicator.next(state)
if i % frequency == warmup - 1:
self.assertTrue(flag)
else:
self.assertFalse(flag)
class TimeRestartIndicatorTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('zero', 0), ('negative', -1.))
def test_round_raise(self, secs):
with self.assertRaisesRegex(
ValueError, 'Restart period_seconds should be larger than 0'):
restart_query.PeriodicTimeRestartIndicator(secs)
def test_round_indicator(self):
indicator = restart_query.PeriodicTimeRestartIndicator(period_seconds=3600 *
23.5)
# TODO(b/193679963): use `tf.timestamp` as the default of a member of
# the `PeriodicTimeRestartIndicator` to unroll the mock test.
return_time = tf.Variable(
1627018868.452365) # 22:41pm PST 5:41am UTC, July 22, initialize
with mock.patch.object(
tf, 'timestamp', return_value=return_time) as mock_func:
time_stamps = [
1627022468.452365, # 23:41pm PST 5:41am UTC, July 22, 1 hr, False
1627105268.452365, # 22:41pm PST 5:41am UTC, July 23, 1 day, True
1627112468.452365, # 2 hr after restart, False
1627189508.452365, # 23.4 hr after restart, False
1627189904.452365, # 23.51 hr after restart, True
]
expected_values = [False, True, False, False, True]
state = indicator.initialize()
for v, t in zip(expected_values, time_stamps):
return_time.assign(t)
mock_func.return_value = return_time
flag, state = indicator.next(state)
self.assertEqual(v, flag.numpy())
def _get_l2_clip_fn():
def l2_clip_fn(record_as_list, clip_value):
clipped_record, _ = tf.clip_by_global_norm(record_as_list, clip_value)
return clipped_record
return l2_clip_fn
class RestartQueryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
tree_node_value, frequency):
total_steps = 20
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeCumulativeSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False)
query = restart_query.RestartQuery(query, indicator)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
# Expected value is the combination of cumsum of signal; sum of trees
# that have been reset; current tree sum. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = (
scalar_value * (i + 1) +
i // frequency * tree_node_value * bin(frequency)[2:].count('1') +
tree_node_value * bin(i % frequency + 1)[2:].count('1'))
self.assertEqual(query_result, expected)
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
frequency):
total_steps = 20
indicator = restart_query.PeriodicRoundRestartIndicator(frequency)
query = tree_aggregation_query.TreeResidualSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False)
query = restart_query.RestartQuery(query, indicator)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
# Expected value is the signal of the current round plus the residual of
# two continous tree aggregation values. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = scalar_value + tree_node_value * (
bin(i % frequency + 1)[2:].count('1') -
bin(i % frequency)[2:].count('1'))
self.assertEqual(query_result, expected)
if __name__ == '__main__':
tf.test.main()

View file

@ -44,6 +44,7 @@ def run_query(query, records, global_state=None, weights=None):
sample_state = query.accumulate_record(params, sample_state, record)
else:
for weight, record in zip(weights, records):
sample_state = query.accumulate_record(
params, sample_state, record, weight)
return query.get_noised_result(sample_state, global_state)
sample_state = query.accumulate_record(params, sample_state, record,
weight)
result, global_state, _ = query.get_noised_result(sample_state, global_state)
return result, global_state

View file

@ -16,20 +16,23 @@
`TreeAggregator` and `EfficientTreeAggregator` compute cumulative sums of noise
based on tree aggregation. When using an appropriate noise function (e.g.,
Gaussian noise), it allows for efficient differentially private algorithms under
continual observation, without prior subsampling or shuffling assumptions.
`build_tree` constructs a tree given the leaf nodes by recursively summing the
children nodes to get the parent node. It allows for efficient range queries and
other statistics such as quantiles on the leaf nodes.
continual observation, without prior subsampling or shuffling assumptions. This
module implements the core logic of tree aggregation in Tensorflow, which serves
as helper functions for `tree_aggregation_query`. This module and helper
functions are publicly accessible.
"""
import abc
import collections
from typing import Any, Callable, Collection, Optional, Tuple, Union
import attr
import tensorflow as tf
# TODO(b/192464750): find a proper place for the helper functions, privatize
# the tree aggregation logic, and encourage users to use the DPQuery API.
class ValueGenerator(metaclass=abc.ABCMeta):
"""Base class establishing interface for stateful value generation.
@ -44,6 +47,7 @@ class ValueGenerator(metaclass=abc.ABCMeta):
Returns:
An initial state.
"""
raise NotImplementedError
@abc.abstractmethod
def next(self, state):
@ -56,6 +60,7 @@ class ValueGenerator(metaclass=abc.ABCMeta):
A pair (value, new_state) where value is the next value and new_state
is the advanced state.
"""
raise NotImplementedError
class GaussianNoiseGenerator(ValueGenerator):
@ -65,6 +70,9 @@ class GaussianNoiseGenerator(ValueGenerator):
nested structure of `tf.TensorSpec`s.
"""
# pylint: disable=invalid-name
_GlobalState = collections.namedtuple('_GlobalState', ['seeds', 'stddev'])
def __init__(self,
noise_std: float,
specs: Collection[tf.TensorSpec],
@ -78,46 +86,57 @@ class GaussianNoiseGenerator(ValueGenerator):
seed: An optional integer seed. If None, generator is seeded from the
clock.
"""
self.noise_std = noise_std
self.specs = specs
self.seed = seed
self._noise_std = noise_std
self._specs = specs
self._seed = seed
def initialize(self):
"""Makes an initial state for the GaussianNoiseGenerator.
Returns:
An initial state.
A named tuple of (seeds, stddev).
"""
if self.seed is None:
return tf.cast(
tf.stack([
tf.math.floor(tf.timestamp() * 1e6),
tf.math.floor(tf.math.log(tf.timestamp() * 1e6))
]),
dtype=tf.int64)
if self._seed is None:
time_now = tf.timestamp()
residual = time_now - tf.math.floor(time_now)
return self._GlobalState(
tf.cast(
tf.stack([
tf.math.floor(tf.timestamp() * 1e6),
tf.math.floor(residual * 1e9)
]),
dtype=tf.int64), tf.constant(self._noise_std, dtype=tf.float32))
else:
return tf.constant(self.seed, dtype=tf.int64, shape=(2,))
return self._GlobalState(
tf.constant(self._seed, dtype=tf.int64, shape=(2,)),
tf.constant(self._noise_std, dtype=tf.float32))
def next(self, state):
"""Gets next value and advances the GaussianNoiseGenerator.
Args:
state: The current state.
state: The current state (seed, noise_std).
Returns:
A pair (sample, new_state) where sample is a new sample and new_state
is the advanced state.
A tuple of (sample, new_state) where sample is a new sample and new_state
is the advanced state (seed+1, noise_std).
"""
flat_structure = tf.nest.flatten(self.specs)
flat_seeds = [state + i for i in range(len(flat_structure))]
nest_seeds = tf.nest.pack_sequence_as(self.specs, flat_seeds)
flat_structure = tf.nest.flatten(self._specs)
flat_seeds = [state.seeds + i for i in range(len(flat_structure))]
nest_seeds = tf.nest.pack_sequence_as(self._specs, flat_seeds)
def _get_noise(spec, seed):
return tf.random.stateless_normal(
shape=spec.shape, seed=seed, stddev=self.noise_std)
shape=spec.shape, seed=seed, stddev=state.stddev)
nest_noise = tf.nest.map_structure(_get_noise, self.specs, nest_seeds)
return nest_noise, flat_seeds[-1] + 1
nest_noise = tf.nest.map_structure(_get_noise, self._specs, nest_seeds)
return nest_noise, self._GlobalState(flat_seeds[-1] + 1, state.stddev)
def make_state(self, seeds: tf.Tensor, stddev: tf.Tensor):
"""Returns a new named tuple of (seeds, stddev)."""
seeds = tf.ensure_shape(seeds, shape=(2,))
return self._GlobalState(
tf.cast(seeds, dtype=tf.int64), tf.cast(stddev, dtype=tf.float32))
class StatelessValueGenerator(ValueGenerator):
@ -170,6 +189,7 @@ class TreeState(object):
value_generator_state = attr.ib(type=Any)
# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.
@tf.function
def get_step_idx(state: TreeState) -> tf.Tensor:
"""Returns the current leaf node index based on `TreeState.level_buffer_idx`."""
@ -192,6 +212,14 @@ class TreeAggregator():
https://dl.acm.org/doi/pdf/10.1145/1806689.1806787. A buffer at the scale of
tree depth is maintained and updated when a new conceptual leaf node arrives.
Example usage:
random_generator = GaussianNoiseGenerator(...)
tree_aggregator = TreeAggregator(random_generator)
state = tree_aggregator.init_state()
for leaf_node_idx in range(total_steps):
assert leaf_node_idx == get_step_idx(state))
noise, state = tree_aggregator.get_cumsum_and_update(state)
Attributes:
value_generator: A `ValueGenerator` or a no-arg function to generate a noise
value for each tree node.
@ -209,14 +237,8 @@ class TreeAggregator():
else:
self.value_generator = StatelessValueGenerator(value_generator)
def init_state(self) -> TreeState:
"""Returns initial `TreeState`.
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
"""
value_generator_state = self.value_generator.initialize()
def _get_init_state(self, value_generator_state) -> TreeState:
"""Returns initial `TreeState` given `value_generator_state`."""
level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True)
level_buffer_idx = level_buffer_idx.write(0, tf.constant(
0, dtype=tf.int32)).stack()
@ -228,12 +250,28 @@ class TreeAggregator():
new_val)
level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(),
level_buffer_structure, new_val)
return TreeState(
level_buffer=level_buffer,
level_buffer_idx=level_buffer_idx,
value_generator_state=value_generator_state)
def init_state(self) -> TreeState:
"""Returns initial `TreeState`.
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
Returns:
An initialized `TreeState`.
"""
value_generator_state = self.value_generator.initialize()
return self._get_init_state(value_generator_state)
def reset_state(self, state: TreeState) -> TreeState:
"""Returns reset `TreeState` after restarting a new tree."""
return self._get_init_state(state.value_generator_state)
@tf.function
def _get_cumsum(self, level_buffer: Collection[tf.Tensor]) -> tf.Tensor:
return tf.nest.map_structure(lambda x: tf.reduce_sum(x, axis=0),
@ -242,7 +280,7 @@ class TreeAggregator():
@tf.function
def get_cumsum_and_update(self,
state: TreeState) -> Tuple[tf.Tensor, TreeState]:
"""Returns tree aggregated value and updated `TreeState` for one step.
"""Returns tree aggregated noise and updates `TreeState` for the next step.
`TreeState` is updated to prepare for accepting the *next* leaf node. Note
that `get_step_idx` can be called to get the current index of the leaf node
@ -253,10 +291,20 @@ class TreeAggregator():
Args:
state: `TreeState` for the current leaf node, index can be queried by
`tree_aggregation.get_step_idx(state.level_buffer_idx)`.
Returns:
Tuple of (noise, state) where `noise` is generated by tree aggregated
protocol for the cumulative sum of streaming data, and `state` is the
updated `TreeState`.
"""
level_buffer_idx, level_buffer, value_generator_state = (
state.level_buffer_idx, state.level_buffer, state.value_generator_state)
# We only publicize a combined function for updating state and returning
# noised results because this DPQuery is designed for the streaming data,
# and we only maintain a dynamic memory buffer of max size logT. Only the
# the most recent noised results can be queried, and the queries are
# expected to happen for every step in the streaming setting.
cumsum = self._get_cumsum(level_buffer)
new_level_buffer = tf.nest.map_structure(
@ -315,6 +363,14 @@ class EfficientTreeAggregator():
`sigma * sqrt(2^{d-1}/(2^d-1))`. which becomes `sigma / sqrt(2)` when
the tree is very tall.
Example usage:
random_generator = GaussianNoiseGenerator(...)
tree_aggregator = EfficientTreeAggregator(random_generator)
state = tree_aggregator.init_state()
for leaf_node_idx in range(total_steps):
assert leaf_node_idx == get_step_idx(state))
noise, state = tree_aggregator.get_cumsum_and_update(state)
Attributes:
value_generator: A `ValueGenerator` or a no-arg function to generate a noise
value for each tree node.
@ -332,17 +388,8 @@ class EfficientTreeAggregator():
else:
self.value_generator = StatelessValueGenerator(value_generator)
def init_state(self) -> TreeState:
"""Returns initial `TreeState`.
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
Returns:
An initialized `TreeState`.
"""
value_generator_state = self.value_generator.initialize()
def _get_init_state(self, value_generator_state):
"""Returns initial buffer for `TreeState`."""
level_buffer_idx = tf.TensorArray(dtype=tf.int32, size=1, dynamic_size=True)
level_buffer_idx = level_buffer_idx.write(0, tf.constant(
0, dtype=tf.int32)).stack()
@ -354,12 +401,28 @@ class EfficientTreeAggregator():
new_val)
level_buffer = tf.nest.map_structure(lambda x, y: x.write(0, y).stack(),
level_buffer_structure, new_val)
return TreeState(
level_buffer=level_buffer,
level_buffer_idx=level_buffer_idx,
value_generator_state=value_generator_state)
def init_state(self) -> TreeState:
"""Returns initial `TreeState`.
Initializes `TreeState` for a tree of a single leaf node: the respective
initial node value in `TreeState.level_buffer` is generated by the value
generator function, and the node index is 0.
Returns:
An initialized `TreeState`.
"""
value_generator_state = self.value_generator.initialize()
return self._get_init_state(value_generator_state)
def reset_state(self, state: TreeState) -> TreeState:
"""Returns reset `TreeState` after restarting a new tree."""
return self._get_init_state(state.value_generator_state)
@tf.function
def _get_cumsum(self, state: TreeState) -> tf.Tensor:
"""Returns weighted cumulative sum of noise based on `TreeState`."""
@ -381,7 +444,7 @@ class EfficientTreeAggregator():
@tf.function
def get_cumsum_and_update(self,
state: TreeState) -> Tuple[tf.Tensor, TreeState]:
"""Returns tree aggregated value and updated `TreeState` for one step.
"""Returns tree aggregated noise and updates `TreeState` for the next step.
`TreeState` is updated to prepare for accepting the *next* leaf node. Note
that `get_step_idx` can be called to get the current index of the leaf node
@ -394,7 +457,17 @@ class EfficientTreeAggregator():
Args:
state: `TreeState` for the current leaf node, index can be queried by
`tree_aggregation.get_step_idx(state.level_buffer_idx)`.
Returns:
Tuple of (noise, state) where `noise` is generated by tree aggregated
protocol for the cumulative sum of streaming data, and `state` is the
updated `TreeState`..
"""
# We only publicize a combined function for updating state and returning
# noised results because this DPQuery is designed for the streaming data,
# and we only maintain a dynamic memory buffer of max size logT. Only the
# the most recent noised results can be queried, and the queries are
# expected to happen for every step in the streaming setting.
cumsum = self._get_cumsum(state)
level_buffer_idx, level_buffer, value_generator_state = (
@ -449,79 +522,3 @@ class EfficientTreeAggregator():
level_buffer_idx=new_level_buffer_idx,
value_generator_state=value_generator_state)
return cumsum, new_state
@tf.function
def build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
"""A function constructs a complete tree given all the leaf nodes.
The function takes a 1-D array representing the leaf nodes of a tree and the
tree's arity, and constructs a complete tree by recursively summing the
adjacent children to get the parent until reaching the root node. Because we
assume a complete tree, if the number of leaf nodes does not divide arity, the
leaf nodes will be padded with zeros.
Args:
leaf_nodes: A 1-D array storing the leaf nodes of the tree.
arity: A `int` for the branching factor of the tree, i.e. the number of
children for each internal node.
Returns:
`tf.RaggedTensor` representing the tree. For example, if
`leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value
should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way,
`tree[layer][index]` can be used to access the node indexed by (layer,
index) in the tree,
Raises:
ValueError: if parameters don't meet expectations. There are two situations
where the error is raised: (1) the input tensor has length smaller than 1;
(2) The arity is less than 2.
"""
if len(leaf_nodes) <= 0:
raise ValueError(
'The number of leaf nodes should at least be 1.'
f'However, an array of length {len(leaf_nodes)} is detected')
if arity <= 1:
raise ValueError('The branching factor should be at least 2.'
f'However, a branching factor of {arity} is detected.')
def pad_zero(leaf_nodes, size):
paddings = [[0, size - len(leaf_nodes)]]
return tf.pad(leaf_nodes, paddings)
leaf_nodes_size = tf.constant(len(leaf_nodes), dtype=tf.float32)
num_layers = tf.math.ceil(
tf.math.log(leaf_nodes_size) /
tf.math.log(tf.constant(arity, dtype=tf.float32))) + 1
leaf_nodes = pad_zero(leaf_nodes, tf.math.pow(float(arity), num_layers - 1))
def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor:
return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1)
# The following `tf.while_loop` constructs the tree from bottom up by
# iteratively applying `_shrink_layer` to each layer of the tree. The reason
# for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not
# support auto-translation from python loop to tf loop when loop variables
# contain a `RaggedTensor` whose shape changes across iterations.
idx = tf.identity(num_layers)
loop_cond = lambda i, h: tf.less_equal(2.0, i)
def _loop_body(i, h):
return [
tf.add(i, -1.0),
tf.concat(([_shrink_layer(h[0], arity)], h), axis=0)
]
_, tree = tf.while_loop(
loop_cond,
_loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])],
shape_invariants=[
idx.get_shape(),
tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1)
])
return tree

View file

@ -11,31 +11,59 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DPQuery for continual observation queries relying on `tree_aggregation`."""
"""`DPQuery`s for differentially private tree aggregation protocols.
`TreeCumulativeSumQuery` and `TreeResidualSumQuery` are `DPQuery`s for continual
online observation queries relying on `tree_aggregation`. 'Online' means that
the leaf nodes of the tree arrive one by one as the time proceeds. The core
logic of tree aggregation is implemented in `tree_aggregation.TreeAggregator`
and `tree_aggregation.EfficientTreeAggregator`.
"""
import attr
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation
# TODO(b/193679963): define `RestartQuery` and move `RestartIndicator` to be
# in the same module.
class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for adding correlated noise through tree structure.
"""Returns private cumulative sums by clipping and adding correlated noise.
First clips and sums records in current sample, returns cumulative sum of
samples over time (instead of only current sample) with added noise for
cumulative sum proportional to log(T), T being the number of times the query
is called.
Consider calling `get_noised_result` T times, and each (x_i, i=0,2,...,T-1) is
the private value returned by `accumulate_record`, i.e. x_i = sum_{j=0}^{n-1}
x_{i,j} where each x_{i,j} is a private record in the database. This class is
intended to make multiple queries, which release privatized values of the
cumulative sums s_i = sum_{k=0}^{i} x_k, for i=0,...,T-1.
Each call to `get_noised_result` releases the next cumulative sum s_i, which
is in contrast to the GaussianSumQuery that releases x_i. Noise for the
cumulative sums is accomplished using the tree aggregation logic in
`tree_aggregation`, which is proportional to log(T).
Example usage:
query = TreeCumulativeSumQuery(...)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i, samples in enumerate(streaming_samples):
sample_state = query.initial_sample_state(samples[0])
# Compute x_i = sum_{j=0}^{n-1} x_{i,j}
for j,sample in enumerate(samples):
sample_state = query.accumulate_record(params, sample_state, sample)
# noised_cumsum is privatized estimate of s_i
noised_cumsum, global_state, event = query.get_noised_result(
sample_state, global_state)
Attributes:
clip_fn: Callable that specifies clipping function. `clip_fn` receives two
arguments: a flat list of vars in a record and a `clip_value` to clip the
corresponding record, e.g. clip_fn(flat_record, clip_value).
corresponding record, e.g. clip_fn(flat_record, clip_value).
clip_value: float indicating the value at which to clip the record.
record_specs: `Collection[tf.TensorSpec]` specifying shapes of records.
tree_aggregator: `tree_aggregation.TreeAggregator` initialized with
user defined `noise_generator`. `noise_generator` is a
tree_aggregator: `tree_aggregation.TreeAggregator` initialized with user
defined `noise_generator`. `noise_generator` is a
`tree_aggregation.ValueGenerator` to generate the noise value for a tree
node. Noise stdandard deviation is specified outside the `dp_query` by the
user when defining `noise_fn` and should have order
@ -94,11 +122,10 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
initial_tree_state = self._tree_aggregator.init_state()
initial_samples_cumulative_sum = tf.nest.map_structure(
lambda spec: tf.zeros(spec.shape), self._record_specs)
initial_state = TreeCumulativeSumQuery.GlobalState(
return TreeCumulativeSumQuery.GlobalState(
tree_state=initial_tree_state,
clip_value=tf.constant(self._clip_value, tf.float32),
samples_cumulative_sum=initial_samples_cumulative_sum)
return initial_state
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@ -139,13 +166,36 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
tf.add, global_state.samples_cumulative_sum, sample_state)
cumulative_sum_noise, new_tree_state = self._tree_aggregator.get_cumsum_and_update(
global_state.tree_state)
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
cumulative_sum_noise)
new_global_state = attr.evolve(
global_state,
samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state)
noised_cum_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
cumulative_sum_noise)
return noised_cum_sum, new_global_state
event = dp_event.UnsupportedDpEvent()
return noised_cumulative_sum, new_global_state, event
def reset_state(self, noised_results, global_state):
"""Returns state after resetting the tree.
This function will be used in `restart_query.RestartQuery` after calling
`get_noised_result` when the restarting condition is met.
Args:
noised_results: Noised cumulative sum returned by `get_noised_result`.
global_state: Updated global state returned by `get_noised_result`, which
has current sample's cumulative sum and tree state for the next
cumulative sum.
Returns:
New global state with current noised cumulative sum and restarted tree
state for the next cumulative sum.
"""
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
samples_cumulative_sum=noised_results,
tree_state=new_tree_state)
@classmethod
def build_l2_gaussian_query(cls,
@ -194,22 +244,47 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for adding correlated noise through tree structure.
"""Implements DPQuery for adding correlated noise through tree structure.
Clips and sums records in current sample; returns the current sample adding
the noise residual from tree aggregation. The returned value is conceptually
equivalent to the following: calculates cumulative sum of samples over time
(instead of only current sample) with added noise for cumulative sum
proportional to log(T), T being the number of times the query is called;
returns the residual between the current noised cumsum and the previous one
when the query is called. Combining this query with a SGD optimizer can be
used to implement the DP-FTRL algorithm in
Clips and sums records in current sample x_i = sum_{j=0}^{n-1} x_{i,j};
returns the current sample adding the noise residual from tree aggregation.
The returned value is conceptually equivalent to the following: calculates
cumulative sum of samples over time s_i = sum_{k=0}^i x_i (instead of only
current sample) with added noise by tree aggregation protocol that is
proportional to log(T), T being the number of times the query is called; r
eturns the residual between the current noised cumsum noised(s_i) and the
previous one noised(s_{i-1}) when the query is called.
This can be used as a drop-in replacement for `GaussianSumQuery`, and can
offer stronger utility/privacy tradeoffs when aplification-via-sampling is not
possible, or when privacy epsilon is relativly large. This may result in
more noise by a log(T) factor in each individual estimate of x_i, but if the
x_i are used in the underlying code to compute cumulative sums, the noise in
those sums can be less. That is, this allows us to adapt code that was written
to use a regular `SumQuery` to benefit from the tree aggregation protocol.
Combining this query with a SGD optimizer can be used to implement the
DP-FTRL algorithm in
"Practical and Private (Deep) Learning without Sampling or Shuffling".
Example usage:
query = TreeResidualSumQuery(...)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i, samples in enumerate(streaming_samples):
sample_state = query.initial_sample_state(samples[0])
# Compute x_i = sum_{j=0}^{n-1} x_{i,j}
for j,sample in enumerate(samples):
sample_state = query.accumulate_record(params, sample_state, sample)
# noised_sum is privatized estimate of x_i by conceptually postprocessing
# noised cumulative sum s_i
noised_sum, global_state, event = query.get_noised_result(
sample_state, global_state)
Attributes:
clip_fn: Callable that specifies clipping function. `clip_fn` receives two
arguments: a flat list of vars in a record and a `clip_value` to clip the
corresponding record, e.g. clip_fn(flat_record, clip_value).
corresponding record, e.g. clip_fn(flat_record, clip_value).
clip_value: float indicating the value at which to clip the record.
record_specs: A nested structure of `tf.TensorSpec`s specifying structure
and shapes of records.
@ -242,10 +317,10 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
clip_fn,
clip_value,
use_efficient=True):
"""Initializes the `TreeResidualSumQuery`.
"""Initializes the `TreeCumulativeSumQuery`.
Consider using `build_l2_gaussian_query` for the construction of a
`TreeResidualSumQuery` with L2 norm clipping and Gaussian noise.
`TreeCumulativeSumQuery` with L2 norm clipping and Gaussian noise.
Args:
record_specs: A nested structure of `tf.TensorSpec`s specifying structure
@ -269,20 +344,39 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
else:
self._tree_aggregator = tree_aggregation.TreeAggregator(noise_generator)
def _zero_initial_noise(self):
return tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
self._record_specs)
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
initial_tree_state = self._tree_aggregator.init_state()
initial_noise = tf.nest.map_structure(lambda spec: tf.zeros(spec.shape),
self._record_specs)
return TreeResidualSumQuery.GlobalState(
tree_state=initial_tree_state,
clip_value=tf.constant(self._clip_value, tf.float32),
previous_tree_noise=initial_noise)
previous_tree_noise=self._zero_initial_noise())
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return global_state.clip_value
def preprocess_record_l2_impl(self, params, record):
"""Clips the l2 norm, returning the clipped record and the l2 norm.
Args:
params: The parameters for the sample.
record: The record to be processed.
Returns:
A tuple (preprocessed_records, l2_norm) where `preprocessed_records` is
the structure of preprocessed tensors, and l2_norm is the total l2 norm
before clipping.
"""
l2_norm_clip = params
record_as_list = tf.nest.flatten(record)
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
return tf.nest.pack_sequence_as(record, clipped_as_list), norm
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
@ -318,7 +412,41 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
global_state.previous_tree_noise)
new_global_state = attr.evolve(
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
return noised_sample, new_global_state
event = dp_event.UnsupportedDpEvent()
return noised_sample, new_global_state, event
def reset_state(self, noised_results, global_state):
"""Returns state after resetting the tree.
This function will be used in `restart_query.RestartQuery` after calling
`get_noised_result` when the restarting condition is met.
Args:
noised_results: Noised results returned by `get_noised_result`.
global_state: Updated global state returned by `get_noised_result`, which
records noise for the conceptual cumulative sum of the current leaf
node, and tree state for the next conceptual cumulative sum.
Returns:
New global state with zero noise and restarted tree state.
"""
del noised_results
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
previous_tree_noise=self._zero_initial_noise(),
tree_state=new_tree_state)
def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
noise_generator_state = global_state.tree_state.value_generator_state
assert isinstance(self._tree_aggregator.value_generator,
tree_aggregation.GaussianNoiseGenerator)
noise_generator_state = self._tree_aggregator.value_generator.make_state(
noise_generator_state.seeds, stddev)
new_tree_state = attr.evolve(
global_state.tree_state, value_generator_state=noise_generator_state)
return attr.evolve(
global_state, clip_value=clip_norm, tree_state=new_tree_state)
@classmethod
def build_l2_gaussian_query(cls,
@ -342,8 +470,8 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
aggregation algorithm based on the paper "Efficient Use of
Differentially Private Binary Trees".
"""
if clip_norm <= 0:
raise ValueError(f'`clip_norm` must be positive, got {clip_norm}.')
if clip_norm < 0:
raise ValueError(f'`clip_norm` must be non-negative, got {clip_norm}.')
if noise_multiplier < 0:
raise ValueError(

View file

@ -14,15 +14,12 @@
"""Tests for `tree_aggregation_query`."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import test_utils
from tensorflow_privacy.privacy.dp_query import tree_aggregation
from tensorflow_privacy.privacy.dp_query import tree_aggregation_query
STRUCT_RECORD = [
tf.constant([[2.0, 0.0], [0.0, 1.0]]),
tf.constant([-1.0, 0.0])
@ -55,6 +52,7 @@ def _get_noise_fn(specs, stddev=NOISE_STD, seed=1):
def _get_no_noise_fn(specs):
shape = tf.nest.map_structure(lambda spec: spec.shape, specs)
def no_noise_fn():
return tf.nest.map_structure(tf.zeros, shape)
@ -73,6 +71,7 @@ def _get_l2_clip_fn():
def _get_l_infty_clip_fn():
def l_infty_clip_fn(record_as_list, clip_value):
def clip(record):
return tf.clip_by_value(
record, clip_value_min=-clip_value, clip_value_max=clip_value)
@ -213,16 +212,16 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('two_records_noise_fn', [2.71828, 3.14159], _get_noise_fn),
('five_records_noise_fn', np.random.uniform(size=5).tolist(),
('five_records_noise_fn', np.random.uniform(low=0.1, size=5).tolist(),
_get_noise_fn),
('two_records_generator', [2.71828, 3.14159], _get_noise_generator),
('five_records_generator', np.random.uniform(size=5).tolist(),
('five_records_generator', np.random.uniform(low=0.1, size=5).tolist(),
_get_noise_generator),
)
def test_noisy_cumsum_and_state_update(self, records, value_generator):
num_trials = 200
record_specs = tf.nest.map_structure(lambda t: tf.TensorSpec(tf.shape(t)),
records[0])
num_trials, vector_size = 10, 100
record_specs = tf.TensorSpec([vector_size])
records = [tf.constant(r, shape=[vector_size]) for r in records]
noised_sums = []
for i in range(num_trials):
query = tree_aggregation_query.TreeCumulativeSumQuery(
@ -231,7 +230,7 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
noise_generator=value_generator(record_specs, seed=i),
record_specs=record_specs)
query_result, _ = test_utils.run_query(query, records)
noised_sums.append(query_result)
noised_sums.append(query_result.numpy())
result_stddev = np.std(noised_sums)
self.assertNear(result_stddev, NOISE_STD, 0.7) # value for chi-squared test
@ -257,18 +256,18 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
for scalar, expected_sum in zip(streaming_scalars, partial_sum):
sample_state = query.initial_sample_state(scalar)
sample_state = query.accumulate_record(params, sample_state, scalar)
query_result, global_state = query.get_noised_result(
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
self.assertEqual(query_result, expected_sum)
@parameterized.named_parameters(
('s0t1step8', 0., 1., [1., 1., 2., 1., 2., 2., 3., 1.]),
('s1t1step8', 1., 1., [2., 3., 5., 5., 7., 8., 10., 9.]),
('s1t2step8', 1., 2., [3., 4., 7., 6., 9., 10., 13., 10.]),
('s0t1', 0., 1.),
('s1t1', 1., 1.),
('s1t2', 1., 2.),
)
def test_partial_sum_scalar_tree_aggregation(self, scalar_value,
tree_node_value,
expected_values):
tree_node_value):
total_steps = 8
query = tree_aggregation_query.TreeCumulativeSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
@ -278,14 +277,53 @@ class TreeCumulativeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for val in expected_values:
# For each streaming step i , the expected value is roughly
# `scalar_value*i + tree_aggregation(tree_node_value, i)`
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state = query.get_noised_result(
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
self.assertEqual(query_result, val)
# For each streaming step i , the expected value is roughly
# `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`.
# The tree aggregation value can be inferred from the binary
# representation of the current step.
self.assertEqual(
query_result,
scalar_value * (i + 1) + tree_node_value * bin(i + 1)[2:].count('1'))
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_sum_scalar_tree_aggregation_reset(self, scalar_value,
tree_node_value, frequency):
total_steps = 20
query = tree_aggregation_query.TreeCumulativeSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
if i % frequency == frequency - 1:
global_state = query.reset_state(query_result, global_state)
# Expected value is the combination of cumsum of signal; sum of trees
# that have been reset; current tree sum. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = (
scalar_value * (i + 1) +
i // frequency * tree_node_value * bin(frequency)[2:].count('1') +
tree_node_value * bin(i % frequency + 1)[2:].count('1'))
self.assertEqual(query_result, expected)
@parameterized.named_parameters(
('efficient', True, tree_aggregation.EfficientTreeAggregator),
@ -394,6 +432,41 @@ class TreeResidualQueryTest(tf.test.TestCase, parameterized.TestCase):
)
self.assertIsInstance(query._tree_aggregator, tree_class)
@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
('s0t1f5', 0., 1., 5),
('s1t1f5', 1., 1., 5),
('s1t2f2', 1., 2., 2),
('s1t5f6', 1., 5., 6),
)
def test_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value,
frequency):
total_steps = 20
query = tree_aggregation_query.TreeResidualSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=scalar_value + 1., # no clip
noise_generator=lambda: tree_node_value,
record_specs=tf.TensorSpec([]),
use_efficient=False)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
for i in range(total_steps):
sample_state = query.initial_sample_state(scalar_value)
sample_state = query.accumulate_record(params, sample_state, scalar_value)
query_result, global_state, _ = query.get_noised_result(
sample_state, global_state)
if i % frequency == frequency - 1:
global_state = query.reset_state(query_result, global_state)
# Expected value is the signal of the current round plus the residual of
# two continous tree aggregation values. The tree aggregation value can
# be inferred from the binary representation of the current step.
expected = scalar_value + tree_node_value * (
bin(i % frequency + 1)[2:].count('1') -
bin(i % frequency)[2:].count('1'))
print(i, query_result, expected)
self.assertEqual(query_result, expected)
if __name__ == '__main__':
tf.test.main()

View file

@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for `tree_aggregation`."""
import math
import random
from absl.testing import parameterized
import tensorflow as tf
@ -297,7 +298,11 @@ class EfficientTreeAggregatorTest(tf.test.TestCase, parameterized.TestCase):
tf.nest.map_structure(self.assertAllClose, val, expected_result)
class GaussianNoiseGeneratorTest(tf.test.TestCase):
class GaussianNoiseGeneratorTest(tf.test.TestCase, parameterized.TestCase):
def assertStateEqual(self, state1, state2):
for s1, s2 in zip(tf.nest.flatten(state1), tf.nest.flatten(state2)):
self.assertAllEqual(s1, s2)
def test_random_generator_tf(self,
noise_mean=1.0,
@ -330,12 +335,12 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
g2 = tree_aggregation.GaussianNoiseGenerator(
noise_std=noise_std, specs=tf.TensorSpec([]), seed=seed)
gstate2 = g.initialize()
self.assertAllEqual(gstate, gstate2)
self.assertStateEqual(gstate, gstate2)
for _ in range(steps):
value, gstate = g.next(gstate)
value2, gstate2 = g2.next(gstate2)
self.assertAllEqual(value, value2)
self.assertAllEqual(gstate, gstate2)
self.assertStateEqual(gstate, gstate2)
def test_seed_state_nondeterministic(self, steps=32, noise_std=0.1):
g = tree_aggregation.GaussianNoiseGenerator(
@ -344,11 +349,12 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
g2 = tree_aggregation.GaussianNoiseGenerator(
noise_std=noise_std, specs=tf.TensorSpec([]))
gstate2 = g2.initialize()
self.assertNotAllEqual(gstate.seeds, gstate2.seeds)
for _ in range(steps):
value, gstate = g.next(gstate)
value2, gstate2 = g2.next(gstate2)
self.assertNotAllEqual(value, value2)
self.assertNotAllEqual(gstate, gstate2)
self.assertNotAllEqual(gstate.seeds, gstate2.seeds)
def test_seed_state_structure(self, seed=1, steps=32, noise_std=0.1):
specs = [tf.TensorSpec([]), tf.TensorSpec([1]), tf.TensorSpec([2, 2])]
@ -358,45 +364,36 @@ class GaussianNoiseGeneratorTest(tf.test.TestCase):
g2 = tree_aggregation.GaussianNoiseGenerator(
noise_std=noise_std, specs=specs, seed=seed)
gstate2 = g2.initialize()
self.assertStateEqual(gstate, gstate2)
for _ in range(steps):
value, gstate = g.next(gstate)
value2, gstate2 = g2.next(gstate2)
self.assertAllClose(value, value2)
self.assertAllEqual(gstate, gstate2)
self.assertStateEqual(gstate, gstate2)
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
leaf_nodes_size=[1, 2, 3, 4, 5],
arity=[2, 3],
dtype=[tf.int32, tf.float32],
@parameterized.named_parameters(
('increase', range(10), 1),
('decrease', range(30, 20, -2), 2),
('flat', [3.0] * 5, 1),
('small', [0.1**x for x in range(4)], 4),
('random', [random.uniform(1, 10) for _ in range(5)], 4),
)
def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype):
"""Test whether `build_tree_from_leaf` will output the correct tree."""
leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype)
depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1
tree = tree_aggregation.build_tree_from_leaf(leaf_nodes, arity)
self.assertEqual(depth, tree.shape[0])
for layer in range(depth):
reverse_depth = tree.shape[0] - layer - 1
span_size = arity**reverse_depth
for idx in range(arity**layer):
left = idx * span_size
right = (idx + 1) * span_size
expected_value = sum(leaf_nodes[left:right])
self.assertEqual(tree[layer][idx], expected_value)
@parameterized.named_parameters(('negative_arity', [1], -1),
('empty_hist', [], 2))
def test_value_error_raises(self, leaf_nodes, arity):
"""Test whether `build_tree_from_leaf` will raise the correct error when the input is illegal."""
with self.assertRaises(ValueError):
tree_aggregation.build_tree_from_leaf(leaf_nodes, arity)
def test_adaptive_stddev(self, stddev_list, reset_frequency):
# The stddev estimation follows a chi distribution. The confidence for
# `sample_num` samples should be high, and we use a relatively large
# tolerance to guard the numerical stability for small stddev values.
sample_num, tolerance = 10000, 0.05
g = tree_aggregation.GaussianNoiseGenerator(
noise_std=1., specs=tf.TensorSpec([sample_num]), seed=2021)
gstate = g.initialize()
for stddev in stddev_list:
gstate = g.make_state(gstate.seeds, tf.constant(stddev, dtype=tf.float32))
for _ in range(reset_frequency):
prev_gstate = gstate
value, gstate = g.next(gstate)
print(tf.math.reduce_std(value), stddev)
self.assertAllClose(tf.math.reduce_std(value), stddev, rtol=tolerance)
self.assertNotAllEqual(gstate.seeds, prev_gstate.seeds)
if __name__ == '__main__':

View file

@ -0,0 +1,283 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""`DPQuery`s for offline differentially private tree aggregation protocols.
'Offline' means all the leaf nodes are ready before the protocol starts.
"""
import distutils
import math
import attr
import tensorflow as tf
from tensorflow_privacy.privacy.analysis import dp_event
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query
def _build_tree_from_leaf(leaf_nodes: tf.Tensor, arity: int) -> tf.RaggedTensor:
"""A function constructs a complete tree given all the leaf nodes.
The function takes a 1-D array representing the leaf nodes of a tree and the
tree's arity, and constructs a complete tree by recursively summing the
adjacent children to get the parent until reaching the root node. Because we
assume a complete tree, if the number of leaf nodes does not divide arity, the
leaf nodes will be padded with zeros.
Args:
leaf_nodes: A 1-D array storing the leaf nodes of the tree.
arity: A `int` for the branching factor of the tree, i.e. the number of
children for each internal node.
Returns:
`tf.RaggedTensor` representing the tree. For example, if
`leaf_nodes=tf.Tensor([1, 2, 3, 4])` and `arity=2`, then the returned value
should be `tree=tf.RaggedTensor([[10],[3,7],[1,2,3,4]])`. In this way,
`tree[layer][index]` can be used to access the node indexed by (layer,
index) in the tree,
"""
def pad_zero(leaf_nodes, size):
paddings = tf.zeros(
shape=(size - leaf_nodes.shape[0],), dtype=leaf_nodes.dtype)
return tf.concat((leaf_nodes, paddings), axis=0)
leaf_nodes_size = tf.constant(leaf_nodes.shape[0], dtype=tf.float32)
num_layers = tf.math.ceil(
tf.math.log(leaf_nodes_size) /
tf.math.log(tf.cast(arity, dtype=tf.float32))) + 1
leaf_nodes = pad_zero(
leaf_nodes, tf.math.pow(tf.cast(arity, dtype=tf.float32), num_layers - 1))
def _shrink_layer(layer: tf.Tensor, arity: int) -> tf.Tensor:
return tf.reduce_sum((tf.reshape(layer, (-1, arity))), 1)
# The following `tf.while_loop` constructs the tree from bottom up by
# iteratively applying `_shrink_layer` to each layer of the tree. The reason
# for the choice of TF1.0-style `tf.while_loop` is that @tf.function does not
# support auto-translation from python loop to tf loop when loop variables
# contain a `RaggedTensor` whose shape changes across iterations.
idx = tf.identity(num_layers)
loop_cond = lambda i, h: tf.less_equal(2.0, i)
def _loop_body(i, h):
return [
tf.add(i, -1.0),
tf.concat(([_shrink_layer(h[0], arity)], h), axis=0)
]
_, tree = tf.while_loop(
loop_cond,
_loop_body, [idx, tf.RaggedTensor.from_tensor([leaf_nodes])],
shape_invariants=[
idx.get_shape(),
tf.RaggedTensorSpec(dtype=leaf_nodes.dtype, ragged_rank=1)
])
return tree
class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
"""Implements dp_query for accurate range queries using tree aggregation.
Implements a variant of the tree aggregation protocol from. "Is interaction
necessary for distributed private learning?. Adam Smith, Abhradeep Thakurta,
Jalaj Upadhyay." Builds a tree on top of the input record and adds noise to
the tree for differential privacy. Any range query can be decomposed into the
sum of O(log(n)) nodes in the tree compared to O(n) when using a histogram.
Improves efficiency and reduces noise scale.
"""
@attr.s(frozen=True)
class GlobalState(object):
"""Class defining global state for TreeRangeSumQuery.
Attributes:
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
inner_query_state: The global state of the inner query.
"""
arity = attr.ib()
inner_query_state = attr.ib()
def __init__(self,
inner_query: dp_query.SumAggregationDPQuery,
arity: int = 2):
"""Initializes the `TreeRangeSumQuery`.
Args:
inner_query: The inner `DPQuery` that adds noise to the tree.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
"""
self._inner_query = inner_query
self._arity = arity
if self._arity < 1:
raise ValueError(f'Invalid arity={arity} smaller than 2.')
def initial_global_state(self):
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
return TreeRangeSumQuery.GlobalState(
arity=self._arity,
inner_query_state=self._inner_query.initial_global_state())
def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return (global_state.arity,
self._inner_query.derive_sample_params(
global_state.inner_query_state))
def preprocess_record(self, params, record):
"""Implements `tensorflow_privacy.DPQuery.preprocess_record`.
This method builds the tree, flattens it and applies
`inner_query.preprocess_record` to the flattened tree.
Args:
params: Hyper-parameters for preprocessing record.
record: A histogram representing the leaf nodes of the tree.
Returns:
A `tf.Tensor` representing the flattened version of the preprocessed tree.
"""
arity, inner_query_params = params
preprocessed_record = _build_tree_from_leaf(record, arity).flat_values
# The following codes reshape the output vector so the output shape of can
# be statically inferred. This is useful when used with
# `tff.aggregators.DifferentiallyPrivateFactory` because it needs to know
# the output shape of this function statically and explicitly.
preprocessed_record_shape = [
(self._arity**(math.ceil(math.log(record.shape[0], self._arity)) + 1) -
1) // (self._arity - 1)
]
preprocessed_record = tf.reshape(preprocessed_record,
preprocessed_record_shape)
preprocessed_record = self._inner_query.preprocess_record(
inner_query_params, preprocessed_record)
return preprocessed_record
def get_noised_result(self, sample_state, global_state):
"""Implements `tensorflow_privacy.DPQuery.get_noised_result`.
This function re-constructs the `tf.RaggedTensor` from the flattened tree
output by `preprocess_records.`
Args:
sample_state: A `tf.Tensor` for the flattened tree.
global_state: The global state of the protocol.
Returns:
A `tf.RaggedTensor` representing the tree.
"""
# The [0] is needed because of how tf.RaggedTensor.from_two_splits works.
# print(tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
# row_splits=[0, 4, 4, 7, 8, 8]))
# <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
# This part is not written in tensorflow and will be executed on the server
# side instead of the client side if used with
# tff.aggregators.DifferentiallyPrivateFactory for federated learning.
sample_state, inner_query_state, _ = self._inner_query.get_noised_result(
sample_state, global_state.inner_query_state)
new_global_state = TreeRangeSumQuery.GlobalState(
arity=global_state.arity, inner_query_state=inner_query_state)
row_splits = [0] + [
(self._arity**(x + 1) - 1) // (self._arity - 1) for x in range(
math.floor(math.log(sample_state.shape[0], self._arity)) + 1)
]
tree = tf.RaggedTensor.from_row_splits(
values=sample_state, row_splits=row_splits)
event = dp_event.UnsupportedDpEvent()
return tree, new_global_state, event
@classmethod
def build_central_gaussian_query(cls,
l2_norm_clip: float,
stddev: float,
arity: int = 2):
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
Args:
l2_norm_clip: Each record should be clipped so that it has L2 norm at most
`l2_norm_clip`.
stddev: Stddev of the central Gaussian noise.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
"""
if l2_norm_clip <= 0:
raise ValueError(f'`l2_norm_clip` must be positive, got {l2_norm_clip}.')
if stddev < 0:
raise ValueError(f'`stddev` must be non-negative, got {stddev}.')
if arity < 2:
raise ValueError(f'`arity` must be at least 2, got {arity}.')
inner_query = gaussian_query.GaussianSumQuery(l2_norm_clip, stddev)
return cls(arity=arity, inner_query=inner_query)
@classmethod
def build_distributed_discrete_gaussian_query(cls,
l2_norm_bound: float,
local_stddev: float,
arity: int = 2):
"""Returns `TreeRangeSumQuery` with central Gaussian noise.
Args:
l2_norm_bound: Each record should be clipped so that it has L2 norm at
most `l2_norm_bound`.
local_stddev: Scale/stddev of the local discrete Gaussian noise.
arity: The branching factor of the tree (i.e. the number of children each
internal node has). Defaults to 2.
"""
if l2_norm_bound <= 0:
raise ValueError(
f'`l2_clip_bound` must be positive, got {l2_norm_bound}.')
if local_stddev < 0:
raise ValueError(
f'`local_stddev` must be non-negative, got {local_stddev}.')
if arity < 2:
raise ValueError(f'`arity` must be at least 2, got {arity}.')
inner_query = distributed_discrete_gaussian_query.DistributedDiscreteGaussianSumQuery(
l2_norm_bound, local_stddev)
return cls(arity=arity, inner_query=inner_query)
def _get_add_noise(stddev, seed: int = None):
"""Utility function to decide which `add_noise` to use according to tf version."""
if distutils.version.LooseVersion(
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
# The seed should be only used for testing purpose.
if seed is not None:
tf.random.set_seed(seed)
def add_noise(v):
return v + tf.random.normal(
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
else:
random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed)
def add_noise(v):
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
return add_noise

View file

@ -0,0 +1,182 @@
# Copyright 2021, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `tree_range_query`."""
import math
from absl.testing import parameterized
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import tree_range_query
class BuildTreeTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
leaf_nodes_size=[1, 2, 3, 4, 5],
arity=[2, 3],
dtype=[tf.int32, tf.float32],
)
def test_build_tree_from_leaf(self, leaf_nodes_size, arity, dtype):
"""Test whether `_build_tree_from_leaf` will output the correct tree."""
leaf_nodes = tf.cast(tf.range(leaf_nodes_size), dtype)
depth = math.ceil(math.log(leaf_nodes_size, arity)) + 1
tree = tree_range_query._build_tree_from_leaf(leaf_nodes, arity)
self.assertEqual(depth, tree.shape[0])
for layer in range(depth):
reverse_depth = tree.shape[0] - layer - 1
span_size = arity**reverse_depth
for idx in range(arity**layer):
left = idx * span_size
right = (idx + 1) * span_size
expected_value = sum(leaf_nodes[left:right])
self.assertEqual(tree[layer][idx], expected_value)
class TreeRangeSumQueryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.product(
inner_query=['central', 'distributed'],
params=[(0., 1., 2), (1., -1., 2), (1., 1., 1)],
)
def test_raises_error(self, inner_query, params):
clip_norm, stddev, arity = params
with self.assertRaises(ValueError):
if inner_query == 'central':
tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev, arity)
elif inner_query == 'distributed':
tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev, arity)
@parameterized.product(
inner_query=['central', 'distributed'],
clip_norm=[0.1, 1.0, 10.0],
stddev=[0.1, 1.0, 10.0])
def test_initial_global_state_type(self, inner_query, clip_norm, stddev):
if inner_query == 'central':
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev)
elif inner_query == 'distributed':
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev)
global_state = query.initial_global_state()
self.assertIsInstance(global_state,
tree_range_query.TreeRangeSumQuery.GlobalState)
@parameterized.product(
inner_query=['central', 'distributed'],
clip_norm=[0.1, 1.0, 10.0],
stddev=[0.1, 1.0, 10.0],
arity=[2, 3, 4])
def test_derive_sample_params(self, inner_query, clip_norm, stddev, arity):
if inner_query == 'central':
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
clip_norm, stddev, arity)
elif inner_query == 'distributed':
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
clip_norm, stddev, arity)
global_state = query.initial_global_state()
derived_arity, inner_query_state = query.derive_sample_params(global_state)
self.assertAllClose(derived_arity, arity)
if inner_query == 'central':
self.assertAllClose(inner_query_state, clip_norm)
elif inner_query == 'distributed':
self.assertAllClose(inner_query_state.l2_norm_bound, clip_norm)
self.assertAllClose(inner_query_state.local_stddev, stddev)
@parameterized.product(
(dict(arity=2, expected_tree=[1, 1, 0, 1, 0, 0, 0]),
dict(arity=3, expected_tree=[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])),
inner_query=['central', 'distributed'],
)
def test_preprocess_record(self, inner_query, arity, expected_tree):
if inner_query == 'central':
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
expected_tree = tf.cast(expected_tree, tf.float32)
elif inner_query == 'distributed':
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(preprocessed_record, expected_tree)
@parameterized.named_parameters(
('stddev_1', 1, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
('stddev_0_1', 4, tf.constant([1, 0], dtype=tf.int32), [1, 1, 0]),
)
def test_distributed_preprocess_record_with_noise(self, local_stddev, record,
expected_tree):
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., local_stddev)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
self.assertAllClose(
preprocessed_record, expected_tree, atol=10 * local_stddev)
@parameterized.product(
(dict(
arity=2,
expected_tree=tf.ragged.constant([[1], [1, 0], [1, 0, 0, 0]])),
dict(
arity=3,
expected_tree=tf.ragged.constant([[1], [1, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0]]))),
inner_query=['central', 'distributed'],
)
def test_get_noised_result(self, inner_query, arity, expected_tree):
if inner_query == 'central':
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.float32)
expected_tree = tf.cast(expected_tree, tf.float32)
elif inner_query == 'distributed':
query = tree_range_query.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
10., 0., arity)
record = tf.constant([1, 0, 0, 0], dtype=tf.int32)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, record)
sample_state, global_state, _ = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(sample_state, expected_tree)
@parameterized.product(stddev=[0.1, 1.0, 10.0])
def test_central_get_noised_result_with_noise(self, stddev):
query = tree_range_query.TreeRangeSumQuery.build_central_gaussian_query(
10., stddev)
global_state = query.initial_global_state()
params = query.derive_sample_params(global_state)
preprocessed_record = query.preprocess_record(params, tf.constant([1., 0.]))
sample_state, global_state, _ = query.get_noised_result(
preprocessed_record, global_state)
self.assertAllClose(
sample_state, tf.ragged.constant([[1.], [1., 0.]]), atol=10 * stddev)
if __name__ == '__main__':
tf.test.main()

View file

@ -21,7 +21,6 @@ from absl import logging
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
@ -99,6 +98,7 @@ def make_optimizer_class(cls):
dp_sum_query,
num_microbatches=None,
unroll_microbatches=False,
while_loop_parallel_iterations=10,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Initializes the DPOptimizerClass.
@ -112,6 +112,10 @@ def make_optimizer_class(cls):
unroll_microbatches: If true, processes microbatches within a Python
loop instead of a `tf.while_loop`. Can be used if using a
`tf.while_loop` raises an exception.
while_loop_parallel_iterations: The number of iterations allowed to run
in parallel. It must be a positive integer. Applicable only when
unroll_microbatches is set to False. It gives users some control over
memory consumption.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
@ -123,6 +127,7 @@ def make_optimizer_class(cls):
# Beware: When num_microbatches is large (>100), enabling this parameter
# may cause an OOM error.
self._unroll_microbatches = unroll_microbatches
self._while_loop_parallel_iterations = while_loop_parallel_iterations
self._was_compute_gradients_called = False
def compute_gradients(self,
@ -165,9 +170,9 @@ def make_optimizer_class(cls):
for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state)
grad_sums, self._global_state = (
self._dp_sum_query.get_noised_result(
sample_state, self._global_state))
grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
def normalize(v):
return v / tf.cast(self._num_microbatches, tf.float32)
@ -178,10 +183,6 @@ def make_optimizer_class(cls):
return grads_and_vars
else:
# TF is running in graph mode. Check we did not receive a gradient tape.
if gradient_tape:
raise ValueError('When in graph mode, a tape should not be passed.')
# Note: it would be closer to the correct i.i.d. sampling of records if
# we sampled each microbatch from the appropriate binomial distribution,
# although that still wouldn't be quite correct because it would be
@ -197,8 +198,8 @@ def make_optimizer_class(cls):
"""Process one microbatch (record) with privacy helper."""
self_super = super(DPOptimizerClass, self)
mean_loss = tf.reduce_mean(input_tensor=tf.gather(
microbatches_losses, [i]))
mean_loss = tf.reduce_mean(
input_tensor=tf.gather(microbatches_losses, [i]))
if hasattr(self_super, 'compute_gradients'):
# This case covers optimizers in tf.train.
@ -207,10 +208,15 @@ def make_optimizer_class(cls):
# This case covers Keras optimizers from optimizers_v2.
compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access
grads, _ = zip(*compute_gradients_fn(
mean_loss, var_list, gate_gradients,
aggregation_method, colocate_gradients_with_ops, grad_loss))
grads_list = list(grads)
if gradient_tape:
# This is intended to work for TF2 and may not work for TF1.
with gradient_tape.stop_recording():
grads_list = list(gradient_tape.gradient(mean_loss, var_list))
else:
grads, _ = zip(*compute_gradients_fn(
mean_loss, var_list, gate_gradients, aggregation_method,
colocate_gradients_with_ops, grad_loss))
grads_list = list(grads)
sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, grads_list)
@ -218,8 +224,8 @@ def make_optimizer_class(cls):
if var_list is None:
var_list = (
tf.trainable_variables() + tf.get_collection(
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
tf.trainable_variables() +
tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
sample_state = self._dp_sum_query.initial_sample_state(var_list)
@ -234,11 +240,14 @@ def make_optimizer_class(cls):
body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long
idx = tf.constant(0)
_, sample_state = tf.while_loop(
cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state])
cond=cond_fn,
body=body_fn,
loop_vars=[idx, sample_state],
parallel_iterations=self._while_loop_parallel_iterations)
grad_sums, self._global_state = (
self._dp_sum_query.get_noised_result(
sample_state, self._global_state))
grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
def normalize(v):
try:
@ -307,9 +316,7 @@ def make_gaussian_optimizer_class(cls):
```
""").format(
'tf.compat.v1.train.' + cls.__name__,
cls.__name__,
cls.__name__,
'tf.compat.v1.train.' + cls.__name__, cls.__name__, cls.__name__,
'DP' + cls.__name__.replace('Optimizer', 'GaussianOptimizer'))
def __init__(
@ -317,7 +324,6 @@ def make_gaussian_optimizer_class(cls):
l2_norm_clip,
noise_multiplier,
num_microbatches=None,
ledger=None,
unroll_microbatches=False,
*args, # pylint: disable=keyword-arg-before-vararg
**kwargs):
@ -329,7 +335,6 @@ def make_gaussian_optimizer_class(cls):
num_microbatches: Number of microbatches into which each minibatch is
split. If `None`, will default to the size of the minibatch, and
per-example gradients will be computed.
ledger: Defaults to `None`. An instance of `tf_privacy.PrivacyLedger`.
unroll_microbatches: If true, processes microbatches within a Python
loop instead of a `tf.while_loop`. Can be used if using a
`tf.while_loop` raises an exception.
@ -344,16 +349,9 @@ def make_gaussian_optimizer_class(cls):
dp_sum_query = gaussian_query.GaussianSumQuery(
l2_norm_clip, l2_norm_clip * noise_multiplier)
if ledger:
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query,
ledger=ledger)
super(DPGaussianOptimizerClass, self).__init__(
dp_sum_query,
num_microbatches,
unroll_microbatches,
*args,
**kwargs)
super(DPGaussianOptimizerClass,
self).__init__(dp_sum_query, num_microbatches, unroll_microbatches,
*args, **kwargs)
def get_config(self):
"""Creates configuration for Keras serialization.
@ -370,16 +368,14 @@ def make_gaussian_optimizer_class(cls):
config.update({
'l2_norm_clip': self._l2_norm_clip,
'noise_multiplier': self._noise_multiplier,
'num_microbatches': self._num_microbatches})
'num_microbatches': self._num_microbatches
})
return config
@property
def ledger(self):
return self._dp_sum_query.ledger
return DPGaussianOptimizerClass
AdagradOptimizer = tf.train.AdagradOptimizer
AdamOptimizer = tf.train.AdamOptimizer
GradientDescentOptimizer = tf.train.GradientDescentOptimizer

View file

@ -22,7 +22,6 @@ import numpy as np
from six.moves import range
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.optimizers import dp_optimizer
@ -56,13 +55,9 @@ class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase):
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(
dp_sum_query, 1e6, num_microbatches / 1e6)
opt = cls(
dp_sum_query,
num_microbatches=num_microbatches,
learning_rate=2.0)
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
@ -85,7 +80,6 @@ class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase):
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
@ -109,7 +103,6 @@ class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase):
data0 = tf.Variable([[0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(4.0, 8.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)

View file

@ -49,7 +49,7 @@ def make_keras_optimizer_class(cls):
```python
# Create optimizer.
opt = {dp_keras_class}(l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1,
opt = {dp_keras_class}(l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1,
<standard arguments>)
```
@ -81,9 +81,43 @@ def make_keras_optimizer_class(cls):
model.fit(...)
```
""".format(base_class='tf.keras.optimizers.' + cls.__name__,
short_base_class=cls.__name__,
dp_keras_class='DPKeras' + cls.__name__)
In DP-SGD training, a larger batch size typically helps to achieve better
privacy/utility tradeoff. However there is typically a maximum batch size
imposed by hardware.
This optimizer can emulate large batch sizes on hardware with limited
memory by accumulating gradients for several steps before actually
applying them to update model weights.
Constructor argument `gradient_accumulation_steps` controls the number
of steps for which gradients are accumulated before updating
the model weights.
Below is an example which demonstrates how to use this feature:
```python
# Create optimizer which will be accumulating gradients for 4 steps.
# and then performing an update of model weights.
opt = {dp_keras_class}(l2_norm_clip=1.0,
noise_multiplier=0.5,
num_microbatches=1,
gradient_accumulation_steps=4,
<standard arguments>)
# Use optimizer in a regular way.
# First three calls to opt.minimize won't update model weights and will
# only accumulate gradients. Model weights will be updated on the fourth
# call to opt.minimize
opt.minimize(loss, var_list=[var])
```
Note that when using this feature effective batch size is
`gradient_accumulation_steps * one_step_batch_size` where
`one_step_batch_size` size of the batch which is passed to single step
of the optimizer. Thus user may have to adjust learning rate, weight decay
and possibly other training hyperparameters accordingly.
""".format(
base_class='tf.keras.optimizers.' + cls.__name__,
short_base_class=cls.__name__,
dp_keras_class='DPKeras' + cls.__name__)
# The class tf.keras.optimizers.Optimizer has two methods to compute
# gradients, `_compute_gradients` and `get_gradients`. The first works
@ -99,6 +133,7 @@ def make_keras_optimizer_class(cls):
l2_norm_clip,
noise_multiplier,
num_microbatches=None,
gradient_accumulation_steps=1,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Initialize the DPOptimizerClass.
@ -106,12 +141,22 @@ def make_keras_optimizer_class(cls):
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch
is split.
num_microbatches: Number of microbatches into which each minibatch is
split. Default is `None` which means that number of microbatches
is equal to batch size (i.e. each microbatch contains exactly one
example). If `gradient_accumulation_steps` is greater than 1 and
`num_microbatches` is not `None` then the effective number of
microbatches is equal to
`num_microbatches * gradient_accumulation_steps`.
gradient_accumulation_steps: If greater than 1 then optimizer will be
accumulating gradients for this number of optimizer steps before
applying them to update model weights. If this argument is set to 1
then updates will be applied on each optimizer step.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
super(DPOptimizerClass, self).__init__(*args, **kwargs)
self.gradient_accumulation_steps = gradient_accumulation_steps
self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier
self._num_microbatches = num_microbatches
@ -120,6 +165,69 @@ def make_keras_optimizer_class(cls):
self._global_state = None
self._was_dp_gradients_called = False
def _create_slots(self, var_list):
super(DPOptimizerClass, self)._create_slots(var_list)
if self.gradient_accumulation_steps > 1:
for var in var_list:
self.add_slot(var, 'grad_acc')
def _prepare_local(self, var_device, var_dtype, apply_state):
super(DPOptimizerClass, self)._prepare_local(
var_device, var_dtype, apply_state)
if self.gradient_accumulation_steps > 1:
apply_update = tf.math.equal(
tf.math.floormod(self.iterations + 1,
self.gradient_accumulation_steps),
0)
grad_scaler = tf.cast(1. / self.gradient_accumulation_steps, var_dtype)
apply_state[(var_device, var_dtype)].update(
{
'apply_update': apply_update,
'grad_scaler': grad_scaler
})
def _resource_apply_dense(self, grad, var, apply_state=None):
if self.gradient_accumulation_steps > 1:
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))
grad_acc = self.get_slot(var, 'grad_acc')
def _update_grad():
apply_grad_op = super(DPOptimizerClass, self)._resource_apply_dense(
grad_acc + grad * coefficients['grad_scaler'], var, apply_state)
with tf.control_dependencies([apply_grad_op]):
return grad_acc.assign(tf.zeros_like(grad_acc),
use_locking=self._use_locking,
read_value=False)
def _accumulate():
return grad_acc.assign_add(grad * coefficients['grad_scaler'],
use_locking=self._use_locking,
read_value=False)
return tf.cond(coefficients['apply_update'], _update_grad, _accumulate)
else:
return super(DPOptimizerClass, self)._resource_apply_dense(
grad, var, apply_state)
def _resource_apply_sparse_duplicate_indices(self, *args, **kwargs):
if self.gradient_accumulation_steps > 1:
raise NotImplementedError(
'Sparse gradients are not supported with large batch emulation.')
else:
return super(DPOptimizerClass,
self)._resource_apply_sparse_duplicate_indices(
*args, **kwargs)
def _resource_apply_sparse(self, *args, **kwargs):
if self.gradient_accumulation_steps > 1:
raise NotImplementedError(
'Sparse gradients are not supported with large batch emulation.')
else:
return super(DPOptimizerClass, self)._resource_apply_sparse(
*args, **kwargs)
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
"""DP-SGD version of base class method."""
@ -210,7 +318,7 @@ def make_keras_optimizer_class(cls):
sample_state = self._dp_sum_query.initial_sample_state(params)
for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state)
grad_sums, self._global_state = (
grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
@ -224,6 +332,25 @@ def make_keras_optimizer_class(cls):
return final_grads
def get_config(self):
"""Returns the config of the optimizer.
An optimizer config is a Python dictionary (serializable)
containing the configuration of an optimizer.
The same optimizer can be reinstantiated later
(without any saved state) from this configuration.
Returns:
Python dictionary.
"""
config = super(DPOptimizerClass, self).get_config()
config.update({
'l2_norm_clip': self._l2_norm_clip,
'noise_multiplier': self._noise_multiplier,
'num_microbatches': self._num_microbatches,
})
return config
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""DP-SGD version of base class method."""
assert self._was_dp_gradients_called, (

View file

@ -394,6 +394,87 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
grads_and_vars = tf.Variable([0.0])
opt.apply_gradients(grads_and_vars)
def testLargeBatchEmulationNoNoise(self):
# Test for emulation of large batch training.
# It tests that updates are only done every gradient_accumulation_steps
# steps.
# In this test we set noise multiplier to zero and clipping norm to high
# value, such that optimizer essentially behave as non-DP optimizer.
# This makes easier to check how values of variables are changing.
#
# This test optimizes loss var0*x + var1
# Gradients of this loss are computed as:
# d(loss)/d(var0) = x
# d(loss)/d(var1) = 1
var0 = tf.Variable([[1.0, 2.0]], dtype=tf.float32)
var1 = tf.Variable([3.0], dtype=tf.float32)
x1 = tf.constant([[2.0, 0.0], [0.0, 1.0]], dtype=tf.float32)
loss1 = lambda: tf.matmul(var0, x1, transpose_b=True) + var1
x2 = tf.constant([[4.0, 2.0], [2.0, 1.0]], dtype=tf.float32)
loss2 = lambda: tf.matmul(var0, x2, transpose_b=True) + var1
opt = dp_optimizer_keras.DPKerasSGDOptimizer(
l2_norm_clip=100.0,
noise_multiplier=0.0,
gradient_accumulation_steps=2,
learning_rate=1.0)
# before any call to optimizer
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
self.assertAllCloseAccordingToType([3.0], var1)
opt.minimize(loss1, [var0, var1])
# After first call to optimizer values didn't change
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
self.assertAllCloseAccordingToType([3.0], var1)
opt.minimize(loss2, [var0, var1])
# After second call to optimizer updates were applied
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
self.assertAllCloseAccordingToType([2.0], var1)
opt.minimize(loss2, [var0, var1])
# After third call to optimizer values didn't change
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
self.assertAllCloseAccordingToType([2.0], var1)
opt.minimize(loss2, [var0, var1])
# After fourth call to optimizer updates were applied again
self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0)
self.assertAllCloseAccordingToType([1.0], var1)
@parameterized.named_parameters(
('DPKerasSGDOptimizer 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
('DPKerasSGDOptimizer 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2),
('DPKerasSGDOptimizer 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4),
('DPKerasAdamOptimizer 2',
dp_optimizer_keras.DPKerasAdamOptimizer, 1),
('DPKerasAdagradOptimizer 2',
dp_optimizer_keras.DPKerasAdagradOptimizer, 2),
)
def testLargeBatchEmulation(self, cls, gradient_accumulation_steps):
# Tests various optimizers with large batch emulation.
# Uses clipping and noise, thus does not test specific values
# of the variables and only tests how often variables are updated.
var0 = tf.Variable([[1.0, 2.0]], dtype=tf.float32)
var1 = tf.Variable([3.0], dtype=tf.float32)
x = tf.constant([[2.0, 0.0], [0.0, 1.0]], dtype=tf.float32)
loss = lambda: tf.matmul(var0, x, transpose_b=True) + var1
opt = cls(
l2_norm_clip=100.0,
noise_multiplier=0.0,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=1.0)
for _ in range(gradient_accumulation_steps):
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
self.assertAllCloseAccordingToType([3.0], var1)
opt.minimize(loss, [var0, var1])
self.assertNotAllClose([[1.0, 2.0]], var0)
self.assertNotAllClose([3.0], var1)
if __name__ == '__main__':
tf.test.main()

View file

@ -24,7 +24,6 @@ import numpy as np
from six.moves import range
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.dp_query import gaussian_query
from tensorflow_privacy.privacy.optimizers import dp_optimizer
@ -36,6 +35,24 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
return 0.5 * tf.reduce_sum(
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
def _compute_expected_gradients(self, per_example_gradients,
l2_norm_clip, num_microbatches):
batch_size, num_vars = per_example_gradients.shape
microbatch_gradients = np.mean(
np.reshape(per_example_gradients,
[num_microbatches,
np.int(batch_size / num_microbatches), num_vars]),
axis=1)
microbatch_gradients_norms = np.linalg.norm(microbatch_gradients, axis=1)
def scale(x):
return 1.0 if x < l2_norm_clip else l2_norm_clip / x
scales = np.array(list(map(scale, microbatch_gradients_norms)))
mean_clipped_gradients = np.mean(
microbatch_gradients * scales[:, None], axis=0)
return mean_clipped_gradients
# Parameters for testing: optimizer, num_microbatches, expected answer.
@parameterized.named_parameters(
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
@ -51,9 +68,8 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-2.5, -2.5]),
('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5]),
('DPRMSPropOptimizer 1', dp_optimizer.DPRMSPropOptimizer, 1,
[-2.5, -2.5]),
('DPRMSPropOptimizer 2', dp_optimizer.DPRMSPropOptimizer, 2,
[-2.5, -2.5]),
[-2.5, -2.5]), ('DPRMSPropOptimizer 2', dp_optimizer.DPRMSPropOptimizer,
2, [-2.5, -2.5]),
('DPRMSPropOptimizer 4', dp_optimizer.DPRMSPropOptimizer, 4, [-2.5, -2.5])
)
def testBaseline(self, cls, num_microbatches, expected_answer):
@ -62,13 +78,9 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(
dp_sum_query, 1e6, num_microbatches / 1e6)
opt = cls(
dp_sum_query,
num_microbatches=num_microbatches,
learning_rate=2.0)
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
@ -91,7 +103,6 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
@ -105,19 +116,56 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer),
('DPAdagrad', dp_optimizer.DPAdagradOptimizer),
('DPAdam', dp_optimizer.DPAdamOptimizer),
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer))
def testNoiseMultiplier(self, cls):
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1),
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2),
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4),
)
def testClippingNormWithMicrobatches(self, cls, num_microbatches):
with self.cached_session() as sess:
var0 = tf.Variable([0.0, 0.0])
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0], [-9.0, -12.0],
[-12.0, -16.0]])
l2_norm_clip = 1.0
dp_sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip, 0.0)
opt = cls(dp_sum_query, num_microbatches=num_microbatches,
learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
var_np = self.evaluate(var0)
self.assertAllClose([0.0, 0.0], var_np)
# Compute expected gradient, which is the sum of differences.
data_np = self.evaluate(data0)
per_example_gradients = var_np - data_np
mean_clipped_gradients = self._compute_expected_gradients(
per_example_gradients, l2_norm_clip, num_microbatches)
# Compare actual with expected gradients.
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
grads_and_vars = sess.run(gradient_op)
print('mean_clipped_gradients: ', mean_clipped_gradients)
self.assertAllCloseAccordingToType(mean_clipped_gradients,
grads_and_vars[0][0])
@parameterized.named_parameters(
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1),
('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2),
('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4),
('DPAdagrad', dp_optimizer.DPAdagradOptimizer, 1),
('DPAdam', dp_optimizer.DPAdamOptimizer, 1),
('DPRMSPropOptimizer', dp_optimizer.DPRMSPropOptimizer, 1))
def testNoiseMultiplier(self, cls, num_microbatches):
with self.cached_session() as sess:
var0 = tf.Variable([0.0])
data0 = tf.Variable([[0.0]])
data0 = tf.Variable([[0.0], [0.0], [0.0], [0.0]])
dp_sum_query = gaussian_query.GaussianSumQuery(4.0, 8.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
opt = cls(
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
self.evaluate(tf.global_variables_initializer())
# Fetch params to validate initial values
@ -130,7 +178,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
grads.append(grads_and_vars[0][0])
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
self.assertNear(np.std(grads), 2.0 * 4.0 / num_microbatches, 0.5)
@mock.patch('absl.logging.warning')
def testComputeGradientsOverrideWarning(self, mock_logging):
@ -157,11 +205,8 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
vector_loss = tf.math.squared_difference(labels, preds)
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
optimizer = dp_optimizer.DPGradientDescentOptimizer(
dp_sum_query,
num_microbatches=1,
learning_rate=1.0)
dp_sum_query, num_microbatches=1, learning_rate=1.0)
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
return tf.estimator.EstimatorSpec(
@ -201,8 +246,6 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
num_microbatches = 4
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(
dp_sum_query, 1e6, num_microbatches / 1e6)
opt = cls(
dp_sum_query,
@ -283,8 +326,6 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
extra_variable = tf.Variable('foo', trainable=True, dtype=tf.string)
dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0)
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6,
num_microbatches / 1e6)
opt = cls(
dp_sum_query, num_microbatches=num_microbatches, learning_rate=2.0)
@ -298,27 +339,26 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
sess.run(minimize_op)
def _testWriteOutAndReload(self, optimizer_cls):
optimizer = optimizer_cls(l2_norm_clip=1.0,
noise_multiplier=0.01,
num_microbatches=1)
optimizer = optimizer_cls(
l2_norm_clip=1.0, noise_multiplier=0.01, num_microbatches=1)
test_dir = self.get_temp_dir()
model_path = os.path.join(test_dir, 'model')
model = tf.keras.Sequential([tf.keras.layers.InputLayer(input_shape=(1, 1)),
tf.keras.layers.Dense(units=1,
activation='softmax')])
model.compile(optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True))
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(1, 1)),
tf.keras.layers.Dense(units=1, activation='softmax')
])
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
tf.keras.models.save_model(model, filepath=model_path,
include_optimizer=True)
tf.keras.models.save_model(
model, filepath=model_path, include_optimizer=True)
optimizer_cls_str = optimizer_cls.__name__
tf.keras.models.load_model(model_path,
custom_objects={
optimizer_cls_str: optimizer_cls})
tf.keras.models.load_model(
model_path, custom_objects={optimizer_cls_str: optimizer_cls})
return

View file

@ -15,12 +15,13 @@
# Lint as: python3
"""Data structures representing attack inputs, configuration, outputs."""
import collections
import dataclasses
import enum
import glob
import os
import pickle
from typing import Any, Iterable, Union
from dataclasses import dataclass
import numpy as np
import pandas as pd
from scipy import special
@ -37,7 +38,7 @@ class SlicingFeature(enum.Enum):
CORRECTLY_CLASSIFIED = 'correctly_classified'
@dataclass
@dataclasses.dataclass
class SingleSliceSpec:
"""Specifies a slice.
@ -64,7 +65,7 @@ class SingleSliceSpec:
return '%s=%s' % (self.feature.name, self.value)
@dataclass
@dataclasses.dataclass
class SlicingSpec:
"""Specification of a slicing procedure.
@ -165,7 +166,7 @@ def _log_value(probs, small_value=1e-30):
return -np.log(np.maximum(probs, small_value))
@dataclass
@dataclasses.dataclass
class AttackInputData:
"""Input data for running an attack.
@ -334,9 +335,11 @@ class AttackInputData:
'labels_train and labels_test should both be either set or unset')
if (self.labels_train is None and self.loss_train is None and
self.logits_train is None and self.entropy_train is None):
self.logits_train is None and self.entropy_train is None and
self.probs_train is None):
raise ValueError(
'At least one of labels, logits, losses or entropy should be set')
'At least one of labels, logits, losses, probabilities or entropy should be set'
)
if self.labels_train is not None and not _is_integer_type_array(
self.labels_train):
@ -390,7 +393,7 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
@dataclass
@dataclasses.dataclass
class RocCurve:
"""Represents ROC curve of a membership inference classifier."""
# Thresholds used to define points on ROC curve.
@ -433,7 +436,7 @@ class RocCurve:
DataSize = collections.namedtuple('DataSize', 'ntrain ntest')
@dataclass
@dataclasses.dataclass
class SingleAttackResult:
"""Results from running a single attack."""
@ -488,7 +491,7 @@ class SingleAttackResult:
])
@dataclass
@dataclasses.dataclass
class SingleMembershipProbabilityResult:
"""Results from computing membership probabilities (denoted as privacy risk score in https://arxiv.org/abs/2003.10595).
@ -578,7 +581,7 @@ class SingleMembershipProbabilityResult:
return summary
@dataclass
@dataclasses.dataclass
class MembershipProbabilityResults:
"""Membership probability results from multiple data slices."""
@ -593,7 +596,7 @@ class MembershipProbabilityResults:
return '\n'.join(summary)
@dataclass
@dataclasses.dataclass
class PrivacyReportMetadata:
"""Metadata about the evaluated model.
@ -622,7 +625,7 @@ class AttackResultsDFColumns(enum.Enum):
return '%s' % self.value
@dataclass
@dataclasses.dataclass
class AttackResults:
"""Results from running multiple attacks."""
single_attack_results: Iterable[SingleAttackResult]
@ -759,7 +762,7 @@ class AttackResults:
return pickle.load(inp)
@dataclass
@dataclasses.dataclass
class AttackResultsCollection:
"""A collection of AttackResults."""
attack_results_list: Iterable[AttackResults]

View file

@ -0,0 +1,50 @@
# Copyright 2020, The TensorFlow Privacy Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Privacy library v1 imports.
This module includes classes designed to be compatible with TF1, based on
`tf.compat.v1.train.Optimizer` and `tf.estimator.Estimator`.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
# pylint: disable=g-import-not-at-top
if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts.
pass
else:
# Estimators
from tensorflow_privacy.privacy.estimators.v1.dnn import DNNClassifier as DNNClassifierV1
# Optimizers
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdagradGaussianOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdamGaussianOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer import make_optimizer_class
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGDOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagrad
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdam
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGD
from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import make_vectorized_optimizer_class

View file

@ -13,4 +13,4 @@
# limitations under the License.
"""TensorFlow Privacy version."""
__version__ = '0.6.1'
__version__ = '0.7.3'

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training a language model (recurrent neural network) with DP-SGD optimizer.
This tutorial uses a corpus of text from TensorFlow datasets unless a
@ -44,7 +43,6 @@ import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
from tensorflow_privacy.privacy.analysis import privacy_ledger
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.optimizers import dp_optimizer
@ -92,27 +90,20 @@ def rnn_model_fn(features, labels, mode): # pylint: disable=unused-argument
if mode == tf.estimator.ModeKeys.TRAIN:
if FLAGS.dpsgd:
ledger = privacy_ledger.PrivacyLedger(
population_size=NB_TRAIN,
selection_probability=(FLAGS.batch_size / NB_TRAIN))
optimizer = dp_optimizer.DPAdamGaussianOptimizer(
l2_norm_clip=FLAGS.l2_norm_clip,
noise_multiplier=FLAGS.noise_multiplier,
num_microbatches=FLAGS.microbatches,
ledger=ledger,
learning_rate=FLAGS.learning_rate,
unroll_microbatches=True)
opt_loss = vector_loss
else:
optimizer = tf.train.AdamOptimizer(
learning_rate=FLAGS.learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
opt_loss = scalar_loss
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode,
loss=scalar_loss,
train_op=train_op)
return tf.estimator.EstimatorSpec(
mode=mode, loss=scalar_loss, train_op=train_op)
# Add evaluation metrics (for EVAL mode).
elif mode == tf.estimator.ModeKeys.EVAL:
@ -122,9 +113,8 @@ def rnn_model_fn(features, labels, mode): # pylint: disable=unused-argument
labels=tf.cast(x[:, 1:], dtype=tf.int32),
predictions=tf.argmax(input=logits, axis=2))
}
return tf.estimator.EstimatorSpec(mode=mode,
loss=scalar_loss,
eval_metric_ops=eval_metric_ops)
return tf.estimator.EstimatorSpec(
mode=mode, loss=scalar_loss, eval_metric_ops=eval_metric_ops)
def load_data():
@ -132,13 +122,13 @@ def load_data():
if not FLAGS.data_dir:
print('FLAGS.data_dir containing train.txt and test.txt was not specified, '
'using a substitute dataset from the tensorflow_datasets module.')
train_dataset = tfds.load(name='lm1b/subwords8k',
split=tfds.Split.TRAIN,
batch_size=NB_TRAIN,
shuffle_files=True)
test_dataset = tfds.load(name='lm1b/subwords8k',
split=tfds.Split.TEST,
batch_size=10000)
train_dataset = tfds.load(
name='lm1b/subwords8k',
split=tfds.Split.TRAIN,
batch_size=NB_TRAIN,
shuffle_files=True)
test_dataset = tfds.load(
name='lm1b/subwords8k', split=tfds.Split.TEST, batch_size=10000)
train_data = next(iter(tfds.as_numpy(train_dataset)))
test_data = next(iter(tfds.as_numpy(test_dataset)))
train_data = train_data['text'].flatten()
@ -162,10 +152,11 @@ def compute_epsilon(steps):
return float('inf')
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
sampling_probability = FLAGS.batch_size / NB_TRAIN
rdp = compute_rdp(q=sampling_probability,
noise_multiplier=FLAGS.noise_multiplier,
steps=steps,
orders=orders)
rdp = compute_rdp(
q=sampling_probability,
noise_multiplier=FLAGS.noise_multiplier,
steps=steps,
orders=orders)
# Delta is set to 1e-5 because Penn TreeBank has 60000 training points.
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
@ -180,9 +171,8 @@ def main(unused_argv):
# Instantiate the tf.Estimator.
conf = tf.estimator.RunConfig(save_summary_steps=1000)
lm_classifier = tf.estimator.Estimator(model_fn=rnn_model_fn,
model_dir=FLAGS.model_dir,
config=conf)
lm_classifier = tf.estimator.Estimator(
model_fn=rnn_model_fn, model_dir=FLAGS.model_dir, config=conf)
# Create tf.Estimator input functions for the training and test data.
batch_len = FLAGS.batch_size * SEQ_LEN
@ -221,5 +211,6 @@ def main(unused_argv):
else:
print('Trained with vanilla non-private SGD optimizer')
if __name__ == '__main__':
app.run(main)