Adds per-example membership scores to trained attackers.

PiperOrigin-RevId: 431615160
This commit is contained in:
Shuang Song 2022-02-28 23:51:55 -08:00 committed by A. Unique TensorFlower
parent a33afde0c1
commit 767788e9cf
5 changed files with 232 additions and 134 deletions

View file

@ -23,6 +23,7 @@ from absl import app
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
from sklearn import metrics
import tensorflow as tf
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures
@ -69,69 +70,69 @@ def generate_features_and_labels(samples_per_cluster=250, scale=0.1):
return (features, labels)
# Hint: Play with "noise_scale" for different levels of overlap between
# the generated clusters. More noise makes the classification harder.
noise_scale = 2
training_features, training_labels = generate_features_and_labels(
samples_per_cluster=250, scale=noise_scale)
test_features, test_labels = generate_features_and_labels(
samples_per_cluster=250, scale=noise_scale)
def get_models(num_clusters):
"""Get the two models we will be using."""
# Hint: play with the number of layers to achieve different level of
# over-fitting and observe its effects on membership inference performance.
three_layer_model = tf.keras.Sequential([
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(num_clusters, activation="relu"),
tf.keras.layers.Softmax()
])
three_layer_model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=["accuracy"])
num_clusters = int(round(np.max(training_labels))) + 1
# Hint: play with the number of layers to achieve different level of
# over-fitting and observe its effects on membership inference performance.
three_layer_model = tf.keras.Sequential([
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(num_clusters, activation="relu"),
tf.keras.layers.Softmax()
])
three_layer_model.compile(
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
two_layer_model = tf.keras.Sequential([
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(num_clusters, activation="relu"),
tf.keras.layers.Softmax()
])
two_layer_model.compile(
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
def crossentropy(true_labels, predictions):
return tf.keras.backend.eval(
tf.keras.metrics.binary_crossentropy(
tf.keras.backend.variable(
tf.keras.utils.to_categorical(true_labels, num_clusters)),
tf.keras.backend.variable(predictions)))
two_layer_model = tf.keras.Sequential([
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(300, activation="relu"),
tf.keras.layers.Dense(num_clusters, activation="relu"),
tf.keras.layers.Softmax()
])
two_layer_model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=["accuracy"])
return three_layer_model, two_layer_model
def main(unused_argv):
epoch_results = data_structures.AttackResultsCollection([])
# Hint: Play with "noise_scale" for different levels of overlap between
# the generated clusters. More noise makes the classification harder.
noise_scale = 2
training_features, training_labels = generate_features_and_labels(
samples_per_cluster=250, scale=noise_scale)
test_features, test_labels = generate_features_and_labels(
samples_per_cluster=250, scale=noise_scale)
num_epochs = 2
num_clusters = int(round(np.max(training_labels))) + 1
three_layer_model, two_layer_model = get_models(num_clusters)
models = {
"two layer model": two_layer_model,
"three layer model": three_layer_model,
"two_layer_model": two_layer_model,
"three_layer_model": three_layer_model,
}
for model_name in models:
# Incrementally train the model and store privacy metrics every num_epochs.
for i in range(1, 6):
models[model_name].fit(
num_epochs_per_round = 20
epoch_results = data_structures.AttackResultsCollection([])
for model_name, model in models.items():
print(f"Train {model_name}.")
# Incrementally train the model and store privacy metrics
# every num_epochs_per_round.
for i in range(5):
model.fit(
training_features,
tf.keras.utils.to_categorical(training_labels, num_clusters),
validation_data=(test_features,
tf.keras.utils.to_categorical(
test_labels, num_clusters)),
training_labels,
validation_data=(test_features, test_labels),
batch_size=64,
epochs=num_epochs,
epochs=num_epochs_per_round,
shuffle=True)
training_pred = models[model_name].predict(training_features)
test_pred = models[model_name].predict(test_features)
training_pred = model.predict(training_features)
test_pred = model.predict(test_features)
# Add metadata to generate a privacy report.
privacy_report_metadata = data_structures.PrivacyReportMetadata(
@ -139,7 +140,7 @@ def main(unused_argv):
training_labels, np.argmax(training_pred, axis=1)),
accuracy_test=metrics.accuracy_score(test_labels,
np.argmax(test_pred, axis=1)),
epoch_num=num_epochs * i,
epoch_num=num_epochs_per_round * (i + 1),
model_variant_label=model_name)
attack_results = mia.run_attacks(
@ -147,9 +148,7 @@ def main(unused_argv):
labels_train=training_labels,
labels_test=test_labels,
probs_train=training_pred,
probs_test=test_pred,
loss_train=crossentropy(training_labels, training_pred),
loss_test=crossentropy(test_labels, test_pred)),
probs_test=test_pred),
data_structures.SlicingSpec(entire_dataset=True, by_class=True),
attack_types=(data_structures.AttackType.THRESHOLD_ATTACK,
data_structures.AttackType.LOGISTIC_REGRESSION),
@ -216,6 +215,39 @@ def main(unused_argv):
# For saving a figure into a file:
# plotting.save_plot(figure, <file_path>)
# Let's look at the per-example membership scores. We'll look at how the
# scores from threshold and logistic regression attackers correlate.
# We take the MIA result of the final three layer model
sample_model = epoch_results.attack_results_list[-1]
print("We will look at the membership scores of",
sample_model.privacy_report_metadata.model_variant_label, "at epoch",
sample_model.privacy_report_metadata.epoch_num)
sample_results = sample_model.single_attack_results
# The first two entries of sample_results are from the threshold and
# logistic regression attackers on the whole dataset.
print("Correlation between the scores of the following two attackers:", "\n ",
sample_results[0].slice_spec, sample_results[0].attack_type, "\n ",
sample_results[1].slice_spec, sample_results[1].attack_type)
threshold_results = np.concatenate( # scores by threshold attacker
(sample_results[0].membership_scores_train,
sample_results[0].membership_scores_test))
lr_results = np.concatenate( # scores by logistic regression attacker
(sample_results[1].membership_scores_train,
sample_results[1].membership_scores_test))
# Order the scores and plot them
threshold_orders = scipy.stats.rankdata(threshold_results)
lr_orders = scipy.stats.rankdata(lr_results)
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(5, 5))
axes.scatter(threshold_orders, lr_orders, alpha=0.2, linewidths=0)
m, b = np.polyfit(threshold_orders, lr_orders, 1) # linear fit
axes.plot(threshold_orders, m * threshold_orders + b, color="orange")
axes.set_aspect("equal", adjustable="box")
fig.show()
if __name__ == "__main__":
app.run(main)

