Functions for advanced MIAs.
PiperOrigin-RevId: 428111799
This commit is contained in:
parent
13a79f419f
commit
560926ea22
4 changed files with 732 additions and 39 deletions
|
@ -1,16 +1,16 @@
|
||||||
# Membership inference attack
|
# Membership inference attack
|
||||||
|
|
||||||
A good privacy-preserving model learns from the training data, but
|
A good privacy-preserving model learns from the training data, but doesn't
|
||||||
doesn't memorize it. This library provides empirical tests for measuring
|
memorize it. This library provides empirical tests for measuring potential
|
||||||
potential memorization.
|
memorization.
|
||||||
|
|
||||||
Technically, the tests build classifiers that infer whether a particular sample
|
Technically, the tests build classifiers that infer whether a particular sample
|
||||||
was present in the training set. The more accurate such classifier is, the more
|
was present in the training set. The more accurate such classifier is, the more
|
||||||
memorization is present and thus the less privacy-preserving the model is.
|
memorization is present and thus the less privacy-preserving the model is.
|
||||||
|
|
||||||
The privacy vulnerability (or memorization potential) is measured
|
The privacy vulnerability (or memorization potential) is measured via the area
|
||||||
via the area under the ROC-curve (`auc`) or via max{|fpr - tpr|} (`advantage`)
|
under the ROC-curve (`auc`) or via max{|fpr - tpr|} (`advantage`) of the attack
|
||||||
of the attack classifier. These measures are very closely related.
|
classifier. These measures are very closely related.
|
||||||
|
|
||||||
The tests provided by the library are "black box". That is, only the outputs of
|
The tests provided by the library are "black box". That is, only the outputs of
|
||||||
the model are used (e.g., losses, logits, predictions). Neither model internals
|
the model are used (e.g., losses, logits, predictions). Neither model internals
|
||||||
|
@ -69,7 +69,8 @@ print(attacks_result.summary())
|
||||||
|
|
||||||
### Other codelabs
|
### Other codelabs
|
||||||
|
|
||||||
Please head over to the [codelabs](https://github.com/tensorflow/privacy/tree/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs)
|
Please head over to the
|
||||||
|
[codelabs](https://github.com/tensorflow/privacy/tree/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs)
|
||||||
section for an overview of the library in action.
|
section for an overview of the library in action.
|
||||||
|
|
||||||
### Advanced usage
|
### Advanced usage
|
||||||
|
@ -77,11 +78,10 @@ section for an overview of the library in action.
|
||||||
#### Specifying attacks to run
|
#### Specifying attacks to run
|
||||||
|
|
||||||
Sometimes, we have more information about the data, such as the logits and the
|
Sometimes, we have more information about the data, such as the logits and the
|
||||||
labels,
|
labels, and we may want to have finer-grained control of the attack, such as
|
||||||
and we may want to have finer-grained control of the attack, such as using more
|
using more complicated classifiers instead of the simple threshold attack, and
|
||||||
complicated classifiers instead of the simple threshold attack, and looks at the
|
looks at the attack results by examples' class. In thoses cases, we can provide
|
||||||
attack results by examples' class.
|
more information to `run_attacks`.
|
||||||
In thoses cases, we can provide more information to `run_attacks`.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
||||||
|
@ -109,15 +109,13 @@ attack_input = AttackInputData(
|
||||||
labels_test = labels_test)
|
labels_test = labels_test)
|
||||||
```
|
```
|
||||||
|
|
||||||
Instead of `logits`, you can also specify
|
Instead of `logits`, you can also specify `probs_train` and `probs_test` as the
|
||||||
`probs_train` and `probs_test` as the predicted probabilty vectors of each
|
predicted probabilty vectors of each example.
|
||||||
example.
|
|
||||||
|
|
||||||
Then, we specify some details of the attack.
|
Then, we specify some details of the attack. The first part includes the
|
||||||
The first part includes the specifications of the slicing of the data. For
|
specifications of the slicing of the data. For example, we may want to evaluate
|
||||||
example, we may want to evaluate the result on the whole dataset, or by class,
|
the result on the whole dataset, or by class, percentiles, or the correctness of
|
||||||
percentiles, or the correctness of the model's classification.
|
the model's classification. These can be specified by a `SlicingSpec` object.
|
||||||
These can be specified by a `SlicingSpec` object.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
slicing_spec = SlicingSpec(
|
slicing_spec = SlicingSpec(
|
||||||
|
@ -127,16 +125,13 @@ slicing_spec = SlicingSpec(
|
||||||
by_classification_correctness = True)
|
by_classification_correctness = True)
|
||||||
```
|
```
|
||||||
|
|
||||||
The second part specifies the classifiers for the attacker to use.
|
The second part specifies the classifiers for the attacker to use. Currently,
|
||||||
Currently, our API supports five classifiers, including
|
our API supports five classifiers, including `AttackType.THRESHOLD_ATTACK` for
|
||||||
`AttackType.THRESHOLD_ATTACK` for simple threshold attack,
|
simple threshold attack, `AttackType.LOGISTIC_REGRESSION`,
|
||||||
`AttackType.LOGISTIC_REGRESSION`,
|
`AttackType.MULTI_LAYERED_PERCEPTRON`, `AttackType.RANDOM_FOREST`, and
|
||||||
`AttackType.MULTI_LAYERED_PERCEPTRON`,
|
`AttackType.K_NEAREST_NEIGHBORS` which use the corresponding machine learning
|
||||||
`AttackType.RANDOM_FOREST`, and
|
models. For some model, different classifiers can yield pertty different
|
||||||
`AttackType.K_NEAREST_NEIGHBORS`
|
results. We can put multiple classifers in a list:
|
||||||
which use the corresponding machine learning models.
|
|
||||||
For some model, different classifiers can yield pertty different results.
|
|
||||||
We can put multiple classifers in a list:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
attack_types = [
|
attack_types = [
|
||||||
|
@ -187,7 +182,6 @@ print(attacks_result.summary(by_slices = True))
|
||||||
# THRESHOLD_ATTACK achieved an advantage of 0.38
|
# THRESHOLD_ATTACK achieved an advantage of 0.38
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
#### Viewing and plotting the attack results
|
#### Viewing and plotting the attack results
|
||||||
|
|
||||||
We have seen an example of using `summary()` to view the attack results as text.
|
We have seen an example of using `summary()` to view the attack results as text.
|
||||||
|
@ -199,6 +193,7 @@ To get the attack that achieves the maximum attacker advantage or AUC, we can do
|
||||||
max_auc_attacker = attacks_result.get_result_with_max_auc()
|
max_auc_attacker = attacks_result.get_result_with_max_auc()
|
||||||
max_advantage_attacker = attacks_result.get_result_with_max_attacker_advantage()
|
max_advantage_attacker = attacks_result.get_result_with_max_attacker_advantage()
|
||||||
```
|
```
|
||||||
|
|
||||||
Then, for individual attack, such as `max_auc_attacker`, we can check its type,
|
Then, for individual attack, such as `max_auc_attacker`, we can check its type,
|
||||||
attacker advantage and AUC by
|
attacker advantage and AUC by
|
||||||
|
|
||||||
|
@ -210,6 +205,7 @@ print("Attack type with max AUC: %s, AUC of %.2f, Attacker advantage of %.2f" %
|
||||||
# Example output:
|
# Example output:
|
||||||
# -> Attack type with max AUC: THRESHOLD_ATTACK, AUC of 0.75, Attacker advantage of 0.38
|
# -> Attack type with max AUC: THRESHOLD_ATTACK, AUC of 0.75, Attacker advantage of 0.38
|
||||||
```
|
```
|
||||||
|
|
||||||
We can also plot its ROC curve by
|
We can also plot its ROC curve by
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -217,6 +213,7 @@ import tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.plot
|
||||||
|
|
||||||
figure = plotting.plot_roc_curve(max_auc_attacker.roc_curve)
|
figure = plotting.plot_roc_curve(max_auc_attacker.roc_curve)
|
||||||
```
|
```
|
||||||
|
|
||||||
which would give a figure like the one below
|
which would give a figure like the one below
|
||||||
![roc_fig](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelab_roc_fig.png?raw=true)
|
![roc_fig](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelab_roc_fig.png?raw=true)
|
||||||
|
|
||||||
|
@ -241,17 +238,42 @@ print(attacks_result.calculate_pd_dataframe())
|
||||||
# 25 correctly_classfied False lr 0.370713 0.737148
|
# 25 correctly_classfied False lr 0.370713 0.737148
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Advanced Membership Inference Attacks
|
||||||
|
|
||||||
|
Threshold MIA uses the intuition that training samples usually have lower loss
|
||||||
|
than test samples, and it thus predict samples with loss lower than a threshold
|
||||||
|
as in-training / member. However, some data samples might be intrinsically
|
||||||
|
harder than others. For example, a hard sample might have pretty high loss even
|
||||||
|
when included in the training set, and an easy sample might get low loss even
|
||||||
|
when it's not. So using the same threshold for all samples might be suboptimal.
|
||||||
|
|
||||||
|
People have considered customizing the membership prediction criteria for
|
||||||
|
different examples by looking at how they behave when included in or excluded
|
||||||
|
from the training set. To do that, we can train a few shadow models with
|
||||||
|
training sets being different subsets of all samples. Then for each sample, we
|
||||||
|
will know what its loss looks like when it's a member or non-member. Now, we can
|
||||||
|
compare its loss from the target model to those from the shadow models. For
|
||||||
|
example, if the average loss is `x` when the sample is a member, and `y` when
|
||||||
|
it's not, we might adjust the target loss by subtracting `(x+y)/2`. We can
|
||||||
|
expect the adjusted losses of different samples to be more of the same scale
|
||||||
|
compared to the original target losses. This gives us potentially better
|
||||||
|
estimations for membership.
|
||||||
|
|
||||||
|
In `advanced_mia.py`, we provide the method described above, and another method
|
||||||
|
that uses a more advanced way, i.e. distribution fitting to estimate membership.
|
||||||
|
`advanced_mia_example.py` shows an example for doing the advanced membership
|
||||||
|
inference on a CIFAR-10 task.
|
||||||
|
|
||||||
### External guides / press mentions
|
### External guides / press mentions
|
||||||
|
|
||||||
* [Introductory blog post](https://franziska-boenisch.de/posts/2021/01/membership-inference/)
|
* [Introductory blog post](https://franziska-boenisch.de/posts/2021/01/membership-inference/)
|
||||||
to the theory and the library by Franziska Boenisch from the Fraunhofer AISEC
|
to the theory and the library by Franziska Boenisch from the Fraunhofer
|
||||||
institute.
|
AISEC institute.
|
||||||
* [Google AI Blog Post](https://ai.googleblog.com/2021/01/google-research-looking-back-at-2020.html#ResponsibleAI)
|
* [Google AI Blog Post](https://ai.googleblog.com/2021/01/google-research-looking-back-at-2020.html#ResponsibleAI)
|
||||||
* [TensorFlow Blog Post](https://blog.tensorflow.org/2020/06/introducing-new-privacy-testing-library.html)
|
* [TensorFlow Blog Post](https://blog.tensorflow.org/2020/06/introducing-new-privacy-testing-library.html)
|
||||||
* [VentureBeat article](https://venturebeat.com/2020/06/24/google-releases-experimental-tensorflow-module-that-tests-the-privacy-of-ai-models/)
|
* [VentureBeat article](https://venturebeat.com/2020/06/24/google-releases-experimental-tensorflow-module-that-tests-the-privacy-of-ai-models/)
|
||||||
* [Tech Xplore article](https://techxplore.com/news/2020-06-google-tensorflow-privacy-module.html)
|
* [Tech Xplore article](https://techxplore.com/news/2020-06-google-tensorflow-privacy-module.html)
|
||||||
|
|
||||||
|
|
||||||
## Contact / Feedback
|
## Contact / Feedback
|
||||||
|
|
||||||
Fill out this
|
Fill out this
|
||||||
|
|
|
@ -0,0 +1,254 @@
|
||||||
|
# Copyright 2022, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Functions for advanced membership inference attacks."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from typing import Sequence, Union
|
||||||
|
import numpy as np
|
||||||
|
import scipy.stats
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils import log_loss
|
||||||
|
|
||||||
|
|
||||||
|
def replace_nan_with_column_mean(a: np.ndarray):
|
||||||
|
"""Replaces each NaN with the mean of the corresponding column."""
|
||||||
|
mean = np.nanmean(a, axis=0) # get the column-wise mean
|
||||||
|
for i in range(a.shape[1]):
|
||||||
|
np.nan_to_num(a[:, i], copy=False, nan=mean[i])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_score_offset(stat_target: Union[np.ndarray, Sequence[float]],
|
||||||
|
stat_in: Sequence[np.ndarray],
|
||||||
|
stat_out: Sequence[np.ndarray],
|
||||||
|
option: str = 'both',
|
||||||
|
median_or_mean: str = 'median') -> np.ndarray:
|
||||||
|
"""Computes score of each sample as stat_target - some offset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stat_target: a list or numpy array where stat_target[i] is the statistics of
|
||||||
|
example i computed from the target model. stat_target[i] is an array of k
|
||||||
|
scalars for k being the number of augmentations for each sample.
|
||||||
|
stat_in: a list where stat_in[i] is the in-training statistics of example i.
|
||||||
|
stat_in[i] is a m by k numpy array where m is the number of shadow models
|
||||||
|
and k is the number of augmentations for each sample. m can be different
|
||||||
|
for different examples.
|
||||||
|
stat_out: a list where stat_out[i] is the out-training statistics of example
|
||||||
|
i. stat_out[i] is a m by k numpy array where m is the number of shadow
|
||||||
|
models and k is the number of augmentations for each sample.
|
||||||
|
option: using stat_in ("in"), stat_out ("out"), or both ("both").
|
||||||
|
median_or_mean: use median or mean across shadow models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The score of each sample as stat_target - some offset, where offset is
|
||||||
|
computed with stat_in, stat_out, or both depending on the option.
|
||||||
|
The relation between the score and the membership depends on that
|
||||||
|
between stat_target and membership.
|
||||||
|
"""
|
||||||
|
if option not in ['both', 'in', 'out']:
|
||||||
|
raise ValueError('option should be "both", "in", or "out".')
|
||||||
|
if median_or_mean not in ['median', 'mean']:
|
||||||
|
raise ValueError('median_or_mean should be either "median" or "mean".')
|
||||||
|
if option in ['in', 'both']:
|
||||||
|
if any([s.ndim != 2 for s in stat_in]):
|
||||||
|
raise ValueError('Each element in stat_in should be a 2-d numpy array.')
|
||||||
|
if any([s.shape[1] != stat_in[0].shape[1] for s in stat_in]):
|
||||||
|
raise ValueError('Each element in stat_in should have the same size '
|
||||||
|
'in the second dimension.')
|
||||||
|
if option in ['out', 'both']:
|
||||||
|
if any([s.ndim != 2 for s in stat_out]):
|
||||||
|
raise ValueError('Each element in stat_out should be a 2-d numpy array.')
|
||||||
|
if any([s.shape[1] != stat_out[0].shape[1] for s in stat_out]):
|
||||||
|
raise ValueError('Each element in stat_out should have the same size '
|
||||||
|
'in the second dimension.')
|
||||||
|
func_avg = functools.partial(
|
||||||
|
np.nanmedian if median_or_mean == 'median' else np.nanmean, axis=0)
|
||||||
|
|
||||||
|
if option == 'both': # use the average of the in-score and out-score
|
||||||
|
avg_in = np.array(list(map(func_avg, stat_in)))
|
||||||
|
avg_out = np.array(list(map(func_avg, stat_out)))
|
||||||
|
# use average in case of NaN
|
||||||
|
replace_nan_with_column_mean(avg_in)
|
||||||
|
replace_nan_with_column_mean(avg_out)
|
||||||
|
offset = (avg_in + avg_out) / 2
|
||||||
|
elif option == 'in': # use in-score only
|
||||||
|
offset = np.array(list(map(func_avg, stat_in)))
|
||||||
|
replace_nan_with_column_mean(offset)
|
||||||
|
else: # use out-score only
|
||||||
|
offset = np.array(list(map(func_avg, stat_out)))
|
||||||
|
replace_nan_with_column_mean(offset)
|
||||||
|
scores = (stat_target - offset).mean(axis=1)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def compute_score_lira(stat_target: Union[np.ndarray, Sequence[float]],
|
||||||
|
stat_in: Sequence[np.ndarray],
|
||||||
|
stat_out: Sequence[np.ndarray],
|
||||||
|
option: str = 'both',
|
||||||
|
fix_variance: bool = False,
|
||||||
|
median_or_mean: str = 'median') -> np.ndarray:
|
||||||
|
"""Computes score of each sample using Gaussian distribution fitting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stat_target: a list or numpy array where stat_target[i] is the statistics of
|
||||||
|
example i computed from the target model. stat_target[i] is an array of k
|
||||||
|
scalars for k being the number of augmentations for each sample.
|
||||||
|
stat_in: a list where stat_in[i] is the in-training statistics of example i.
|
||||||
|
stat_in[i] is a m by k numpy array where m is the number of shadow models
|
||||||
|
and k is the number of augmentations for each sample. m can be different
|
||||||
|
for different examples.
|
||||||
|
stat_out: a list where stat_out[i] is the out-training statistics of example
|
||||||
|
i. stat_out[i] is a m by k numpy array where m is the number of shadow
|
||||||
|
models and k is the number of augmentations for each sample.
|
||||||
|
option: using stat_in ("in"), stat_out ("out"), or both ("both").
|
||||||
|
fix_variance: whether to use the same variance for all examples.
|
||||||
|
median_or_mean: use median or mean across shadow models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
log(Pr(out)) - log(Pr(in)), log(Pr(out)), or -log(Pr(in)) depending on the
|
||||||
|
option. In-training sample is expected to have small value.
|
||||||
|
The idea is from https://arxiv.org/pdf/2112.03570.pdf.
|
||||||
|
"""
|
||||||
|
# median of statistics across shadow models
|
||||||
|
if option not in ['both', 'in', 'out']:
|
||||||
|
raise ValueError('option should be "both", "in", or "out".')
|
||||||
|
if median_or_mean not in ['median', 'mean']:
|
||||||
|
raise ValueError('median_or_mean should be either "median" or "mean".')
|
||||||
|
if option in ['in', 'both']:
|
||||||
|
if any([s.ndim != 2 for s in stat_in]):
|
||||||
|
raise ValueError('Each element in stat_in should be a 2-d numpy array.')
|
||||||
|
if any([s.shape[1] != stat_in[0].shape[1] for s in stat_in]):
|
||||||
|
raise ValueError('Each element in stat_in should have the same size '
|
||||||
|
'in the second dimension.')
|
||||||
|
if option in ['out', 'both']:
|
||||||
|
if any([s.ndim != 2 for s in stat_out]):
|
||||||
|
raise ValueError('Each element in stat_out should be a 2-d numpy array.')
|
||||||
|
if any([s.shape[1] != stat_out[0].shape[1] for s in stat_out]):
|
||||||
|
raise ValueError('Each element in stat_out should have the same size '
|
||||||
|
'in the second dimension.')
|
||||||
|
|
||||||
|
func_avg = functools.partial(
|
||||||
|
np.nanmedian if median_or_mean == 'median' else np.nanmean, axis=0)
|
||||||
|
if option in ['in', 'both']:
|
||||||
|
avg_in = np.array(list(map(func_avg, stat_in))) # n by k array
|
||||||
|
replace_nan_with_column_mean(avg_in) # use column average in case of NaN
|
||||||
|
if option in ['out', 'both']:
|
||||||
|
avg_out = np.array(list(map(func_avg, stat_out)))
|
||||||
|
replace_nan_with_column_mean(avg_out)
|
||||||
|
|
||||||
|
if fix_variance:
|
||||||
|
# standard deviation of statistics across shadow models and examples
|
||||||
|
if option in ['in', 'both']:
|
||||||
|
std_in = np.nanstd(
|
||||||
|
np.concatenate([l - m[np.newaxis] for l, m in zip(stat_in, avg_in)]))
|
||||||
|
if option in ['out', 'both']:
|
||||||
|
std_out = np.nanstd(
|
||||||
|
np.concatenate([l - m[np.newaxis] for l, m in zip(stat_out, avg_out)
|
||||||
|
]))
|
||||||
|
else:
|
||||||
|
# standard deviation of statistics across shadow models
|
||||||
|
func_std = functools.partial(np.nanstd, axis=0)
|
||||||
|
if option in ['in', 'both']:
|
||||||
|
std_in = np.array(list(map(func_std, stat_in)))
|
||||||
|
replace_nan_with_column_mean(std_in)
|
||||||
|
if option in ['out', 'both']:
|
||||||
|
std_out = np.array(list(map(func_std, stat_out)))
|
||||||
|
replace_nan_with_column_mean(std_out)
|
||||||
|
|
||||||
|
stat_target = np.array(stat_target)
|
||||||
|
if option in ['in', 'both']:
|
||||||
|
log_pr_in = scipy.stats.norm.logpdf(stat_target, avg_in, std_in + 1e-30)
|
||||||
|
if option in ['out', 'both']:
|
||||||
|
log_pr_out = scipy.stats.norm.logpdf(stat_target, avg_out, std_out + 1e-30)
|
||||||
|
|
||||||
|
if option == 'both':
|
||||||
|
scores = -(log_pr_in - log_pr_out).mean(axis=1)
|
||||||
|
elif option == 'in':
|
||||||
|
scores = -log_pr_in.mean(axis=1)
|
||||||
|
else:
|
||||||
|
scores = log_pr_out.mean(axis=1)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def convert_logit_to_prob(logit: np.ndarray) -> np.ndarray:
|
||||||
|
"""Converts logits to probability vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logit: n by c array where n is the number of samples and c is the number of
|
||||||
|
classes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The probability vectors as n by c array
|
||||||
|
"""
|
||||||
|
prob = logit - np.max(logit, axis=1, keepdims=True)
|
||||||
|
prob = np.array(np.exp(prob), dtype=np.float64)
|
||||||
|
prob = prob / np.sum(prob, axis=1, keepdims=True)
|
||||||
|
return prob
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_statistic(pred: np.ndarray,
|
||||||
|
labels: np.ndarray,
|
||||||
|
is_logits: bool = True,
|
||||||
|
option: str = 'logit',
|
||||||
|
small_value: float = 1e-45):
|
||||||
|
"""Calculates the statistics of each sample.
|
||||||
|
|
||||||
|
The statistics is:
|
||||||
|
for option="conf with prob", p, the probability of the true class;
|
||||||
|
for option="xe", the cross-entropy loss;
|
||||||
|
for option="logit", log(p / (1 - p));
|
||||||
|
for option="conf with logit", max(logits);
|
||||||
|
for option="hinge", logit of the true class - max(logits of the other
|
||||||
|
classes).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: the logits or probability vectors, depending on the value of is_logit.
|
||||||
|
An array of size n by c where n is the number of samples and c is the
|
||||||
|
number of classes
|
||||||
|
labels: true labels of samples (integer valued)
|
||||||
|
is_logits: whether pred is logits or probability vectors
|
||||||
|
option: confidence using probability, xe loss, logit of confidence,
|
||||||
|
confidence using logits, hinge loss
|
||||||
|
small_value: a small value to avoid numerical issue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the computed statistics as size n array
|
||||||
|
"""
|
||||||
|
if option not in [
|
||||||
|
'conf with prob', 'xe', 'logit', 'conf with logit', 'hinge'
|
||||||
|
]:
|
||||||
|
raise ValueError(
|
||||||
|
'option should be one of ["conf with prob", "xe", "logit", "conf with logit", "hinge"].'
|
||||||
|
)
|
||||||
|
if option in ['conf with logit', 'hinge']:
|
||||||
|
if not is_logits: # the input needs to be the logits
|
||||||
|
raise ValueError('To compute statistics with option "conf with logit" '
|
||||||
|
'or "hinge", the input must be logits instead of '
|
||||||
|
'probability vectors.')
|
||||||
|
elif is_logits:
|
||||||
|
pred = convert_logit_to_prob(pred)
|
||||||
|
|
||||||
|
n = labels.size # number of samples
|
||||||
|
if option in ['conf with prob', 'conf with logit']:
|
||||||
|
return pred[range(n), labels]
|
||||||
|
if option == 'xe':
|
||||||
|
return log_loss(labels, pred)
|
||||||
|
if option == 'logit':
|
||||||
|
p_true = pred[range(n), labels]
|
||||||
|
pred[range(n), labels] = 0
|
||||||
|
p_other = pred.sum(axis=1)
|
||||||
|
return np.log(p_true + small_value) - np.log(p_other + small_value)
|
||||||
|
if option == 'hinge':
|
||||||
|
l_true = pred[range(n), labels]
|
||||||
|
pred[range(n), labels] = -np.inf
|
||||||
|
return l_true - pred.max(axis=1)
|
||||||
|
raise ValueError
|
|
@ -0,0 +1,219 @@
|
||||||
|
# Copyright 2022, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""An example for using advanced_mia."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
_LR = flags.DEFINE_float('learning_rate', 0.02, 'Learning rate for training')
|
||||||
|
_BATCH = flags.DEFINE_integer('batch_size', 250, 'Batch size')
|
||||||
|
_EPOCHS = flags.DEFINE_integer('epochs', 20, 'Number of epochs')
|
||||||
|
_NUM_SHADOWS = flags.DEFINE_integer('num_shadows', 10,
|
||||||
|
'Number of shadow models.')
|
||||||
|
_MODEL_DIR = flags.DEFINE_string('model_dir', None, 'Model directory.')
|
||||||
|
|
||||||
|
|
||||||
|
def small_cnn():
|
||||||
|
"""Setup a small CNN for image classification."""
|
||||||
|
model = tf.keras.models.Sequential()
|
||||||
|
# Add a layer to do random horizontal augmentation.
|
||||||
|
model.add(tf.keras.layers.RandomFlip('horizontal'))
|
||||||
|
model.add(tf.keras.layers.Input(shape=(32, 32, 3)))
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu'))
|
||||||
|
model.add(tf.keras.layers.MaxPooling2D())
|
||||||
|
|
||||||
|
model.add(tf.keras.layers.Flatten())
|
||||||
|
model.add(tf.keras.layers.Dense(64, activation='relu'))
|
||||||
|
model.add(tf.keras.layers.Dense(10))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_cifar10():
|
||||||
|
"""Loads CIFAR10, with training and test combined."""
|
||||||
|
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
|
||||||
|
x = np.concatenate([x_train, x_test]).astype(np.float32) / 255
|
||||||
|
y = np.concatenate([y_train, y_test]).astype(np.int32).squeeze()
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
def plot_curve_with_area(x, y, xlabel, ylabel, ax, label, title=None):
|
||||||
|
ax.plot([0, 1], [0, 1], 'k-', lw=1.0)
|
||||||
|
ax.plot(x, y, lw=2, label=label)
|
||||||
|
ax.set(xlabel=xlabel, ylabel=ylabel)
|
||||||
|
ax.set(aspect=1, xscale='log', yscale='log')
|
||||||
|
ax.title.set_text(title)
|
||||||
|
|
||||||
|
|
||||||
|
def get_stat_and_loss_aug(model, x, y, batch_size=4096):
|
||||||
|
"""A helper function to get the statistics and losses.
|
||||||
|
|
||||||
|
Here we get the statistics and losses for the original and
|
||||||
|
horizontally flipped image, as we are going to train the model with
|
||||||
|
random horizontal flip.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: model to make prediction
|
||||||
|
x: samples
|
||||||
|
y: true labels of samples (integer valued)
|
||||||
|
batch_size: the batch size for model.predict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the statistics and cross-entropy losses
|
||||||
|
"""
|
||||||
|
losses, stat = [], []
|
||||||
|
for data in [x, x[:, :, ::-1, :]]:
|
||||||
|
prob = amia.convert_logit_to_prob(
|
||||||
|
model.predict(data, batch_size=batch_size))
|
||||||
|
losses.append(utils.log_loss(y, prob))
|
||||||
|
stat.append(amia.calculate_statistic(prob, y, convert_to_prob=False))
|
||||||
|
return np.vstack(stat).transpose(1, 0), np.vstack(losses).transpose(1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def main(unused_argv):
|
||||||
|
del unused_argv # unused argument
|
||||||
|
seed = 123
|
||||||
|
np.random.seed(seed)
|
||||||
|
if _MODEL_DIR.value and not os.path.exists(_MODEL_DIR.value):
|
||||||
|
os.mkdir(_MODEL_DIR.value)
|
||||||
|
|
||||||
|
# Load data.
|
||||||
|
x, y = load_cifar10()
|
||||||
|
n = x.shape[0]
|
||||||
|
|
||||||
|
# Train the target and shadow models. We will use one of the model in `models`
|
||||||
|
# as target and the rest as shadow.
|
||||||
|
# Here we use the same architecture and optimizer. In practice, they might
|
||||||
|
# differ between the target and shadow models.
|
||||||
|
in_indices = [] # a list of in-training indices for all models
|
||||||
|
stat = [] # a list of statistics for all models
|
||||||
|
losses = [] # a list of losses for all models
|
||||||
|
for i in range(_NUM_SHADOWS.value + 1):
|
||||||
|
if _MODEL_DIR.value:
|
||||||
|
model_path = os.path.join(
|
||||||
|
_MODEL_DIR.value,
|
||||||
|
f'model{i}_lr{_LR.value}_b{_BATCH.value}_e{_EPOCHS.value}_sd{seed}.h5'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate a binary array indicating which example to include for training
|
||||||
|
in_indices.append(np.random.binomial(1, 0.5, n).astype(bool))
|
||||||
|
|
||||||
|
model = small_cnn()
|
||||||
|
if _MODEL_DIR.value and os.path.exists(model_path): # Load if exists
|
||||||
|
model(x[:1]) # use this to make the `load_weights` work
|
||||||
|
model.load_weights(model_path)
|
||||||
|
print(f'Loaded model #{i} with {in_indices[-1].sum()} examples.')
|
||||||
|
else: # Otherwise, train the model
|
||||||
|
model.compile(
|
||||||
|
optimizer=tf.keras.optimizers.SGD(_LR.value, momentum=0.9),
|
||||||
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||||
|
metrics=['accuracy'])
|
||||||
|
model.fit(
|
||||||
|
x[in_indices[-1]],
|
||||||
|
y[in_indices[-1]],
|
||||||
|
validation_data=(x[~in_indices[-1]], y[~in_indices[-1]]),
|
||||||
|
epochs=_EPOCHS.value,
|
||||||
|
batch_size=_BATCH.value,
|
||||||
|
verbose=2)
|
||||||
|
if _MODEL_DIR.value:
|
||||||
|
model.save_weights(model_path)
|
||||||
|
print(f'Trained model #{i} with {in_indices[-1].sum()} examples.')
|
||||||
|
|
||||||
|
# Get the statistics of the current model.
|
||||||
|
s, l = get_stat_and_loss_aug(model, x, y)
|
||||||
|
stat.append(s)
|
||||||
|
losses.append(l)
|
||||||
|
|
||||||
|
# Avoid OOM
|
||||||
|
tf.keras.backend.clear_session()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Now we do MIA for each model
|
||||||
|
for idx in range(_NUM_SHADOWS.value + 1):
|
||||||
|
print(f'Target model is #{idx}')
|
||||||
|
stat_target = stat[idx] # statistics of target model, shape (n, k)
|
||||||
|
in_indices_target = in_indices[idx] # ground-truth membership, shape (n,)
|
||||||
|
|
||||||
|
# `stat_shadow` contains statistics of the shadow models, with shape
|
||||||
|
# (num_shadows, n, k). `in_indices_shadow` contains membership of the shadow
|
||||||
|
# models, with shape (num_shadows, n). We will use them to get a list
|
||||||
|
# `stat_in` and a list `stat_out`, where stat_in[j] (resp. stat_out[j]) is a
|
||||||
|
# (m, k) array, for m being the number of shadow models trained with
|
||||||
|
# (resp. without) the j-th example, and k being the number of augmentations
|
||||||
|
# (2 in our case).
|
||||||
|
stat_shadow = np.array(stat[:idx] + stat[idx + 1:])
|
||||||
|
in_indices_shadow = np.array(in_indices[:idx] + in_indices[idx + 1:])
|
||||||
|
stat_in = [stat_shadow[:, j][in_indices_shadow[:, j]] for j in range(n)]
|
||||||
|
stat_out = [stat_shadow[:, j][~in_indices_shadow[:, j]] for j in range(n)]
|
||||||
|
|
||||||
|
# Compute the scores and use them for MIA
|
||||||
|
scores = amia.compute_score_lira(
|
||||||
|
stat_target, stat_in, stat_out, fix_variance=True)
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
loss_train=scores[in_indices_target],
|
||||||
|
loss_test=scores[~in_indices_target])
|
||||||
|
result_lira = mia.run_attacks(attack_input).single_attack_results[0]
|
||||||
|
print('Better MIA attack with Gaussian:',
|
||||||
|
f'auc = {result_lira.get_auc():.4f}',
|
||||||
|
f'adv = {result_lira.get_attacker_advantage():.4f}')
|
||||||
|
|
||||||
|
# We also try using `compute_score_offset` to compute the score. We take
|
||||||
|
# the negative of the score, because higher statistics corresponds to higher
|
||||||
|
# probability for in-training, which is the opposite of loss.
|
||||||
|
scores = -amia.compute_score_offset(stat_target, stat_in, stat_out)
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
loss_train=scores[in_indices_target],
|
||||||
|
loss_test=scores[~in_indices_target])
|
||||||
|
result_offset = mia.run_attacks(attack_input).single_attack_results[0]
|
||||||
|
print('Better MIA attack with offset:',
|
||||||
|
f'auc = {result_offset.get_auc():.4f}',
|
||||||
|
f'adv = {result_offset.get_attacker_advantage():.4f}')
|
||||||
|
|
||||||
|
# Compare with the baseline MIA using the loss of the target model
|
||||||
|
loss_target = losses[idx][:, 0]
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
loss_train=loss_target[in_indices_target],
|
||||||
|
loss_test=loss_target[~in_indices_target])
|
||||||
|
result_baseline = mia.run_attacks(attack_input).single_attack_results[0]
|
||||||
|
print('Baseline MIA attack:', f'auc = {result_baseline.get_auc():.4f}',
|
||||||
|
f'adv = {result_baseline.get_attacker_advantage():.4f}')
|
||||||
|
|
||||||
|
# Plot and save the AUC curves for the three methods.
|
||||||
|
_, ax = plt.subplots(1, 1, figsize=(5, 5))
|
||||||
|
for res, title in zip([result_baseline, result_lira, result_offset],
|
||||||
|
['baseline', 'LiRA', 'offset']):
|
||||||
|
label = f'{title} auc={res.get_auc():.4f}'
|
||||||
|
mia_plotting.plot_roc_curve(
|
||||||
|
res.roc_curve,
|
||||||
|
functools.partial(plot_curve_with_area, ax=ax, label=label))
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig('advanced_mia_demo.png')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main)
|
|
@ -0,0 +1,198 @@
|
||||||
|
# Copyright 2022, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for advanced_mia."""
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoreOffset(parameterized.TestCase):
|
||||||
|
"""Tests compute_score_offset."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.stat_target = np.array([[-0.1, 0.1, 0], [0, 0, 27],
|
||||||
|
[0, 0, 0]]) # 3 samples with 3 augmentations
|
||||||
|
self.stat_in = [
|
||||||
|
np.array([[1, 2, -3.]]), # 1 shadow
|
||||||
|
np.array([[-2., 4, 6], [0, 0, 0], [5, -7, -9]]), # 3 shadow
|
||||||
|
np.empty((0, 3))
|
||||||
|
] # no shadow
|
||||||
|
self.stat_out = [-s + 10 for s in self.stat_in]
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('both_mean', 'both', 'mean', np.array([-5., 4., -5.])),
|
||||||
|
('both_median', 'both', 'median', np.array([-5., 4., -5.])),
|
||||||
|
('in_median', 'in', 'median', np.array([0., 9., 0.])),
|
||||||
|
('out_median', 'out', 'median', np.array([-10., -1., -10.])),
|
||||||
|
('in_mean', 'in', 'mean', np.array([0, 28. / 3, 1. / 6])),
|
||||||
|
('out_mean', 'out', 'mean', np.array([-10, -4. / 3, -61. / 6])))
|
||||||
|
def test_compute_score_offset(self, option, median_or_mean, expected):
|
||||||
|
scores = amia.compute_score_offset(self.stat_target, self.stat_in,
|
||||||
|
self.stat_out, option, median_or_mean)
|
||||||
|
np.testing.assert_allclose(scores, expected, atol=1e-7)
|
||||||
|
# If `option` is "in" (resp. out), test with `stat_out` (resp. `stat_out`)
|
||||||
|
# setting to empty list.
|
||||||
|
if option == 'in':
|
||||||
|
scores = amia.compute_score_offset(self.stat_target, self.stat_in, [],
|
||||||
|
option, median_or_mean)
|
||||||
|
np.testing.assert_allclose(scores, expected, atol=1e-7)
|
||||||
|
elif option == 'out':
|
||||||
|
scores = amia.compute_score_offset(self.stat_target, [], self.stat_out,
|
||||||
|
option, median_or_mean)
|
||||||
|
np.testing.assert_allclose(scores, expected, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLiRA(parameterized.TestCase):
|
||||||
|
"""Tests compute_score_lira."""
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('in_median', 'in', False, 'median',
|
||||||
|
np.array([1.41893853, 1.41893853, 3.72152363, 2.72537178])),
|
||||||
|
('in_mean', 'in', False, 'mean',
|
||||||
|
np.array([1.41893853, 1.41893853, 3.72152363, 2.72537178])),
|
||||||
|
('out_median', 'out', False, 'median',
|
||||||
|
-np.array([1.41893853, 1.41893853, 3.72152363, 2.72537178])),
|
||||||
|
('out_mean', 'out', False, 'mean',
|
||||||
|
-np.array([1.41893853, 1.41893853, 3.72152363, 2.72537178])),
|
||||||
|
('in_median_fix', 'in', True, 'median',
|
||||||
|
np.array([2.69682468, 2.69682468, 4.15270703, 2.87983121])),
|
||||||
|
('in_mean_fix', 'in', True, 'mean',
|
||||||
|
np.array([2.69682468, 2.69682468, 4.15270703, 2.87983121])),
|
||||||
|
('out_median_fix', 'out', True, 'median',
|
||||||
|
-np.array([2.69682468, 2.69682468, 4.15270703, 2.87983121])),
|
||||||
|
('out_mean_fix', 'out', True, 'mean',
|
||||||
|
-np.array([2.69682468, 2.69682468, 4.15270703, 2.87983121])),
|
||||||
|
('both_median_fix', 'both', True, 'median', np.array([0, 0, 0, 0.])),
|
||||||
|
('both_mean_fix', 'both', True, 'mean', np.array([0, 0, 0, 0.])),
|
||||||
|
('both_median', 'both', False, 'median', np.array([0, 0, 0, 0.])),
|
||||||
|
('both_mean', 'both', False, 'mean', np.array([0, 0, 0, 0.])),
|
||||||
|
)
|
||||||
|
def test_with_one_augmentation(self, option, fix_variance, median_or_mean,
|
||||||
|
expected):
|
||||||
|
stat_target = np.array([[1.], [0.], [0.], [0.]])
|
||||||
|
stat_in = [
|
||||||
|
np.array([[-1], [1.]]),
|
||||||
|
np.array([[0], [2.]]),
|
||||||
|
np.array([[0], [20.]]),
|
||||||
|
np.empty((0, 1))
|
||||||
|
]
|
||||||
|
stat_out = [-s for s in stat_in]
|
||||||
|
|
||||||
|
scores = amia.compute_score_lira(stat_target, stat_in, stat_out, option,
|
||||||
|
fix_variance, median_or_mean)
|
||||||
|
np.testing.assert_allclose(scores, expected, atol=1e-7)
|
||||||
|
# If `option` is "in" (resp. out), test with `stat_out` (resp. `stat_out`)
|
||||||
|
# setting to empty list.
|
||||||
|
if option == 'in':
|
||||||
|
scores = amia.compute_score_lira(stat_target, stat_in, [], option,
|
||||||
|
fix_variance, median_or_mean)
|
||||||
|
np.testing.assert_allclose(scores, expected, atol=1e-7)
|
||||||
|
elif option == 'out':
|
||||||
|
scores = amia.compute_score_lira(stat_target, [], stat_out, option,
|
||||||
|
fix_variance, median_or_mean)
|
||||||
|
np.testing.assert_allclose(scores, expected, atol=1e-7)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('in_median', 'in', False, 'median', 2.57023108),
|
||||||
|
('in_mean', 'in', False, 'mean', 2.57023108),
|
||||||
|
('out_median', 'out', False, 'median', -2.57023108),
|
||||||
|
('out_mean', 'out', False, 'mean', -2.57023108),
|
||||||
|
('both_median', 'both', False, 'median', 0),
|
||||||
|
('both_mean', 'both', False, 'mean', 0))
|
||||||
|
def test_two_augmentations(self, option, fix_variance, median_or_mean,
|
||||||
|
expected):
|
||||||
|
stat_target = np.array([[1., 0.]])
|
||||||
|
stat_in = [np.array([[-1, 0], [1., 20]])]
|
||||||
|
stat_out = [-s for s in stat_in]
|
||||||
|
|
||||||
|
scores = amia.compute_score_lira(stat_target, stat_in, stat_out, option,
|
||||||
|
fix_variance, median_or_mean)
|
||||||
|
np.testing.assert_allclose(scores, expected, atol=1e-7)
|
||||||
|
# If `option` is "in" (resp. out), test with `stat_out` (resp. `stat_out`)
|
||||||
|
# setting to empty list.
|
||||||
|
if option == 'in':
|
||||||
|
scores = amia.compute_score_lira(stat_target, stat_in, [], option,
|
||||||
|
fix_variance, median_or_mean)
|
||||||
|
np.testing.assert_allclose(scores, expected, atol=1e-7)
|
||||||
|
elif option == 'out':
|
||||||
|
scores = amia.compute_score_lira(stat_target, [], stat_out, option,
|
||||||
|
fix_variance, median_or_mean)
|
||||||
|
np.testing.assert_allclose(scores, expected, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogitProbConversion(absltest.TestCase):
|
||||||
|
"""Test convert_logit_to_prob."""
|
||||||
|
|
||||||
|
def test_convert_logit_to_prob(self):
|
||||||
|
"""Test convert_logit_to_prob."""
|
||||||
|
logit = np.array([[10, -1, 0.], [-10, 0, -11]])
|
||||||
|
prob = amia.convert_logit_to_prob(logit)
|
||||||
|
expected = np.array([[9.99937902e-01, 1.67006637e-05, 4.53971105e-05],
|
||||||
|
[4.53971105e-05, 9.99937902e-01, 1.67006637e-05]])
|
||||||
|
np.testing.assert_allclose(prob, expected, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCalculateStatistic(absltest.TestCase):
|
||||||
|
"""Test calculate_statistic."""
|
||||||
|
|
||||||
|
def test_calculate_statistic_logit(self):
|
||||||
|
"""Test calculate_statistic with input as logit."""
|
||||||
|
is_logits = True
|
||||||
|
logit = np.array([[1, 2, -3.], [-1, 1, 0]])
|
||||||
|
# expected probability vector
|
||||||
|
# array([[0.26762315, 0.72747516, 0.00490169],
|
||||||
|
# [0.09003057, 0.66524096, 0.24472847]])
|
||||||
|
labels = np.array([1, 2])
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(logit, labels, is_logits, 'conf with prob')
|
||||||
|
np.testing.assert_allclose(stat, np.array([0.72747516, 0.24472847]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(logit, labels, is_logits, 'xe')
|
||||||
|
np.testing.assert_allclose(stat, np.array([0.31817543, 1.40760596]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(logit, labels, is_logits, 'logit')
|
||||||
|
np.testing.assert_allclose(stat, np.array([0.98185009, -1.12692802]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(logit, labels, is_logits, 'conf with logit')
|
||||||
|
np.testing.assert_allclose(stat, np.array([2, 0.]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(logit, labels, is_logits, 'hinge')
|
||||||
|
np.testing.assert_allclose(stat, np.array([1, -1.]))
|
||||||
|
|
||||||
|
def test_calculate_statistic_prob(self):
|
||||||
|
"""Test calculate_statistic with input as probability vector."""
|
||||||
|
is_logits = False
|
||||||
|
prob = np.array([[0.1, 0.85, 0.05], [0.1, 0.5, 0.4]])
|
||||||
|
labels = np.array([1, 2])
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(prob, labels, is_logits, 'conf with prob')
|
||||||
|
np.testing.assert_allclose(stat, np.array([0.85, 0.4]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(prob, labels, is_logits, 'xe')
|
||||||
|
np.testing.assert_allclose(stat, np.array([0.16251893, 0.91629073]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(prob, labels, is_logits, 'logit')
|
||||||
|
np.testing.assert_allclose(stat, np.array([1.73460106, -0.40546511]))
|
||||||
|
|
||||||
|
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
|
||||||
|
is_logits, 'conf with logit')
|
||||||
|
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
|
||||||
|
is_logits, 'hinge')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
Loading…
Reference in a new issue