forked from 626_privacy/tensorflow_privacy
Implementation of Differentially Private Logistic Regression.
PiperOrigin-RevId: 381904153
This commit is contained in:
parent
af87581387
commit
392c506c62
6 changed files with 589 additions and 0 deletions
125
tensorflow_privacy/privacy/logistic_regression/datasets.py
Normal file
125
tensorflow_privacy/privacy/logistic_regression/datasets.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Functions for generating train and test data for logistic regression models.
|
||||
|
||||
Includes two types of datasets:
|
||||
- Synthetic linearly separable labeled examples.
|
||||
Here, in the binary classification case, we generate training examples by
|
||||
first sampling a random weight vector w from a multivariate Gaussian
|
||||
distribution. Then, for each training example, we randomly sample a point x,
|
||||
also from a multivariate Gaussian distribution, and then set the label y equal
|
||||
to 1 if the inner product of w and x is positive, and equal to 0 otherwise. As
|
||||
such, the training data is linearly separable.
|
||||
More generally, in the case where there are num_classes many classes, we
|
||||
sample num_classes different w vectors. After sampling x, we will set its
|
||||
class label y to the class for which the corresponding w vector has the
|
||||
largest inner product with x.
|
||||
- MNIST 10-class classification dataset.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
from sklearn import preprocessing
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RegressionDataset:
|
||||
"""Class for storing labeled examples for a regression dataset.
|
||||
|
||||
Attributes:
|
||||
points: array of shape (num_examples, dimension) containing the points to
|
||||
be classified.
|
||||
labels: array of shape (num_examples,) containing the corresponding labels,
|
||||
each belonging to the set {0,1,...,num_classes-1}, where num_classes is
|
||||
the number of classes.
|
||||
"""
|
||||
points: np.ndarray
|
||||
labels: np.ndarray
|
||||
|
||||
|
||||
def linearly_separable_labeled_examples(
|
||||
num_examples: int, weights: np.ndarray)-> RegressionDataset:
|
||||
"""Generates num_examples labeled examples using separator given by weights.
|
||||
|
||||
Args:
|
||||
num_examples: number of labeled examples to generate.
|
||||
weights: dimension by num_classes matrix containing coefficients of linear
|
||||
separator, where dimension is the dimension and num_classes is the number
|
||||
of classes.
|
||||
|
||||
Returns:
|
||||
RegressionDataset consisting of points and labels. Each point has unit
|
||||
l2-norm.
|
||||
"""
|
||||
dimension = weights.shape[0]
|
||||
# Generate points and normalize each to have unit l2-norm.
|
||||
points_non_normalized = np.random.normal(size=(num_examples, dimension))
|
||||
points = preprocessing.normalize(points_non_normalized)
|
||||
# Compute labels.
|
||||
labels = np.argmax(np.matmul(points, weights), axis=1)
|
||||
return RegressionDataset(points, labels)
|
||||
|
||||
|
||||
def synthetic_linearly_separable_data(
|
||||
num_train: int, num_test: int, dimension: int,
|
||||
num_classes: int)-> Tuple[RegressionDataset, RegressionDataset]:
|
||||
"""Generates synthetic train and test data for logistic regression.
|
||||
|
||||
Args:
|
||||
num_train: number of training data points.
|
||||
num_test: number of test data points.
|
||||
dimension: the dimension of the classification problem.
|
||||
num_classes: number of classes, assumed to be at least 2.
|
||||
|
||||
Returns:
|
||||
train_dataset: num_train labeled examples, with unit l2-norm points.
|
||||
test_dataset: num_test labeled examples, with unit l2-norm points.
|
||||
"""
|
||||
if num_classes < 2:
|
||||
raise ValueError(f'num_classes must be at least 2. It is {num_classes}.')
|
||||
|
||||
# Generate weight vector.
|
||||
weights = np.random.normal(size=(dimension, num_classes))
|
||||
|
||||
# Generate train labeled examples.
|
||||
train_dataset = linearly_separable_labeled_examples(num_train, weights)
|
||||
|
||||
# Generate test labeled examples.
|
||||
test_dataset = linearly_separable_labeled_examples(num_test, weights)
|
||||
|
||||
return (train_dataset, test_dataset)
|
||||
|
||||
|
||||
def mnist_dataset()-> Tuple[RegressionDataset, RegressionDataset]:
|
||||
"""Generates (normalized) train and test data for MNIST.
|
||||
|
||||
Returns:
|
||||
train_dataset: MNIST labeled examples, with unit l2-norm points.
|
||||
test_dataset: MNIST labeled examples, with unit l2-norm points.
|
||||
"""
|
||||
train_data, test_data = tf.keras.datasets.mnist.load_data()
|
||||
train_points_non_normalized, train_labels = train_data
|
||||
test_points_non_normalized, test_labels = test_data
|
||||
num_train = train_points_non_normalized.shape[0]
|
||||
num_test = test_points_non_normalized.shape[0]
|
||||
train_points_non_normalized = train_points_non_normalized.reshape(
|
||||
(num_train, -1))
|
||||
test_points_non_normalized = test_points_non_normalized.reshape(
|
||||
(num_test, -1))
|
||||
train_points = preprocessing.normalize(train_points_non_normalized)
|
||||
test_points = preprocessing.normalize(test_points_non_normalized)
|
||||
return (RegressionDataset(train_points, train_labels),
|
||||
RegressionDataset(test_points, test_labels))
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for tensorflow_privacy.privacy.logistic_regression.datasets."""
|
||||
|
||||
import unittest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow_privacy.privacy.logistic_regression import datasets
|
||||
|
||||
|
||||
class DatasetsTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
(1, np.array([[1],])),
|
||||
(2, np.array([[1],])),
|
||||
(5, np.array([[-1, 1], [1, -1]])),
|
||||
(15, np.array([[-1, 1.5, 2.1], [1.3, -3.3, -7.1], [1.3, -3.3, -7.1]])))
|
||||
def test_linearly_separable_labeled_examples(self, num_examples, weights):
|
||||
dimension, num_classes = weights.shape
|
||||
dataset = datasets.linearly_separable_labeled_examples(num_examples,
|
||||
weights)
|
||||
self.assertEqual(dataset.points.shape, (num_examples, dimension))
|
||||
self.assertEqual(dataset.labels.shape, (num_examples,))
|
||||
product = np.matmul(dataset.points, weights)
|
||||
for i in range(num_examples):
|
||||
for j in range(num_classes):
|
||||
self.assertGreaterEqual(product[i, dataset.labels[i]], product[i, j])
|
||||
|
||||
@parameterized.parameters(
|
||||
(1, 1, 1, 2),
|
||||
(20, 5, 1, 2),
|
||||
(20, 5, 2, 2),
|
||||
(1000, 10, 15, 10))
|
||||
def test_synthetic(self, num_train, num_test, dimension, num_classes):
|
||||
(train_dataset, test_dataset) = datasets.synthetic_linearly_separable_data(
|
||||
num_train, num_test, dimension, num_classes)
|
||||
self.assertEqual(train_dataset.points.shape, (num_train, dimension))
|
||||
self.assertEqual(train_dataset.labels.shape, (num_train,))
|
||||
self.assertEqual(test_dataset.points.shape, (num_test, dimension))
|
||||
self.assertEqual(test_dataset.labels.shape, (num_test,))
|
||||
# Check that each train and test point has unit l2-norm.
|
||||
for i in range(num_train):
|
||||
self.assertAlmostEqual(np.linalg.norm(train_dataset.points[i, :]), 1)
|
||||
for i in range(num_test):
|
||||
self.assertAlmostEqual(np.linalg.norm(test_dataset.points[i, :]), 1)
|
||||
# Check that each train and test label is in {0,...,num_classes-1}.
|
||||
self.assertTrue(np.all(np.isin(train_dataset.labels, range(num_classes))))
|
||||
self.assertTrue(np.all(np.isin(test_dataset.labels, range(num_classes))))
|
||||
|
||||
def test_mnist_dataset(self):
|
||||
(train_dataset, test_dataset) = datasets.mnist_dataset()
|
||||
self.assertEqual(train_dataset.points.shape, (60000, 784))
|
||||
self.assertEqual(train_dataset.labels.shape, (60000,))
|
||||
self.assertEqual(test_dataset.points.shape, (10000, 784))
|
||||
self.assertEqual(test_dataset.labels.shape, (10000,))
|
||||
# Check that each train and test point has unit l2-norm.
|
||||
for i in range(train_dataset.points.shape[0]):
|
||||
self.assertAlmostEqual(np.linalg.norm(train_dataset.points[i, :]), 1)
|
||||
for i in range(test_dataset.points.shape[0]):
|
||||
self.assertAlmostEqual(np.linalg.norm(test_dataset.points[i, :]), 1)
|
||||
# Check that each train and test label is in {0,...,9}.
|
||||
self.assertTrue(np.all(np.isin(train_dataset.labels, range(10))))
|
||||
self.assertTrue(np.all(np.isin(test_dataset.labels, range(10))))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,208 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implementation of differentially private multinomial logistic regression.
|
||||
|
||||
Algorithms include:
|
||||
|
||||
- Based on the differentially private objective perturbation method of Kifer et
|
||||
al. (Colt 2012): http://proceedings.mlr.press/v23/kifer12/kifer12.pdf
|
||||
Their algorithm can be used for convex optimization problems in general, and in
|
||||
the case of multinomial logistic regression in particular.
|
||||
|
||||
- Training procedure based on the Differentially Private Stochastic Gradient
|
||||
Descent (DP-SGD) implementation in TensorFlow Privacy, which is itself based on
|
||||
the algorithm of Abadi et al.: https://arxiv.org/pdf/1607.00133.pdf%20.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy import compute_dp_sgd_privacy as compute_epsilon
|
||||
from tensorflow_privacy.privacy.logistic_regression import datasets
|
||||
from tensorflow_privacy.privacy.logistic_regression import single_layer_softmax
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer_keras
|
||||
from differential_privacy.python.accounting import common
|
||||
|
||||
|
||||
@tf.keras.utils.register_keras_serializable(package='Custom', name='Kifer')
|
||||
class KiferRegularizer(tf.keras.regularizers.Regularizer):
|
||||
"""Class corresponding to the regularizer in Algorithm 1 of Kifer et al.
|
||||
|
||||
Attributes:
|
||||
l2_regularizer: scalar coefficient for l2-regularization term.
|
||||
num_train: number of training examples.
|
||||
b: tensor of shape (d,num_classes) linearly translating the objective.
|
||||
"""
|
||||
|
||||
def __init__(self, num_train: int, dimension: int, epsilon: float,
|
||||
delta: float, num_classes: int, input_clipping_norm: float):
|
||||
self._num_train = num_train
|
||||
(self._l2_regularizer,
|
||||
variance) = self.logistic_objective_perturbation_parameters(
|
||||
num_train, epsilon, delta, num_classes, input_clipping_norm)
|
||||
self._b = tf.random.normal(shape=[dimension, num_classes], mean=0.0,
|
||||
stddev=math.sqrt(variance),
|
||||
dtype=tf.dtypes.float32)
|
||||
|
||||
def __call__(self, x):
|
||||
return (tf.reduce_sum(self._l2_regularizer*tf.square(x)) +
|
||||
(1/self._num_train)*tf.reduce_sum(tf.multiply(x, self._b)))
|
||||
|
||||
def get_config(self):
|
||||
return {'l2_regularizer': self._l2_regularizer,
|
||||
'num_train': self._num_train, 'b': self._b}
|
||||
|
||||
def logistic_objective_perturbation_parameters(
|
||||
self, num_train: int, epsilon: float, delta: float, num_classes: int,
|
||||
input_clipping_norm: float)-> Tuple[float, float]:
|
||||
"""Computes l2-regularization coefficient and Gaussian noise variance.
|
||||
|
||||
The setting is based on Algorithm 1 of Kifer et al.
|
||||
|
||||
Args:
|
||||
num_train: number of input training points.
|
||||
epsilon: epsilon parameter in (epsilon, delta)-DP.
|
||||
delta: delta parameter in (epsilon, delta)-DP.
|
||||
num_classes: number of classes.
|
||||
input_clipping_norm: l2-norm according to which input points are clipped.
|
||||
|
||||
Returns:
|
||||
l2-regularization coefficient and variance of Gaussian noise added in
|
||||
Algorithm 1 of Kifer et al.
|
||||
"""
|
||||
# zeta is an upper bound on the l2-norm of the loss function gradient.
|
||||
zeta = input_clipping_norm
|
||||
# variance is based on line 5 from Algorithm 1 of Kifer et al. (page 6):
|
||||
variance = zeta*zeta*(8*np.log(2/delta)+4*epsilon)/(epsilon*epsilon)
|
||||
# lambda_coefficient is an upper bound on the spectral norm of the Hessian
|
||||
# of the loss function.
|
||||
lambda_coefficient = math.sqrt(2*num_classes)*(input_clipping_norm**2)/4
|
||||
l2_regularizer = lambda_coefficient/(epsilon*num_train)
|
||||
return (l2_regularizer, variance)
|
||||
|
||||
|
||||
def logistic_objective_perturbation(train_dataset: datasets.RegressionDataset,
|
||||
test_dataset: datasets.RegressionDataset,
|
||||
epsilon: float, delta: float,
|
||||
epochs: int, num_classes: int,
|
||||
input_clipping_norm: float)-> List[float]:
|
||||
"""Trains and validates differentially private logistic regression model.
|
||||
|
||||
The training is based on the Algorithm 1 of Kifer et al.
|
||||
|
||||
Args:
|
||||
train_dataset: consists of num_train many labeled examples, where the labels
|
||||
are in {0,1,...,num_classes-1}.
|
||||
test_dataset: consists of num_test many labeled examples, where the labels
|
||||
are in {0,1,...,num_classes-1}.
|
||||
epsilon: epsilon parameter in (epsilon, delta)-DP.
|
||||
delta: delta parameter in (epsilon, delta)-DP.
|
||||
epochs: number of training epochs.
|
||||
num_classes: number of classes.
|
||||
input_clipping_norm: l2-norm according to which input points are clipped.
|
||||
|
||||
Returns:
|
||||
List of test accuracies (one for each epoch) on test_dataset of model
|
||||
trained on train_dataset.
|
||||
"""
|
||||
num_train, dimension = train_dataset.points.shape
|
||||
# Normalize each training point (i.e., row of train_dataset.points) to have
|
||||
# l2-norm at most input_clipping_norm.
|
||||
train_dataset.points = tf.clip_by_norm(train_dataset.points,
|
||||
input_clipping_norm, [1]).numpy()
|
||||
optimizer = 'sgd'
|
||||
loss = 'categorical_crossentropy'
|
||||
kernel_regularizer = KiferRegularizer(num_train, dimension, epsilon, delta,
|
||||
num_classes, input_clipping_norm)
|
||||
return single_layer_softmax.single_layer_softmax_classifier(
|
||||
train_dataset, test_dataset, epochs, num_classes, optimizer, loss,
|
||||
kernel_regularizer=kernel_regularizer)
|
||||
|
||||
|
||||
def compute_dpsgd_noise_multiplier(
|
||||
num_train: int, epsilon: float, delta: float, epochs: int,
|
||||
batch_size: int, tolerance: float = 1e-2) -> Optional[float]:
|
||||
"""Computes the noise multiplier for DP-SGD given privacy parameters.
|
||||
|
||||
The algorithm performs binary search on the values of epsilon.
|
||||
|
||||
Args:
|
||||
num_train: number of input training points.
|
||||
epsilon: epsilon parameter in (epsilon, delta)-DP.
|
||||
delta: delta parameter in (epsilon, delta)-DP.
|
||||
epochs: number of training epochs.
|
||||
batch_size: the number of examples in each batch of gradient descent.
|
||||
tolerance: an upper bound on the absolute difference between the input
|
||||
(desired) epsilon and the epsilon value corresponding to the
|
||||
noise_multiplier that is output.
|
||||
|
||||
Returns:
|
||||
noise_multiplier: the smallest noise multiplier value (within plus or minus
|
||||
the given tolerance) for which using DPKerasAdamOptimizer will result in an
|
||||
(epsilon, delta)-differentially private trained model.
|
||||
"""
|
||||
search_parameters = common.BinarySearchParameters(lower_bound=0,
|
||||
upper_bound=math.inf,
|
||||
initial_guess=1,
|
||||
tolerance=tolerance)
|
||||
return common.inverse_monotone_function(
|
||||
lambda x: compute_epsilon(num_train, batch_size, x, epochs, delta)[0],
|
||||
epsilon, search_parameters)
|
||||
|
||||
|
||||
def logistic_dpsgd(train_dataset: datasets.RegressionDataset,
|
||||
test_dataset: datasets.RegressionDataset,
|
||||
epsilon: float, delta: float, epochs: int, num_classes: int,
|
||||
batch_size: int, num_microbatches: int,
|
||||
clipping_norm: float)-> List[float]:
|
||||
"""Trains and validates private logistic regression model via DP-SGD.
|
||||
|
||||
The training is based on the differentially private stochasstic gradient
|
||||
descent algorithm implemented in TensorFlow Privacy.
|
||||
|
||||
Args:
|
||||
train_dataset: consists of num_train many labeled examples, where the labels
|
||||
are in {0,1,...,num_classes-1}.
|
||||
test_dataset: consists of num_test many labeled examples, where the labels
|
||||
are in {0,1,...,num_classes-1}.
|
||||
epsilon: epsilon parameter in (epsilon, delta)-DP.
|
||||
delta: delta parameter in (epsilon, delta)-DP.
|
||||
epochs: number of training epochs.
|
||||
num_classes: number of classes.
|
||||
batch_size: the number of examples in each batch of gradient descent.
|
||||
num_microbatches: the number of microbatches in gradient descent.
|
||||
clipping_norm: the gradients will be normalized by DPKerasAdamOptimizer
|
||||
to have l2-norm at most clipping_norm.
|
||||
|
||||
Returns:
|
||||
List of test accuracies (one for each epoch) on test_dataset of model
|
||||
trained on train_dataset.
|
||||
"""
|
||||
num_train = train_dataset.points.shape[0]
|
||||
remainder = num_train % batch_size
|
||||
if remainder != 0:
|
||||
train_dataset.points = train_dataset.points[:-remainder, :]
|
||||
train_dataset.labels = train_dataset.labels[:-remainder]
|
||||
num_train -= remainder
|
||||
noise_multiplier = compute_dpsgd_noise_multiplier(num_train, epsilon, delta,
|
||||
epochs, batch_size)
|
||||
optimizer = dp_optimizer_keras.DPKerasAdamOptimizer(
|
||||
l2_norm_clip=clipping_norm, noise_multiplier=noise_multiplier,
|
||||
num_microbatches=num_microbatches)
|
||||
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||
reduction=tf.losses.Reduction.NONE)
|
||||
return single_layer_softmax.single_layer_softmax_classifier(
|
||||
train_dataset, test_dataset, epochs, num_classes, optimizer, loss,
|
||||
batch_size)
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for tensorflow_privacy.privacy.logistic_regression.multinomial_logistic."""
|
||||
|
||||
import unittest
|
||||
from absl.testing import parameterized
|
||||
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy import compute_dp_sgd_privacy
|
||||
from tensorflow_privacy.privacy.logistic_regression import datasets
|
||||
from tensorflow_privacy.privacy.logistic_regression import multinomial_logistic
|
||||
|
||||
|
||||
class MultinomialLogisticRegressionTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
(5000, 500, 3, 1, 1e-5, 40, 2, 0.05),
|
||||
(5000, 500, 4, 1, 1e-5, 40, 2, 0.05),
|
||||
(10000, 1000, 3, 1, 1e-5, 40, 4, 0.1),
|
||||
(10000, 1000, 4, 1, 1e-5, 40, 4, 0.1),
|
||||
)
|
||||
def test_logistic_objective_perturbation(self, num_train, num_test, dimension,
|
||||
epsilon, delta, epochs, num_classes,
|
||||
tolerance):
|
||||
(train_dataset, test_dataset) = datasets.synthetic_linearly_separable_data(
|
||||
num_train, num_test, dimension, num_classes)
|
||||
accuracy = multinomial_logistic.logistic_objective_perturbation(
|
||||
train_dataset, test_dataset, epsilon, delta, epochs, num_classes, 1)
|
||||
# Since the synthetic data is linearly separable, we expect the test
|
||||
# accuracy to come arbitrarily close to 1 as the number of training examples
|
||||
# grows.
|
||||
self.assertAlmostEqual(accuracy[-1], 1, delta=tolerance)
|
||||
|
||||
@parameterized.parameters(
|
||||
(1, 1, 1e-5, 40, 1, 1e-2),
|
||||
(500, 0.1, 1e-5, 40, 50, 1e-2),
|
||||
(5000, 10, 1e-5, 40, 10, 1e-3),
|
||||
)
|
||||
def test_compute_dpsgd_noise_multiplier(self, num_train, epsilon, delta,
|
||||
epochs, batch_size, tolerance):
|
||||
noise_multiplier = multinomial_logistic.compute_dpsgd_noise_multiplier(
|
||||
num_train, epsilon, delta, epochs, batch_size, tolerance)
|
||||
epsilon_lower_bound = compute_dp_sgd_privacy(num_train, batch_size,
|
||||
noise_multiplier + tolerance,
|
||||
epochs, delta)[0]
|
||||
epsilon_upper_bound = compute_dp_sgd_privacy(num_train, batch_size,
|
||||
noise_multiplier - tolerance,
|
||||
epochs, delta)[0]
|
||||
self.assertLess(epsilon_lower_bound, epsilon)
|
||||
self.assertLess(epsilon, epsilon_upper_bound)
|
||||
|
||||
@parameterized.parameters(
|
||||
(5000, 500, 3, 1, 1e-5, 40, 2, 0.05, 10, 10, 1),
|
||||
(5000, 500, 4, 1, 1e-5, 40, 2, 0.05, 10, 10, 1),
|
||||
(5000, 500, 3, 2, 1e-4, 40, 4, 0.1, 10, 10, 1),
|
||||
(5000, 500, 4, 2, 1e-4, 40, 4, 0.1, 10, 10, 1),
|
||||
)
|
||||
def test_logistic_dpsgd(self, num_train, num_test, dimension, epsilon,
|
||||
delta, epochs, num_classes, tolerance,
|
||||
batch_size, num_microbatches, clipping_norm):
|
||||
(train_dataset, test_dataset) = datasets.synthetic_linearly_separable_data(
|
||||
num_train, num_test, dimension, num_classes)
|
||||
accuracy = multinomial_logistic.logistic_dpsgd(
|
||||
train_dataset, test_dataset, epsilon, delta, epochs, num_classes,
|
||||
batch_size, num_microbatches, clipping_norm)
|
||||
# Since the synthetic data is linearly separable, we expect the test
|
||||
# accuracy to come arbitrarily close to 1 as the number of training examples
|
||||
# grows.
|
||||
self.assertAlmostEqual(accuracy[-1], 1, delta=tolerance)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implementation of a single-layer softmax classifier.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.logistic_regression import datasets
|
||||
|
||||
|
||||
def single_layer_softmax_classifier(
|
||||
train_dataset: datasets.RegressionDataset,
|
||||
test_dataset: datasets.RegressionDataset,
|
||||
epochs: int, num_classes: int, optimizer: tf.keras.optimizers.Optimizer,
|
||||
loss: tf.keras.losses.Loss = 'categorical_crossentropy',
|
||||
batch_size: int = 32,
|
||||
kernel_regularizer: tf.keras.regularizers.Regularizer = None)-> List[float]:
|
||||
"""Trains a single layer neural network classifier with softmax activation.
|
||||
|
||||
Args:
|
||||
train_dataset: consists of num_train many labeled examples, where the labels
|
||||
are in {0,1,...,num_classes-1}.
|
||||
test_dataset: consists of num_test many labeled examples, where the labels
|
||||
are in {0,1,...,num_classes-1}.
|
||||
epochs: the number of epochs.
|
||||
num_classes: the number of classes.
|
||||
optimizer: a tf.keras optimizer.
|
||||
loss: a tf.keras loss function.
|
||||
batch_size: a positive integer.
|
||||
kernel_regularizer: a regularization function.
|
||||
|
||||
Returns:
|
||||
List of test accuracies (one for each epoch) on test_dataset of model
|
||||
trained on train_dataset.
|
||||
"""
|
||||
one_hot_train_labels = tf.one_hot(train_dataset.labels, num_classes)
|
||||
one_hot_test_labels = tf.one_hot(test_dataset.labels, num_classes)
|
||||
model = tf.keras.Sequential()
|
||||
model.add(tf.keras.layers.Dense(units=num_classes,
|
||||
activation='softmax',
|
||||
kernel_regularizer=kernel_regularizer))
|
||||
model.compile(optimizer, loss=loss, metrics=['accuracy'])
|
||||
history = model.fit(train_dataset.points, one_hot_train_labels,
|
||||
batch_size=batch_size, epochs=epochs,
|
||||
validation_data=(test_dataset.points,
|
||||
one_hot_test_labels),
|
||||
verbose=0)
|
||||
return history.history['val_accuracy']
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for tensorflow_privacy.privacy.logistic_regression.single_layer_softmax."""
|
||||
|
||||
import unittest
|
||||
from absl.testing import parameterized
|
||||
from tensorflow_privacy.privacy.logistic_regression import datasets
|
||||
from tensorflow_privacy.privacy.logistic_regression import single_layer_softmax
|
||||
|
||||
|
||||
class SingleLayerSoftmaxTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
(5000, 500, 3, 40, 2, 0.05),
|
||||
(5000, 500, 4, 40, 2, 0.05),
|
||||
(10000, 1000, 3, 40, 4, 0.1),
|
||||
(10000, 1000, 4, 40, 4, 0.1),
|
||||
)
|
||||
def test_single_layer_softmax(self, num_train, num_test, dimension, epochs,
|
||||
num_classes, tolerance):
|
||||
(train_dataset, test_dataset) = datasets.synthetic_linearly_separable_data(
|
||||
num_train, num_test, dimension, num_classes)
|
||||
accuracy = single_layer_softmax.single_layer_softmax_classifier(
|
||||
train_dataset, test_dataset, epochs, num_classes, 'sgd')
|
||||
self.assertAlmostEqual(accuracy[-1], 1, delta=tolerance)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in a new issue