View file

@ -21,6 +21,7 @@ from typing import Iterable
import numpy as np
from sklearn import metrics
from sklearn import model_selection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
@ -44,49 +45,61 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
return SingleSliceSpec()
# TODO(b/220394926): Allow users to specify their own attack models.
def _run_trained_attack(attack_input: AttackInputData,
attack_type: AttackType,
balance_attacker_training: bool = True):
balance_attacker_training: bool = True,
cross_validation_folds: int = 2):
"""Classification attack done by ML models."""
attacker = None
if attack_type == AttackType.LOGISTIC_REGRESSION:
attacker = models.LogisticRegressionAttacker()
elif attack_type == AttackType.MULTI_LAYERED_PERCEPTRON:
attacker = models.MultilayerPerceptronAttacker()
elif attack_type == AttackType.RANDOM_FOREST:
attacker = models.RandomForestAttacker()
elif attack_type == AttackType.K_NEAREST_NEIGHBORS:
attacker = models.KNearestNeighborsAttacker()
else:
raise NotImplementedError('Attack type %s not implemented yet.' %
attack_type)
prepared_attacker_data = models.create_attacker_data(
attack_input, balance=balance_attacker_training)
indices = prepared_attacker_data.fold_indices
left_out_indices = prepared_attacker_data.left_out_indices
features = prepared_attacker_data.features_all
labels = prepared_attacker_data.labels_all
attacker.train_model(prepared_attacker_data.features_train,
prepared_attacker_data.is_training_labels_train)
# We are going to train multiple models on disjoint subsets of the data
# (`features`, `labels`), so we can get the membership scores of all samples,
# and each example gets its score assigned only once.
# An alternative implementation is to train multiple models on overlapping
# subsets of the data, and take an average to get the score for each sample.
# `scores` will record the membership score of each sample, initialized to nan
scores = np.full(features.shape[0], np.nan)
# Run the attacker on (permuted) test examples.
predictions_test = attacker.predict(prepared_attacker_data.features_test)
# We use StratifiedKFold to create disjoint subsets of samples. Notice that
# the index it returns is with respect to the samples shuffled with `indices`.
kf = model_selection.StratifiedKFold(cross_validation_folds, shuffle=False)
for train_indices_in_shuffled, test_indices_in_shuffled in kf.split(
features[indices], labels[indices]):
# `train_indices_in_shuffled` is with respect to the data shuffled with
# `indices`. We convert it to `train_indices` to work with the original
# data (`features` and 'labels').
train_indices = indices[train_indices_in_shuffled]
test_indices = indices[test_indices_in_shuffled]
# Make sure one sample only got score predicted once
assert np.all(np.isnan(scores[test_indices]))
# Generate ROC curves with predictions.
fpr, tpr, thresholds = metrics.roc_curve(
prepared_attacker_data.is_training_labels_test, predictions_test)
attacker = models.create_attacker(attack_type)
attacker.train_model(features[train_indices], labels[train_indices])
scores[test_indices] = attacker.predict(features[test_indices])
# Predict the left out with the last attacker
if left_out_indices.size:
assert np.all(np.isnan(scores[left_out_indices]))
scores[left_out_indices] = attacker.predict(features[left_out_indices])
assert not np.any(np.isnan(scores))
# Generate ROC curves with scores.
fpr, tpr, thresholds = metrics.roc_curve(labels, scores)
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
# NOTE: In the current setup we can't obtain membership scores for all
# samples, since some of them were used to train the attacker. This can be
# fixed by training several attackers to ensure each sample was left out
# in exactly one attacker (basically, this means performing cross-validation).
# TODO(b/175870479): Implement membership scores for predicted attackers.
in_train_indices = (labels == 0)
return SingleAttackResult(
slice_spec=_get_slice_spec(attack_input),
data_size=prepared_attacker_data.data_size,
attack_type=attack_type,
membership_scores_train=scores[in_train_indices],
membership_scores_test=scores[~in_train_indices],
roc_curve=roc_curve)
@ -107,8 +120,8 @@ def _run_threshold_attack(attack_input: AttackInputData):
slice_spec=_get_slice_spec(attack_input),
data_size=DataSize(ntrain=ntrain, ntest=ntest),
attack_type=AttackType.THRESHOLD_ATTACK,
membership_scores_train=-attack_input.get_loss_train(),
membership_scores_test=-attack_input.get_loss_test(),
membership_scores_train=attack_input.get_loss_train(),
membership_scores_test=attack_input.get_loss_test(),
roc_curve=roc_curve)

