PiperOrigin-RevId: 424922009
This commit is contained in:
Michael Reneer 2022-01-28 11:50:07 -08:00 committed by A. Unique TensorFlower
parent 7396ad62da
commit 07230a161a
4 changed files with 31 additions and 36 deletions

View file

@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_privacy.privacy.logistic_regression.multinomial_logistic."""
import unittest
from absl.testing import parameterized
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy import compute_dp_sgd_privacy
from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib
from tensorflow_privacy.privacy.logistic_regression import datasets
from tensorflow_privacy.privacy.logistic_regression import multinomial_logistic
@ -49,12 +49,10 @@ class MultinomialLogisticRegressionTest(parameterized.TestCase):
epochs, batch_size, tolerance):
noise_multiplier = multinomial_logistic.compute_dpsgd_noise_multiplier(
num_train, epsilon, delta, epochs, batch_size, tolerance)
epsilon_lower_bound = compute_dp_sgd_privacy(num_train, batch_size,
noise_multiplier + tolerance,
epochs, delta)[0]
epsilon_upper_bound = compute_dp_sgd_privacy(num_train, batch_size,
noise_multiplier - tolerance,
epochs, delta)[0]
epsilon_lower_bound = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy(
num_train, batch_size, noise_multiplier + tolerance, epochs, delta)[0]
epsilon_upper_bound = compute_dp_sgd_privacy_lib.compute_dp_sgd_privacy(
num_train, batch_size, noise_multiplier - tolerance, epochs, delta)[0]
self.assertLess(epsilon_lower_bound, epsilon)
self.assertLess(epsilon, epsilon_upper_bound)

View file

@ -24,12 +24,8 @@ import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import metrics
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
@ -91,31 +87,32 @@ num_clusters = int(round(np.max(training_labels))) + 1
# Hint: play with the number of layers to achieve different level of
# over-fitting and observe its effects on membership inference performance.
three_layer_model = keras.models.Sequential([
layers.Dense(300, activation="relu"),
layers.Dense(300, activation="relu"),
layers.Dense(300, activation="relu"),
layers.Dense(num_clusters, activation="relu"),
layers.Softmax()
three_layer_model = tf.keras.Sequential([
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(num_clusters, activation="relu"),
tf.keras.layers.Softmax()
])
three_layer_model.compile(
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
two_layer_model = keras.models.Sequential([
layers.Dense(300, activation="relu"),
layers.Dense(300, activation="relu"),
layers.Dense(num_clusters, activation="relu"),
layers.Softmax()
two_layer_model = tf.keras.Sequential([
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(num_clusters, activation="relu"),
tf.keras.layers.Softmax()
])
two_layer_model.compile(
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
def crossentropy(true_labels, predictions):
return keras.backend.eval(
keras.losses.binary_crossentropy(
keras.backend.variable(to_categorical(true_labels, num_clusters)),
keras.backend.variable(predictions)))
return tf.keras.backend.eval(
tf.keras.metrics.binary_crossentropy(
tf.keras.backend.variable(
tf.keras.utils.to_categorical(true_labels, num_clusters)),
tf.keras.backend.variable(predictions)))
def main(unused_argv):
@ -131,9 +128,10 @@ def main(unused_argv):
for i in range(1, 6):
models[model_name].fit(
training_features,
to_categorical(training_labels, num_clusters),
tf.keras.utils.to_categorical(training_labels, num_clusters),
validation_data=(test_features,
to_categorical(test_labels, num_clusters)),
tf.keras.utils.to_categorical(
test_labels, num_clusters)),
batch_size=64,
epochs=num_epochs,
shuffle=True)

View file

@ -13,10 +13,10 @@
# limitations under the License.
"""Plotting code for ML Privacy Reports."""
from typing import Iterable
from typing import Iterable, Sequence
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsDFColumns
@ -30,7 +30,7 @@ TRAIN_ACCURACY_STR = 'Train accuracy'
def plot_by_epochs(results: AttackResultsCollection,
privacy_metrics: Iterable[PrivacyMetric]) -> plt.Figure:
privacy_metrics: Sequence[PrivacyMetric]) -> plt.Figure:
"""Plots privacy vulnerabilities vs epoch numbers.
In case multiple privacy metrics are specified, the plot will feature
@ -55,7 +55,7 @@ def plot_by_epochs(results: AttackResultsCollection,
def plot_privacy_vs_accuracy(results: AttackResultsCollection,
privacy_metrics: Iterable[PrivacyMetric]):
privacy_metrics: Sequence[PrivacyMetric]):
"""Plots privacy vulnerabilities vs accuracy plots.
In case multiple privacy metrics are specified, the plot will feature
@ -105,7 +105,7 @@ def _calculate_combined_df_with_metadata(results: Iterable[AttackResults]):
def _generate_subplots(all_results_df: pd.DataFrame, x_axis_metric: str,
figure_title: str,
privacy_metrics: Iterable[PrivacyMetric]):
privacy_metrics: Sequence[PrivacyMetric]):
"""Create one subplot per privacy metric for a specified x_axis_metric."""
fig, axes = plt.subplots(
1, len(privacy_metrics), figsize=(5 * len(privacy_metrics) + 3, 5))

View file

@ -20,7 +20,6 @@ from absl import app
from absl import flags
from absl import logging
import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy_lib
from tensorflow_privacy.privacy.optimizers import dp_optimizer
import mnist_dpsgd_tutorial_common as common