Implementation of Differentially Private Logistic Regression.

PiperOrigin-RevId: 458266079
This commit is contained in:
A. Unique TensorFlower 2022-06-30 11:00:33 -07:00
parent 77d962e0fc
commit c665281c55
4 changed files with 17 additions and 15 deletions

View file

@ -29,7 +29,7 @@ Includes two types of datasets:
"""
import dataclasses
from typing import Tuple
from typing import Tuple, Optional
import numpy as np
from sklearn import preprocessing
@ -46,9 +46,13 @@ class RegressionDataset:
labels: array of shape (num_examples,) containing the corresponding labels,
each belonging to the set {0,1,...,num_classes-1}, where num_classes is
the number of classes.
weights: dimension by num_classes matrix containing coefficients of linear
separator, where dimension is the dimension and num_classes is the number
of classes.
"""
points: np.ndarray
labels: np.ndarray
weights: Optional[np.ndarray]
def linearly_separable_labeled_examples(
@ -71,7 +75,7 @@ def linearly_separable_labeled_examples(
points = preprocessing.normalize(points_non_normalized)
# Compute labels.
labels = np.argmax(np.matmul(points, weights), axis=1)
return RegressionDataset(points, labels)
return RegressionDataset(points, labels, weights)
def synthetic_linearly_separable_data(
@ -122,5 +126,5 @@ def mnist_dataset() -> Tuple[RegressionDataset, RegressionDataset]:
(num_test, -1))
train_points = preprocessing.normalize(train_points_non_normalized)
test_points = preprocessing.normalize(test_points_non_normalized)
return (RegressionDataset(train_points, train_labels),
RegressionDataset(test_points, test_labels))
return (RegressionDataset(train_points, train_labels, None),
RegressionDataset(test_points, test_labels, None))

View file

@ -33,7 +33,7 @@ class MultinomialLogisticRegressionTest(parameterized.TestCase):
tolerance):
(train_dataset, test_dataset) = datasets.synthetic_linearly_separable_data(
num_train, num_test, dimension, num_classes)
accuracy = multinomial_logistic.logistic_objective_perturbation(
_, accuracy = multinomial_logistic.logistic_objective_perturbation(
train_dataset, test_dataset, epsilon, delta, epochs, num_classes, 1)
# Since the synthetic data is linearly separable, we expect the test
# accuracy to come arbitrarily close to 1 as the number of training examples
@ -67,11 +67,9 @@ class MultinomialLogisticRegressionTest(parameterized.TestCase):
num_microbatches, clipping_norm):
(train_dataset, test_dataset) = datasets.synthetic_linearly_separable_data(
num_train, num_test, dimension, num_classes)
accuracy = multinomial_logistic.logistic_dpsgd(train_dataset, test_dataset,
epsilon, delta, epochs,
num_classes, batch_size,
num_microbatches,
clipping_norm)
_, accuracy = multinomial_logistic.logistic_dpsgd(
train_dataset, test_dataset, epsilon, delta, epochs, num_classes,
batch_size, num_microbatches, clipping_norm)
# Since the synthetic data is linearly separable, we expect the test
# accuracy to come arbitrarily close to 1 as the number of training examples
# grows.

View file

@ -13,7 +13,7 @@
# limitations under the License.
"""Implementation of a single-layer softmax classifier."""
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple, Any
import tensorflow as tf
from tensorflow_privacy.privacy.logistic_regression import datasets
@ -28,7 +28,7 @@ def single_layer_softmax_classifier(
loss: Union[tf.keras.losses.Loss, str] = 'categorical_crossentropy',
batch_size: int = 32,
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> List[float]:
) -> Tuple[Any, List[float]]:
"""Trains a single layer neural network classifier with softmax activation.
Args:
@ -63,4 +63,5 @@ def single_layer_softmax_classifier(
epochs=epochs,
validation_data=(test_dataset.points, one_hot_test_labels),
verbose=0)
return history.history['val_accuracy']
weights = model.layers[0].weights
return weights, history.history['val_accuracy']

View file

@ -31,10 +31,9 @@ class SingleLayerSoftmaxTest(parameterized.TestCase):
num_classes, tolerance):
(train_dataset, test_dataset) = datasets.synthetic_linearly_separable_data(
num_train, num_test, dimension, num_classes)
accuracy = single_layer_softmax.single_layer_softmax_classifier(
_, accuracy = single_layer_softmax.single_layer_softmax_classifier(
train_dataset, test_dataset, epochs, num_classes, 'sgd')
self.assertAlmostEqual(accuracy[-1], 1, delta=tolerance)
if __name__ == '__main__':
unittest.main()