diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/README.md b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/README.md index ab3fa95..94f9c0a 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/README.md +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/README.md @@ -1,16 +1,16 @@ # Membership inference attack -A good privacy-preserving model learns from the training data, but -doesn't memorize it. This library provides empirical tests for measuring -potential memorization. +A good privacy-preserving model learns from the training data, but doesn't +memorize it. This library provides empirical tests for measuring potential +memorization. Technically, the tests build classifiers that infer whether a particular sample was present in the training set. The more accurate such classifier is, the more memorization is present and thus the less privacy-preserving the model is. -The privacy vulnerability (or memorization potential) is measured -via the area under the ROC-curve (`auc`) or via max{|fpr - tpr|} (`advantage`) -of the attack classifier. These measures are very closely related. +The privacy vulnerability (or memorization potential) is measured via the area +under the ROC-curve (`auc`) or via max{|fpr - tpr|} (`advantage`) of the attack +classifier. These measures are very closely related. The tests provided by the library are "black box". That is, only the outputs of the model are used (e.g., losses, logits, predictions). Neither model internals @@ -69,7 +69,8 @@ print(attacks_result.summary()) ### Other codelabs -Please head over to the [codelabs](https://github.com/tensorflow/privacy/tree/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs) +Please head over to the +[codelabs](https://github.com/tensorflow/privacy/tree/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs) section for an overview of the library in action. ### Advanced usage @@ -77,11 +78,10 @@ section for an overview of the library in action. #### Specifying attacks to run Sometimes, we have more information about the data, such as the logits and the -labels, -and we may want to have finer-grained control of the attack, such as using more -complicated classifiers instead of the simple threshold attack, and looks at the -attack results by examples' class. -In thoses cases, we can provide more information to `run_attacks`. +labels, and we may want to have finer-grained control of the attack, such as +using more complicated classifiers instead of the simple threshold attack, and +looks at the attack results by examples' class. In thoses cases, we can provide +more information to `run_attacks`. ```python from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia @@ -109,15 +109,13 @@ attack_input = AttackInputData( labels_test = labels_test) ``` -Instead of `logits`, you can also specify -`probs_train` and `probs_test` as the predicted probabilty vectors of each -example. +Instead of `logits`, you can also specify `probs_train` and `probs_test` as the +predicted probabilty vectors of each example. -Then, we specify some details of the attack. -The first part includes the specifications of the slicing of the data. For -example, we may want to evaluate the result on the whole dataset, or by class, -percentiles, or the correctness of the model's classification. -These can be specified by a `SlicingSpec` object. +Then, we specify some details of the attack. The first part includes the +specifications of the slicing of the data. For example, we may want to evaluate +the result on the whole dataset, or by class, percentiles, or the correctness of +the model's classification. These can be specified by a `SlicingSpec` object. ```python slicing_spec = SlicingSpec( @@ -127,16 +125,13 @@ slicing_spec = SlicingSpec( by_classification_correctness = True) ``` -The second part specifies the classifiers for the attacker to use. -Currently, our API supports five classifiers, including -`AttackType.THRESHOLD_ATTACK` for simple threshold attack, -`AttackType.LOGISTIC_REGRESSION`, -`AttackType.MULTI_LAYERED_PERCEPTRON`, -`AttackType.RANDOM_FOREST`, and -`AttackType.K_NEAREST_NEIGHBORS` -which use the corresponding machine learning models. -For some model, different classifiers can yield pertty different results. -We can put multiple classifers in a list: +The second part specifies the classifiers for the attacker to use. Currently, +our API supports five classifiers, including `AttackType.THRESHOLD_ATTACK` for +simple threshold attack, `AttackType.LOGISTIC_REGRESSION`, +`AttackType.MULTI_LAYERED_PERCEPTRON`, `AttackType.RANDOM_FOREST`, and +`AttackType.K_NEAREST_NEIGHBORS` which use the corresponding machine learning +models. For some model, different classifiers can yield pertty different +results. We can put multiple classifers in a list: ```python attack_types = [ @@ -187,7 +182,6 @@ print(attacks_result.summary(by_slices = True)) # THRESHOLD_ATTACK achieved an advantage of 0.38 ``` - #### Viewing and plotting the attack results We have seen an example of using `summary()` to view the attack results as text. @@ -199,6 +193,7 @@ To get the attack that achieves the maximum attacker advantage or AUC, we can do max_auc_attacker = attacks_result.get_result_with_max_auc() max_advantage_attacker = attacks_result.get_result_with_max_attacker_advantage() ``` + Then, for individual attack, such as `max_auc_attacker`, we can check its type, attacker advantage and AUC by @@ -210,6 +205,7 @@ print("Attack type with max AUC: %s, AUC of %.2f, Attacker advantage of %.2f" % # Example output: # -> Attack type with max AUC: THRESHOLD_ATTACK, AUC of 0.75, Attacker advantage of 0.38 ``` + We can also plot its ROC curve by ```python @@ -217,6 +213,7 @@ import tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.plot figure = plotting.plot_roc_curve(max_auc_attacker.roc_curve) ``` + which would give a figure like the one below ![roc_fig](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelab_roc_fig.png?raw=true) @@ -241,16 +238,41 @@ print(attacks_result.calculate_pd_dataframe()) # 25 correctly_classfied False lr 0.370713 0.737148 ``` +#### Advanced Membership Inference Attacks + +Threshold MIA uses the intuition that training samples usually have lower loss +than test samples, and it thus predict samples with loss lower than a threshold +as in-training / member. However, some data samples might be intrinsically +harder than others. For example, a hard sample might have pretty high loss even +when included in the training set, and an easy sample might get low loss even +when it's not. So using the same threshold for all samples might be suboptimal. + +People have considered customizing the membership prediction criteria for +different examples by looking at how they behave when included in or excluded +from the training set. To do that, we can train a few shadow models with +training sets being different subsets of all samples. Then for each sample, we +will know what its loss looks like when it's a member or non-member. Now, we can +compare its loss from the target model to those from the shadow models. For +example, if the average loss is `x` when the sample is a member, and `y` when +it's not, we might adjust the target loss by subtracting `(x+y)/2`. We can +expect the adjusted losses of different samples to be more of the same scale +compared to the original target losses. This gives us potentially better +estimations for membership. + +In `advanced_mia.py`, we provide the method described above, and another method +that uses a more advanced way, i.e. distribution fitting to estimate membership. +`advanced_mia_example.py` shows an example for doing the advanced membership +inference on a CIFAR-10 task. + ### External guides / press mentions -* [Introductory blog post](https://franziska-boenisch.de/posts/2021/01/membership-inference/) -to the theory and the library by Franziska Boenisch from the Fraunhofer AISEC -institute. -* [Google AI Blog Post](https://ai.googleblog.com/2021/01/google-research-looking-back-at-2020.html#ResponsibleAI) -* [TensorFlow Blog Post](https://blog.tensorflow.org/2020/06/introducing-new-privacy-testing-library.html) -* [VentureBeat article](https://venturebeat.com/2020/06/24/google-releases-experimental-tensorflow-module-that-tests-the-privacy-of-ai-models/) -* [Tech Xplore article](https://techxplore.com/news/2020-06-google-tensorflow-privacy-module.html) - +* [Introductory blog post](https://franziska-boenisch.de/posts/2021/01/membership-inference/) + to the theory and the library by Franziska Boenisch from the Fraunhofer + AISEC institute. +* [Google AI Blog Post](https://ai.googleblog.com/2021/01/google-research-looking-back-at-2020.html#ResponsibleAI) +* [TensorFlow Blog Post](https://blog.tensorflow.org/2020/06/introducing-new-privacy-testing-library.html) +* [VentureBeat article](https://venturebeat.com/2020/06/24/google-releases-experimental-tensorflow-module-that-tests-the-privacy-of-ai-models/) +* [Tech Xplore article](https://techxplore.com/news/2020-06-google-tensorflow-privacy-module.html) ## Contact / Feedback diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py new file mode 100644 index 0000000..4dd109a --- /dev/null +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py @@ -0,0 +1,254 @@ +# Copyright 2022, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions for advanced membership inference attacks.""" + +import functools +from typing import Sequence, Union +import numpy as np +import scipy.stats +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss + + +def replace_nan_with_column_mean(a: np.ndarray): + """Replaces each NaN with the mean of the corresponding column.""" + mean = np.nanmean(a, axis=0) # get the column-wise mean + for i in range(a.shape[1]): + np.nan_to_num(a[:, i], copy=False, nan=mean[i]) + + +def compute_score_offset(stat_target: Union[np.ndarray, Sequence[float]], + stat_in: Sequence[np.ndarray], + stat_out: Sequence[np.ndarray], + option: str = 'both', + median_or_mean: str = 'median') -> np.ndarray: + """Computes score of each sample as stat_target - some offset. + + Args: + stat_target: a list or numpy array where stat_target[i] is the statistics of + example i computed from the target model. stat_target[i] is an array of k + scalars for k being the number of augmentations for each sample. + stat_in: a list where stat_in[i] is the in-training statistics of example i. + stat_in[i] is a m by k numpy array where m is the number of shadow models + and k is the number of augmentations for each sample. m can be different + for different examples. + stat_out: a list where stat_out[i] is the out-training statistics of example + i. stat_out[i] is a m by k numpy array where m is the number of shadow + models and k is the number of augmentations for each sample. + option: using stat_in ("in"), stat_out ("out"), or both ("both"). + median_or_mean: use median or mean across shadow models. + + Returns: + The score of each sample as stat_target - some offset, where offset is + computed with stat_in, stat_out, or both depending on the option. + The relation between the score and the membership depends on that + between stat_target and membership. + """ + if option not in ['both', 'in', 'out']: + raise ValueError('option should be "both", "in", or "out".') + if median_or_mean not in ['median', 'mean']: + raise ValueError('median_or_mean should be either "median" or "mean".') + if option in ['in', 'both']: + if any([s.ndim != 2 for s in stat_in]): + raise ValueError('Each element in stat_in should be a 2-d numpy array.') + if any([s.shape[1] != stat_in[0].shape[1] for s in stat_in]): + raise ValueError('Each element in stat_in should have the same size ' + 'in the second dimension.') + if option in ['out', 'both']: + if any([s.ndim != 2 for s in stat_out]): + raise ValueError('Each element in stat_out should be a 2-d numpy array.') + if any([s.shape[1] != stat_out[0].shape[1] for s in stat_out]): + raise ValueError('Each element in stat_out should have the same size ' + 'in the second dimension.') + func_avg = functools.partial( + np.nanmedian if median_or_mean == 'median' else np.nanmean, axis=0) + + if option == 'both': # use the average of the in-score and out-score + avg_in = np.array(list(map(func_avg, stat_in))) + avg_out = np.array(list(map(func_avg, stat_out))) + # use average in case of NaN + replace_nan_with_column_mean(avg_in) + replace_nan_with_column_mean(avg_out) + offset = (avg_in + avg_out) / 2 + elif option == 'in': # use in-score only + offset = np.array(list(map(func_avg, stat_in))) + replace_nan_with_column_mean(offset) + else: # use out-score only + offset = np.array(list(map(func_avg, stat_out))) + replace_nan_with_column_mean(offset) + scores = (stat_target - offset).mean(axis=1) + return scores + + +def compute_score_lira(stat_target: Union[np.ndarray, Sequence[float]], + stat_in: Sequence[np.ndarray], + stat_out: Sequence[np.ndarray], + option: str = 'both', + fix_variance: bool = False, + median_or_mean: str = 'median') -> np.ndarray: + """Computes score of each sample using Gaussian distribution fitting. + + Args: + stat_target: a list or numpy array where stat_target[i] is the statistics of + example i computed from the target model. stat_target[i] is an array of k + scalars for k being the number of augmentations for each sample. + stat_in: a list where stat_in[i] is the in-training statistics of example i. + stat_in[i] is a m by k numpy array where m is the number of shadow models + and k is the number of augmentations for each sample. m can be different + for different examples. + stat_out: a list where stat_out[i] is the out-training statistics of example + i. stat_out[i] is a m by k numpy array where m is the number of shadow + models and k is the number of augmentations for each sample. + option: using stat_in ("in"), stat_out ("out"), or both ("both"). + fix_variance: whether to use the same variance for all examples. + median_or_mean: use median or mean across shadow models. + + Returns: + log(Pr(out)) - log(Pr(in)), log(Pr(out)), or -log(Pr(in)) depending on the + option. In-training sample is expected to have small value. + The idea is from https://arxiv.org/pdf/2112.03570.pdf. + """ + # median of statistics across shadow models + if option not in ['both', 'in', 'out']: + raise ValueError('option should be "both", "in", or "out".') + if median_or_mean not in ['median', 'mean']: + raise ValueError('median_or_mean should be either "median" or "mean".') + if option in ['in', 'both']: + if any([s.ndim != 2 for s in stat_in]): + raise ValueError('Each element in stat_in should be a 2-d numpy array.') + if any([s.shape[1] != stat_in[0].shape[1] for s in stat_in]): + raise ValueError('Each element in stat_in should have the same size ' + 'in the second dimension.') + if option in ['out', 'both']: + if any([s.ndim != 2 for s in stat_out]): + raise ValueError('Each element in stat_out should be a 2-d numpy array.') + if any([s.shape[1] != stat_out[0].shape[1] for s in stat_out]): + raise ValueError('Each element in stat_out should have the same size ' + 'in the second dimension.') + + func_avg = functools.partial( + np.nanmedian if median_or_mean == 'median' else np.nanmean, axis=0) + if option in ['in', 'both']: + avg_in = np.array(list(map(func_avg, stat_in))) # n by k array + replace_nan_with_column_mean(avg_in) # use column average in case of NaN + if option in ['out', 'both']: + avg_out = np.array(list(map(func_avg, stat_out))) + replace_nan_with_column_mean(avg_out) + + if fix_variance: + # standard deviation of statistics across shadow models and examples + if option in ['in', 'both']: + std_in = np.nanstd( + np.concatenate([l - m[np.newaxis] for l, m in zip(stat_in, avg_in)])) + if option in ['out', 'both']: + std_out = np.nanstd( + np.concatenate([l - m[np.newaxis] for l, m in zip(stat_out, avg_out) + ])) + else: + # standard deviation of statistics across shadow models + func_std = functools.partial(np.nanstd, axis=0) + if option in ['in', 'both']: + std_in = np.array(list(map(func_std, stat_in))) + replace_nan_with_column_mean(std_in) + if option in ['out', 'both']: + std_out = np.array(list(map(func_std, stat_out))) + replace_nan_with_column_mean(std_out) + + stat_target = np.array(stat_target) + if option in ['in', 'both']: + log_pr_in = scipy.stats.norm.logpdf(stat_target, avg_in, std_in + 1e-30) + if option in ['out', 'both']: + log_pr_out = scipy.stats.norm.logpdf(stat_target, avg_out, std_out + 1e-30) + + if option == 'both': + scores = -(log_pr_in - log_pr_out).mean(axis=1) + elif option == 'in': + scores = -log_pr_in.mean(axis=1) + else: + scores = log_pr_out.mean(axis=1) + return scores + + +def convert_logit_to_prob(logit: np.ndarray) -> np.ndarray: + """Converts logits to probability vectors. + + Args: + logit: n by c array where n is the number of samples and c is the number of + classes. + + Returns: + The probability vectors as n by c array + """ + prob = logit - np.max(logit, axis=1, keepdims=True) + prob = np.array(np.exp(prob), dtype=np.float64) + prob = prob / np.sum(prob, axis=1, keepdims=True) + return prob + + +def calculate_statistic(pred: np.ndarray, + labels: np.ndarray, + is_logits: bool = True, + option: str = 'logit', + small_value: float = 1e-45): + """Calculates the statistics of each sample. + + The statistics is: + for option="conf with prob", p, the probability of the true class; + for option="xe", the cross-entropy loss; + for option="logit", log(p / (1 - p)); + for option="conf with logit", max(logits); + for option="hinge", logit of the true class - max(logits of the other + classes). + + Args: + pred: the logits or probability vectors, depending on the value of is_logit. + An array of size n by c where n is the number of samples and c is the + number of classes + labels: true labels of samples (integer valued) + is_logits: whether pred is logits or probability vectors + option: confidence using probability, xe loss, logit of confidence, + confidence using logits, hinge loss + small_value: a small value to avoid numerical issue + + Returns: + the computed statistics as size n array + """ + if option not in [ + 'conf with prob', 'xe', 'logit', 'conf with logit', 'hinge' + ]: + raise ValueError( + 'option should be one of ["conf with prob", "xe", "logit", "conf with logit", "hinge"].' + ) + if option in ['conf with logit', 'hinge']: + if not is_logits: # the input needs to be the logits + raise ValueError('To compute statistics with option "conf with logit" ' + 'or "hinge", the input must be logits instead of ' + 'probability vectors.') + elif is_logits: + pred = convert_logit_to_prob(pred) + + n = labels.size # number of samples + if option in ['conf with prob', 'conf with logit']: + return pred[range(n), labels] + if option == 'xe': + return log_loss(labels, pred) + if option == 'logit': + p_true = pred[range(n), labels] + pred[range(n), labels] = 0 + p_other = pred.sum(axis=1) + return np.log(p_true + small_value) - np.log(p_other + small_value) + if option == 'hinge': + l_true = pred[range(n), labels] + pred[range(n), labels] = -np.inf + return l_true - pred.max(axis=1) + raise ValueError diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py new file mode 100644 index 0000000..a0e570c --- /dev/null +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py @@ -0,0 +1,219 @@ +# Copyright 2022, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An example for using advanced_mia.""" + +import functools +import gc +import os +from absl import app +from absl import flags +import matplotlib.pyplot as plt +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData + +FLAGS = flags.FLAGS +_LR = flags.DEFINE_float('learning_rate', 0.02, 'Learning rate for training') +_BATCH = flags.DEFINE_integer('batch_size', 250, 'Batch size') +_EPOCHS = flags.DEFINE_integer('epochs', 20, 'Number of epochs') +_NUM_SHADOWS = flags.DEFINE_integer('num_shadows', 10, + 'Number of shadow models.') +_MODEL_DIR = flags.DEFINE_string('model_dir', None, 'Model directory.') + + +def small_cnn(): + """Setup a small CNN for image classification.""" + model = tf.keras.models.Sequential() + # Add a layer to do random horizontal augmentation. + model.add(tf.keras.layers.RandomFlip('horizontal')) + model.add(tf.keras.layers.Input(shape=(32, 32, 3))) + + for _ in range(3): + model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu')) + model.add(tf.keras.layers.MaxPooling2D()) + + model.add(tf.keras.layers.Flatten()) + model.add(tf.keras.layers.Dense(64, activation='relu')) + model.add(tf.keras.layers.Dense(10)) + return model + + +def load_cifar10(): + """Loads CIFAR10, with training and test combined.""" + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() + x = np.concatenate([x_train, x_test]).astype(np.float32) / 255 + y = np.concatenate([y_train, y_test]).astype(np.int32).squeeze() + return x, y + + +def plot_curve_with_area(x, y, xlabel, ylabel, ax, label, title=None): + ax.plot([0, 1], [0, 1], 'k-', lw=1.0) + ax.plot(x, y, lw=2, label=label) + ax.set(xlabel=xlabel, ylabel=ylabel) + ax.set(aspect=1, xscale='log', yscale='log') + ax.title.set_text(title) + + +def get_stat_and_loss_aug(model, x, y, batch_size=4096): + """A helper function to get the statistics and losses. + + Here we get the statistics and losses for the original and + horizontally flipped image, as we are going to train the model with + random horizontal flip. + + Args: + model: model to make prediction + x: samples + y: true labels of samples (integer valued) + batch_size: the batch size for model.predict + + Returns: + the statistics and cross-entropy losses + """ + losses, stat = [], [] + for data in [x, x[:, :, ::-1, :]]: + prob = amia.convert_logit_to_prob( + model.predict(data, batch_size=batch_size)) + losses.append(utils.log_loss(y, prob)) + stat.append(amia.calculate_statistic(prob, y, convert_to_prob=False)) + return np.vstack(stat).transpose(1, 0), np.vstack(losses).transpose(1, 0) + + +def main(unused_argv): + del unused_argv # unused argument + seed = 123 + np.random.seed(seed) + if _MODEL_DIR.value and not os.path.exists(_MODEL_DIR.value): + os.mkdir(_MODEL_DIR.value) + + # Load data. + x, y = load_cifar10() + n = x.shape[0] + + # Train the target and shadow models. We will use one of the model in `models` + # as target and the rest as shadow. + # Here we use the same architecture and optimizer. In practice, they might + # differ between the target and shadow models. + in_indices = [] # a list of in-training indices for all models + stat = [] # a list of statistics for all models + losses = [] # a list of losses for all models + for i in range(_NUM_SHADOWS.value + 1): + if _MODEL_DIR.value: + model_path = os.path.join( + _MODEL_DIR.value, + f'model{i}_lr{_LR.value}_b{_BATCH.value}_e{_EPOCHS.value}_sd{seed}.h5' + ) + + # Generate a binary array indicating which example to include for training + in_indices.append(np.random.binomial(1, 0.5, n).astype(bool)) + + model = small_cnn() + if _MODEL_DIR.value and os.path.exists(model_path): # Load if exists + model(x[:1]) # use this to make the `load_weights` work + model.load_weights(model_path) + print(f'Loaded model #{i} with {in_indices[-1].sum()} examples.') + else: # Otherwise, train the model + model.compile( + optimizer=tf.keras.optimizers.SGD(_LR.value, momentum=0.9), + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy']) + model.fit( + x[in_indices[-1]], + y[in_indices[-1]], + validation_data=(x[~in_indices[-1]], y[~in_indices[-1]]), + epochs=_EPOCHS.value, + batch_size=_BATCH.value, + verbose=2) + if _MODEL_DIR.value: + model.save_weights(model_path) + print(f'Trained model #{i} with {in_indices[-1].sum()} examples.') + + # Get the statistics of the current model. + s, l = get_stat_and_loss_aug(model, x, y) + stat.append(s) + losses.append(l) + + # Avoid OOM + tf.keras.backend.clear_session() + gc.collect() + + # Now we do MIA for each model + for idx in range(_NUM_SHADOWS.value + 1): + print(f'Target model is #{idx}') + stat_target = stat[idx] # statistics of target model, shape (n, k) + in_indices_target = in_indices[idx] # ground-truth membership, shape (n,) + + # `stat_shadow` contains statistics of the shadow models, with shape + # (num_shadows, n, k). `in_indices_shadow` contains membership of the shadow + # models, with shape (num_shadows, n). We will use them to get a list + # `stat_in` and a list `stat_out`, where stat_in[j] (resp. stat_out[j]) is a + # (m, k) array, for m being the number of shadow models trained with + # (resp. without) the j-th example, and k being the number of augmentations + # (2 in our case). + stat_shadow = np.array(stat[:idx] + stat[idx + 1:]) + in_indices_shadow = np.array(in_indices[:idx] + in_indices[idx + 1:]) + stat_in = [stat_shadow[:, j][in_indices_shadow[:, j]] for j in range(n)] + stat_out = [stat_shadow[:, j][~in_indices_shadow[:, j]] for j in range(n)] + + # Compute the scores and use them for MIA + scores = amia.compute_score_lira( + stat_target, stat_in, stat_out, fix_variance=True) + attack_input = AttackInputData( + loss_train=scores[in_indices_target], + loss_test=scores[~in_indices_target]) + result_lira = mia.run_attacks(attack_input).single_attack_results[0] + print('Better MIA attack with Gaussian:', + f'auc = {result_lira.get_auc():.4f}', + f'adv = {result_lira.get_attacker_advantage():.4f}') + + # We also try using `compute_score_offset` to compute the score. We take + # the negative of the score, because higher statistics corresponds to higher + # probability for in-training, which is the opposite of loss. + scores = -amia.compute_score_offset(stat_target, stat_in, stat_out) + attack_input = AttackInputData( + loss_train=scores[in_indices_target], + loss_test=scores[~in_indices_target]) + result_offset = mia.run_attacks(attack_input).single_attack_results[0] + print('Better MIA attack with offset:', + f'auc = {result_offset.get_auc():.4f}', + f'adv = {result_offset.get_attacker_advantage():.4f}') + + # Compare with the baseline MIA using the loss of the target model + loss_target = losses[idx][:, 0] + attack_input = AttackInputData( + loss_train=loss_target[in_indices_target], + loss_test=loss_target[~in_indices_target]) + result_baseline = mia.run_attacks(attack_input).single_attack_results[0] + print('Baseline MIA attack:', f'auc = {result_baseline.get_auc():.4f}', + f'adv = {result_baseline.get_attacker_advantage():.4f}') + + # Plot and save the AUC curves for the three methods. + _, ax = plt.subplots(1, 1, figsize=(5, 5)) + for res, title in zip([result_baseline, result_lira, result_offset], + ['baseline', 'LiRA', 'offset']): + label = f'{title} auc={res.get_auc():.4f}' + mia_plotting.plot_roc_curve( + res.roc_curve, + functools.partial(plot_curve_with_area, ax=ax, label=label)) + plt.legend() + plt.savefig('advanced_mia_demo.png') + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_test.py new file mode 100644 index 0000000..1865d84 --- /dev/null +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_test.py @@ -0,0 +1,198 @@ +# Copyright 2022, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for advanced_mia.""" + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia + + +class TestScoreOffset(parameterized.TestCase): + """Tests compute_score_offset.""" + + def setUp(self): + super().setUp() + self.stat_target = np.array([[-0.1, 0.1, 0], [0, 0, 27], + [0, 0, 0]]) # 3 samples with 3 augmentations + self.stat_in = [ + np.array([[1, 2, -3.]]), # 1 shadow + np.array([[-2., 4, 6], [0, 0, 0], [5, -7, -9]]), # 3 shadow + np.empty((0, 3)) + ] # no shadow + self.stat_out = [-s + 10 for s in self.stat_in] + + @parameterized.named_parameters( + ('both_mean', 'both', 'mean', np.array([-5., 4., -5.])), + ('both_median', 'both', 'median', np.array([-5., 4., -5.])), + ('in_median', 'in', 'median', np.array([0., 9., 0.])), + ('out_median', 'out', 'median', np.array([-10., -1., -10.])), + ('in_mean', 'in', 'mean', np.array([0, 28. / 3, 1. / 6])), + ('out_mean', 'out', 'mean', np.array([-10, -4. / 3, -61. / 6]))) + def test_compute_score_offset(self, option, median_or_mean, expected): + scores = amia.compute_score_offset(self.stat_target, self.stat_in, + self.stat_out, option, median_or_mean) + np.testing.assert_allclose(scores, expected, atol=1e-7) + # If `option` is "in" (resp. out), test with `stat_out` (resp. `stat_out`) + # setting to empty list. + if option == 'in': + scores = amia.compute_score_offset(self.stat_target, self.stat_in, [], + option, median_or_mean) + np.testing.assert_allclose(scores, expected, atol=1e-7) + elif option == 'out': + scores = amia.compute_score_offset(self.stat_target, [], self.stat_out, + option, median_or_mean) + np.testing.assert_allclose(scores, expected, atol=1e-7) + + +class TestLiRA(parameterized.TestCase): + """Tests compute_score_lira.""" + + @parameterized.named_parameters( + ('in_median', 'in', False, 'median', + np.array([1.41893853, 1.41893853, 3.72152363, 2.72537178])), + ('in_mean', 'in', False, 'mean', + np.array([1.41893853, 1.41893853, 3.72152363, 2.72537178])), + ('out_median', 'out', False, 'median', + -np.array([1.41893853, 1.41893853, 3.72152363, 2.72537178])), + ('out_mean', 'out', False, 'mean', + -np.array([1.41893853, 1.41893853, 3.72152363, 2.72537178])), + ('in_median_fix', 'in', True, 'median', + np.array([2.69682468, 2.69682468, 4.15270703, 2.87983121])), + ('in_mean_fix', 'in', True, 'mean', + np.array([2.69682468, 2.69682468, 4.15270703, 2.87983121])), + ('out_median_fix', 'out', True, 'median', + -np.array([2.69682468, 2.69682468, 4.15270703, 2.87983121])), + ('out_mean_fix', 'out', True, 'mean', + -np.array([2.69682468, 2.69682468, 4.15270703, 2.87983121])), + ('both_median_fix', 'both', True, 'median', np.array([0, 0, 0, 0.])), + ('both_mean_fix', 'both', True, 'mean', np.array([0, 0, 0, 0.])), + ('both_median', 'both', False, 'median', np.array([0, 0, 0, 0.])), + ('both_mean', 'both', False, 'mean', np.array([0, 0, 0, 0.])), + ) + def test_with_one_augmentation(self, option, fix_variance, median_or_mean, + expected): + stat_target = np.array([[1.], [0.], [0.], [0.]]) + stat_in = [ + np.array([[-1], [1.]]), + np.array([[0], [2.]]), + np.array([[0], [20.]]), + np.empty((0, 1)) + ] + stat_out = [-s for s in stat_in] + + scores = amia.compute_score_lira(stat_target, stat_in, stat_out, option, + fix_variance, median_or_mean) + np.testing.assert_allclose(scores, expected, atol=1e-7) + # If `option` is "in" (resp. out), test with `stat_out` (resp. `stat_out`) + # setting to empty list. + if option == 'in': + scores = amia.compute_score_lira(stat_target, stat_in, [], option, + fix_variance, median_or_mean) + np.testing.assert_allclose(scores, expected, atol=1e-7) + elif option == 'out': + scores = amia.compute_score_lira(stat_target, [], stat_out, option, + fix_variance, median_or_mean) + np.testing.assert_allclose(scores, expected, atol=1e-7) + + @parameterized.named_parameters( + ('in_median', 'in', False, 'median', 2.57023108), + ('in_mean', 'in', False, 'mean', 2.57023108), + ('out_median', 'out', False, 'median', -2.57023108), + ('out_mean', 'out', False, 'mean', -2.57023108), + ('both_median', 'both', False, 'median', 0), + ('both_mean', 'both', False, 'mean', 0)) + def test_two_augmentations(self, option, fix_variance, median_or_mean, + expected): + stat_target = np.array([[1., 0.]]) + stat_in = [np.array([[-1, 0], [1., 20]])] + stat_out = [-s for s in stat_in] + + scores = amia.compute_score_lira(stat_target, stat_in, stat_out, option, + fix_variance, median_or_mean) + np.testing.assert_allclose(scores, expected, atol=1e-7) + # If `option` is "in" (resp. out), test with `stat_out` (resp. `stat_out`) + # setting to empty list. + if option == 'in': + scores = amia.compute_score_lira(stat_target, stat_in, [], option, + fix_variance, median_or_mean) + np.testing.assert_allclose(scores, expected, atol=1e-7) + elif option == 'out': + scores = amia.compute_score_lira(stat_target, [], stat_out, option, + fix_variance, median_or_mean) + np.testing.assert_allclose(scores, expected, atol=1e-7) + + +class TestLogitProbConversion(absltest.TestCase): + """Test convert_logit_to_prob.""" + + def test_convert_logit_to_prob(self): + """Test convert_logit_to_prob.""" + logit = np.array([[10, -1, 0.], [-10, 0, -11]]) + prob = amia.convert_logit_to_prob(logit) + expected = np.array([[9.99937902e-01, 1.67006637e-05, 4.53971105e-05], + [4.53971105e-05, 9.99937902e-01, 1.67006637e-05]]) + np.testing.assert_allclose(prob, expected, atol=1e-5) + + +class TestCalculateStatistic(absltest.TestCase): + """Test calculate_statistic.""" + + def test_calculate_statistic_logit(self): + """Test calculate_statistic with input as logit.""" + is_logits = True + logit = np.array([[1, 2, -3.], [-1, 1, 0]]) + # expected probability vector + # array([[0.26762315, 0.72747516, 0.00490169], + # [0.09003057, 0.66524096, 0.24472847]]) + labels = np.array([1, 2]) + + stat = amia.calculate_statistic(logit, labels, is_logits, 'conf with prob') + np.testing.assert_allclose(stat, np.array([0.72747516, 0.24472847])) + + stat = amia.calculate_statistic(logit, labels, is_logits, 'xe') + np.testing.assert_allclose(stat, np.array([0.31817543, 1.40760596])) + + stat = amia.calculate_statistic(logit, labels, is_logits, 'logit') + np.testing.assert_allclose(stat, np.array([0.98185009, -1.12692802])) + + stat = amia.calculate_statistic(logit, labels, is_logits, 'conf with logit') + np.testing.assert_allclose(stat, np.array([2, 0.])) + + stat = amia.calculate_statistic(logit, labels, is_logits, 'hinge') + np.testing.assert_allclose(stat, np.array([1, -1.])) + + def test_calculate_statistic_prob(self): + """Test calculate_statistic with input as probability vector.""" + is_logits = False + prob = np.array([[0.1, 0.85, 0.05], [0.1, 0.5, 0.4]]) + labels = np.array([1, 2]) + + stat = amia.calculate_statistic(prob, labels, is_logits, 'conf with prob') + np.testing.assert_allclose(stat, np.array([0.85, 0.4])) + + stat = amia.calculate_statistic(prob, labels, is_logits, 'xe') + np.testing.assert_allclose(stat, np.array([0.16251893, 0.91629073])) + + stat = amia.calculate_statistic(prob, labels, is_logits, 'logit') + np.testing.assert_allclose(stat, np.array([1.73460106, -0.40546511])) + + np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels, + is_logits, 'conf with logit') + np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels, + is_logits, 'hinge') + + +if __name__ == '__main__': + absltest.main()