View file

@ -90,6 +90,31 @@ class RunAttacksTest(absltest.TestCase):
self.assertLen(result.membership_scores_train, 100)
self.assertLen(result.membership_scores_test, 50)
def test_run_attack_trained_sets_membership_scores(self):
attack_input = AttackInputData(
logits_train=np.tile([500., -500.], (100, 1)),
logits_test=np.tile([0., 0.], (50, 1)))
result = mia._run_trained_attack(
attack_input,
AttackType.LOGISTIC_REGRESSION,
balance_attacker_training=True)
self.assertLen(result.membership_scores_train, 100)
self.assertLen(result.membership_scores_test, 50)
# Scores for all training (resp. test) examples should be close
np.testing.assert_allclose(
result.membership_scores_train,
result.membership_scores_train[0],
rtol=1e-3)
np.testing.assert_allclose(
result.membership_scores_test,
result.membership_scores_test[0],
rtol=1e-3)
# Training score should be smaller than test score
self.assertLess(result.membership_scores_train[0],
result.membership_scores_test[0])
def test_run_attack_threshold_calculates_correct_auc(self):
result = mia._run_attack(
AttackInputData(

View file

@ -15,7 +15,6 @@
import dataclasses
from typing import Optional
import numpy as np
from sklearn import ensemble
from sklearn import linear_model
@ -23,30 +22,34 @@ from sklearn import model_selection
from sklearn import neighbors
from sklearn import neural_network
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures
@dataclasses.dataclass
class AttackerData:
"""Input data for an ML classifier attack.
This includes only the data, and not configuration.
Labels in this class correspond to whether an example was in the
train or test set.
"""
# Features of in-training and out-of-training examples.
features_all: Optional[np.ndarray] = None
# Indicator for whether the example is in-training (0) or out-of-training (1).
labels_all: Optional[np.ndarray] = None
features_train: Optional[np.ndarray] = None
# element-wise boolean array denoting if the example was part of training.
is_training_labels_train: Optional[np.ndarray] = None
# Indices for `features_all` and `labels_all` that are going to be used for
# training the attackers.
fold_indices: Optional[np.ndarray] = None
features_test: Optional[np.ndarray] = None
# element-wise boolean array denoting if the example was part of training.
is_training_labels_test: Optional[np.ndarray] = None
# Indices for `features_all` and `labels_all` that were left out due to
# balancing. Disjoint from `fold_indices`.
left_out_indices: Optional[np.ndarray] = None
data_size: Optional[DataSize] = None
# Number of in-training and out-of-training examples.
data_size: Optional[data_structures.DataSize] = None
def create_attacker_data(attack_input_data: AttackInputData,
test_fraction: float = 0.25,
def create_attacker_data(attack_input_data: data_structures.AttackInputData,
balance: bool = True) -> AttackerData:
"""Prepare AttackInputData to train ML attackers.
@ -54,7 +57,6 @@ def create_attacker_data(attack_input_data: AttackInputData,
Args:
attack_input_data: Original AttackInputData
test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples from the
training and test sets used to develop the model under attack.
@ -67,25 +69,49 @@ def create_attacker_data(attack_input_data: AttackInputData,
attack_input_test = _column_stack(attack_input_data.logits_or_probs_test,
attack_input_data.get_loss_test())
if balance:
min_size = min(attack_input_data.get_train_size(),
attack_input_data.get_test_size())
attack_input_train = _sample_multidimensional_array(attack_input_train,
min_size)
attack_input_test = _sample_multidimensional_array(attack_input_test,
min_size)
ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0]
features_all = np.concatenate((attack_input_train, attack_input_test))
labels_all = np.concatenate((np.zeros(ntrain), np.ones(ntest)))
labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest))))
fold_indices = np.arange(ntrain + ntest)
left_out_indices = np.asarray([], dtype=np.int32)
# Perform a train-test split
features_train, features_test, is_training_labels_train, is_training_labels_test = model_selection.train_test_split(
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
return AttackerData(features_train, is_training_labels_train, features_test,
is_training_labels_test,
DataSize(ntrain=ntrain, ntest=ntest))
if balance:
idx_train, idx_test = range(ntrain), range(ntrain, ntrain + ntest)
min_size = min(ntrain, ntest)
if ntrain > min_size:
left_out_size = ntrain - min_size
perm_train = np.random.permutation(idx_train) # shuffle training
left_out_indices = perm_train[:left_out_size]
fold_indices = np.concatenate((perm_train[left_out_size:], idx_test))
elif ntest > min_size:
left_out_size = ntest - min_size
perm_test = np.random.permutation(idx_test) # shuffle test
left_out_indices = perm_test[:left_out_size]
fold_indices = np.concatenate((perm_test[left_out_size:], idx_train))
# Shuffle indices for the downstream attackers.
fold_indices = np.random.permutation(fold_indices)
return AttackerData(
features_all=features_all,
labels_all=labels_all,
fold_indices=fold_indices,
left_out_indices=left_out_indices,
data_size=data_structures.DataSize(ntrain=ntrain, ntest=ntest))
def create_attacker(attack_type):
"""Returns the corresponding attacker for the provided attack_type."""
if attack_type == data_structures.AttackType.LOGISTIC_REGRESSION:
return LogisticRegressionAttacker()
if attack_type == data_structures.AttackType.MULTI_LAYERED_PERCEPTRON:
return MultilayerPerceptronAttacker()
if attack_type == data_structures.AttackType.RANDOM_FOREST:
return RandomForestAttacker()
if attack_type == data_structures.AttackType.K_NEAREST_NEIGHBORS:
return KNearestNeighborsAttacker()
raise NotImplementedError('Attack type %s not implemented yet.' % attack_type)
def _sample_multidimensional_array(array, size):

View file

@ -33,9 +33,8 @@ class TrainedAttackerTest(absltest.TestCase):
def test_create_attacker_data_loss_only(self):
attack_input = AttackInputData(
loss_train=np.array([1, 3]), loss_test=np.array([2, 4]))
attacker_data = models.create_attacker_data(attack_input, 0.5)
self.assertLen(attacker_data.features_test, 2)
self.assertLen(attacker_data.features_train, 2)
attacker_data = models.create_attacker_data(attack_input, 2)
self.assertLen(attacker_data.features_all, 4)
def test_create_attacker_data_loss_and_logits(self):
attack_input = AttackInputData(
@ -43,15 +42,22 @@ class TrainedAttackerTest(absltest.TestCase):
logits_test=np.array([[10, 11], [14, 15]]),
loss_train=np.array([3, 7, 10]),
loss_test=np.array([12, 16]))
attacker_data = models.create_attacker_data(
attack_input, 0.25, balance=False)
self.assertLen(attacker_data.features_test, 2)
self.assertLen(attacker_data.features_train, 3)
attacker_data = models.create_attacker_data(attack_input, balance=False)
self.assertLen(attacker_data.features_all, 5)
self.assertLen(attacker_data.fold_indices, 5)
self.assertEmpty(attacker_data.left_out_indices)
for i, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 3) # each feature has two logits and one loss
expected = feature[:2] not in attack_input.logits_train
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
def test_unbalanced_create_attacker_data_loss_and_logits(self):
attack_input = AttackInputData(
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
logits_test=np.array([[10, 11], [14, 15]]),
loss_train=np.array([3, 7, 10]),
loss_test=np.array([12, 16]))
attacker_data = models.create_attacker_data(attack_input, balance=True)
self.assertLen(attacker_data.features_all, 5)
self.assertLen(attacker_data.fold_indices, 4)
self.assertLen(attacker_data.left_out_indices, 1)
self.assertIn(attacker_data.left_out_indices[0], [0, 1, 2])
def test_balanced_create_attacker_data_loss_and_logits(self):
attack_input = AttackInputData(
@ -59,14 +65,10 @@ class TrainedAttackerTest(absltest.TestCase):
logits_test=np.array([[10, 11], [14, 15], [17, 18]]),
loss_train=np.array([3, 7, 10]),
loss_test=np.array([12, 16, 19]))
attacker_data = models.create_attacker_data(attack_input, 0.33)
self.assertLen(attacker_data.features_test, 2)
self.assertLen(attacker_data.features_train, 4)
for i, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 3) # each feature has two logits and one loss
expected = feature[:2] not in attack_input.logits_train
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
attacker_data = models.create_attacker_data(attack_input)
self.assertLen(attacker_data.features_all, 6)
self.assertLen(attacker_data.fold_indices, 6)
self.assertEmpty(attacker_data.left_out_indices)
if __name__ == '__main__':