forked from 626_privacy/tensorflow_privacy
Add ability to use sample weights to the membership attack models, where they are supported by the underlying Scikit-Learn estimators. Only the Logistic Regression and Random Forest estimators support sample weights.
PiperOrigin-RevId: 478542133
This commit is contained in:
parent
feddd28a63
commit
3f6d0acdef
15 changed files with 552 additions and 78 deletions
|
@ -14,7 +14,7 @@
|
||||||
"""Functions for advanced membership inference attacks."""
|
"""Functions for advanced membership inference attacks."""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Sequence, Union
|
from typing import Optional, Sequence, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
from tensorflow_privacy.privacy.privacy_tests.utils import log_loss
|
from tensorflow_privacy.privacy.privacy_tests.utils import log_loss
|
||||||
|
@ -197,6 +197,7 @@ def convert_logit_to_prob(logit: np.ndarray) -> np.ndarray:
|
||||||
|
|
||||||
def calculate_statistic(pred: np.ndarray,
|
def calculate_statistic(pred: np.ndarray,
|
||||||
labels: np.ndarray,
|
labels: np.ndarray,
|
||||||
|
sample_weight: Optional[np.ndarray] = None,
|
||||||
is_logits: bool = True,
|
is_logits: bool = True,
|
||||||
option: str = 'logit',
|
option: str = 'logit',
|
||||||
small_value: float = 1e-45):
|
small_value: float = 1e-45):
|
||||||
|
@ -215,6 +216,10 @@ def calculate_statistic(pred: np.ndarray,
|
||||||
An array of size n by c where n is the number of samples and c is the
|
An array of size n by c where n is the number of samples and c is the
|
||||||
number of classes
|
number of classes
|
||||||
labels: true labels of samples (integer valued)
|
labels: true labels of samples (integer valued)
|
||||||
|
sample_weight: a vector of weights of shape (num_samples, ) that are
|
||||||
|
assigned to individual samples. If not provided, then each sample is
|
||||||
|
given unit weight. Only the LogisticRegressionAttacker and the
|
||||||
|
RandomForestAttacker support sample weights.
|
||||||
is_logits: whether pred is logits or probability vectors
|
is_logits: whether pred is logits or probability vectors
|
||||||
option: confidence using probability, xe loss, logit of confidence,
|
option: confidence using probability, xe loss, logit of confidence,
|
||||||
confidence using logits, hinge loss
|
confidence using logits, hinge loss
|
||||||
|
@ -241,7 +246,7 @@ def calculate_statistic(pred: np.ndarray,
|
||||||
if option in ['conf with prob', 'conf with logit']:
|
if option in ['conf with prob', 'conf with logit']:
|
||||||
return pred[range(n), labels]
|
return pred[range(n), labels]
|
||||||
if option == 'xe':
|
if option == 'xe':
|
||||||
return log_loss(labels, pred)
|
return log_loss(labels, pred, sample_weight=sample_weight)
|
||||||
if option == 'logit':
|
if option == 'logit':
|
||||||
p_true = pred[range(n), labels]
|
p_true = pred[range(n), labels]
|
||||||
pred[range(n), labels] = 0
|
pred[range(n), labels] = 0
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
import functools
|
import functools
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -69,7 +71,11 @@ def plot_curve_with_area(x, y, xlabel, ylabel, ax, label, title=None):
|
||||||
ax.title.set_text(title)
|
ax.title.set_text(title)
|
||||||
|
|
||||||
|
|
||||||
def get_stat_and_loss_aug(model, x, y, batch_size=4096):
|
def get_stat_and_loss_aug(model,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
sample_weight: Optional[np.ndarray] = None,
|
||||||
|
batch_size=4096):
|
||||||
"""A helper function to get the statistics and losses.
|
"""A helper function to get the statistics and losses.
|
||||||
|
|
||||||
Here we get the statistics and losses for the original and
|
Here we get the statistics and losses for the original and
|
||||||
|
@ -80,6 +86,10 @@ def get_stat_and_loss_aug(model, x, y, batch_size=4096):
|
||||||
model: model to make prediction
|
model: model to make prediction
|
||||||
x: samples
|
x: samples
|
||||||
y: true labels of samples (integer valued)
|
y: true labels of samples (integer valued)
|
||||||
|
sample_weight: a vector of weights of shape (n_samples, ) that are
|
||||||
|
assigned to individual samples. If not provided, then each sample is
|
||||||
|
given unit weight. Only the LogisticRegressionAttacker and the
|
||||||
|
RandomForestAttacker support sample weights.
|
||||||
batch_size: the batch size for model.predict
|
batch_size: the batch size for model.predict
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -89,8 +99,10 @@ def get_stat_and_loss_aug(model, x, y, batch_size=4096):
|
||||||
for data in [x, x[:, :, ::-1, :]]:
|
for data in [x, x[:, :, ::-1, :]]:
|
||||||
prob = amia.convert_logit_to_prob(
|
prob = amia.convert_logit_to_prob(
|
||||||
model.predict(data, batch_size=batch_size))
|
model.predict(data, batch_size=batch_size))
|
||||||
losses.append(utils.log_loss(y, prob))
|
losses.append(utils.log_loss(y, prob, sample_weight=sample_weight))
|
||||||
stat.append(amia.calculate_statistic(prob, y, convert_to_prob=False))
|
stat.append(
|
||||||
|
amia.calculate_statistic(
|
||||||
|
prob, y, sample_weight=sample_weight, convert_to_prob=False))
|
||||||
return np.vstack(stat).transpose(1, 0), np.vstack(losses).transpose(1, 0)
|
return np.vstack(stat).transpose(1, 0), np.vstack(losses).transpose(1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,6 +115,8 @@ def main(unused_argv):
|
||||||
|
|
||||||
# Load data.
|
# Load data.
|
||||||
x, y = load_cifar10()
|
x, y = load_cifar10()
|
||||||
|
# Sample weights are set to `None` by default, but can be changed here.
|
||||||
|
sample_weight = None
|
||||||
n = x.shape[0]
|
n = x.shape[0]
|
||||||
|
|
||||||
# Train the target and shadow models. We will use one of the model in `models`
|
# Train the target and shadow models. We will use one of the model in `models`
|
||||||
|
@ -144,7 +158,7 @@ def main(unused_argv):
|
||||||
print(f'Trained model #{i} with {in_indices[-1].sum()} examples.')
|
print(f'Trained model #{i} with {in_indices[-1].sum()} examples.')
|
||||||
|
|
||||||
# Get the statistics of the current model.
|
# Get the statistics of the current model.
|
||||||
s, l = get_stat_and_loss_aug(model, x, y)
|
s, l = get_stat_and_loss_aug(model, x, y, sample_weight)
|
||||||
stat.append(s)
|
stat.append(s)
|
||||||
losses.append(l)
|
losses.append(l)
|
||||||
|
|
||||||
|
@ -175,7 +189,9 @@ def main(unused_argv):
|
||||||
stat_target, stat_in, stat_out, fix_variance=True)
|
stat_target, stat_in, stat_out, fix_variance=True)
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
loss_train=scores[in_indices_target],
|
loss_train=scores[in_indices_target],
|
||||||
loss_test=scores[~in_indices_target])
|
loss_test=scores[~in_indices_target],
|
||||||
|
sample_weight_train=sample_weight,
|
||||||
|
sample_weight_test=sample_weight)
|
||||||
result_lira = mia.run_attacks(attack_input).single_attack_results[0]
|
result_lira = mia.run_attacks(attack_input).single_attack_results[0]
|
||||||
print('Advanced MIA attack with Gaussian:',
|
print('Advanced MIA attack with Gaussian:',
|
||||||
f'auc = {result_lira.get_auc():.4f}',
|
f'auc = {result_lira.get_auc():.4f}',
|
||||||
|
@ -187,7 +203,9 @@ def main(unused_argv):
|
||||||
scores = -amia.compute_score_offset(stat_target, stat_in, stat_out)
|
scores = -amia.compute_score_offset(stat_target, stat_in, stat_out)
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
loss_train=scores[in_indices_target],
|
loss_train=scores[in_indices_target],
|
||||||
loss_test=scores[~in_indices_target])
|
loss_test=scores[~in_indices_target],
|
||||||
|
sample_weight_train=sample_weight,
|
||||||
|
sample_weight_test=sample_weight)
|
||||||
result_offset = mia.run_attacks(attack_input).single_attack_results[0]
|
result_offset = mia.run_attacks(attack_input).single_attack_results[0]
|
||||||
print('Advanced MIA attack with offset:',
|
print('Advanced MIA attack with offset:',
|
||||||
f'auc = {result_offset.get_auc():.4f}',
|
f'auc = {result_offset.get_auc():.4f}',
|
||||||
|
@ -197,7 +215,9 @@ def main(unused_argv):
|
||||||
loss_target = losses[idx][:, 0]
|
loss_target = losses[idx][:, 0]
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
loss_train=loss_target[in_indices_target],
|
loss_train=loss_target[in_indices_target],
|
||||||
loss_test=loss_target[~in_indices_target])
|
loss_test=loss_target[~in_indices_target],
|
||||||
|
sample_weight_train=sample_weight,
|
||||||
|
sample_weight_test=sample_weight)
|
||||||
result_baseline = mia.run_attacks(attack_input).single_attack_results[0]
|
result_baseline = mia.run_attacks(attack_input).single_attack_results[0]
|
||||||
print('Baseline MIA attack:', f'auc = {result_baseline.get_auc():.4f}',
|
print('Baseline MIA attack:', f'auc = {result_baseline.get_auc():.4f}',
|
||||||
f'adv = {result_baseline.get_attacker_advantage():.4f}')
|
f'adv = {result_baseline.get_attacker_advantage():.4f}')
|
||||||
|
|
|
@ -158,19 +158,21 @@ class TestCalculateStatistic(absltest.TestCase):
|
||||||
# [0.09003057, 0.66524096, 0.24472847]])
|
# [0.09003057, 0.66524096, 0.24472847]])
|
||||||
labels = np.array([1, 2])
|
labels = np.array([1, 2])
|
||||||
|
|
||||||
stat = amia.calculate_statistic(logit, labels, is_logits, 'conf with prob')
|
stat = amia.calculate_statistic(logit, labels, None, is_logits,
|
||||||
|
'conf with prob')
|
||||||
np.testing.assert_allclose(stat, np.array([0.72747516, 0.24472847]))
|
np.testing.assert_allclose(stat, np.array([0.72747516, 0.24472847]))
|
||||||
|
|
||||||
stat = amia.calculate_statistic(logit, labels, is_logits, 'xe')
|
stat = amia.calculate_statistic(logit, labels, None, is_logits, 'xe')
|
||||||
np.testing.assert_allclose(stat, np.array([0.31817543, 1.40760596]))
|
np.testing.assert_allclose(stat, np.array([0.31817543, 1.40760596]))
|
||||||
|
|
||||||
stat = amia.calculate_statistic(logit, labels, is_logits, 'logit')
|
stat = amia.calculate_statistic(logit, labels, None, is_logits, 'logit')
|
||||||
np.testing.assert_allclose(stat, np.array([0.98185009, -1.12692802]))
|
np.testing.assert_allclose(stat, np.array([0.98185009, -1.12692802]))
|
||||||
|
|
||||||
stat = amia.calculate_statistic(logit, labels, is_logits, 'conf with logit')
|
stat = amia.calculate_statistic(logit, labels, None, is_logits,
|
||||||
|
'conf with logit')
|
||||||
np.testing.assert_allclose(stat, np.array([2, 0.]))
|
np.testing.assert_allclose(stat, np.array([2, 0.]))
|
||||||
|
|
||||||
stat = amia.calculate_statistic(logit, labels, is_logits, 'hinge')
|
stat = amia.calculate_statistic(logit, labels, None, is_logits, 'hinge')
|
||||||
np.testing.assert_allclose(stat, np.array([1, -1.]))
|
np.testing.assert_allclose(stat, np.array([1, -1.]))
|
||||||
|
|
||||||
def test_calculate_statistic_prob(self):
|
def test_calculate_statistic_prob(self):
|
||||||
|
@ -179,19 +181,74 @@ class TestCalculateStatistic(absltest.TestCase):
|
||||||
prob = np.array([[0.1, 0.85, 0.05], [0.1, 0.5, 0.4]])
|
prob = np.array([[0.1, 0.85, 0.05], [0.1, 0.5, 0.4]])
|
||||||
labels = np.array([1, 2])
|
labels = np.array([1, 2])
|
||||||
|
|
||||||
stat = amia.calculate_statistic(prob, labels, is_logits, 'conf with prob')
|
stat = amia.calculate_statistic(prob, labels, None, is_logits,
|
||||||
|
'conf with prob')
|
||||||
np.testing.assert_allclose(stat, np.array([0.85, 0.4]))
|
np.testing.assert_allclose(stat, np.array([0.85, 0.4]))
|
||||||
|
|
||||||
stat = amia.calculate_statistic(prob, labels, is_logits, 'xe')
|
stat = amia.calculate_statistic(prob, labels, None, is_logits, 'xe')
|
||||||
np.testing.assert_allclose(stat, np.array([0.16251893, 0.91629073]))
|
np.testing.assert_allclose(stat, np.array([0.16251893, 0.91629073]))
|
||||||
|
|
||||||
stat = amia.calculate_statistic(prob, labels, is_logits, 'logit')
|
stat = amia.calculate_statistic(prob, labels, None, is_logits, 'logit')
|
||||||
np.testing.assert_allclose(stat, np.array([1.73460106, -0.40546511]))
|
np.testing.assert_allclose(stat, np.array([1.73460106, -0.40546511]))
|
||||||
|
|
||||||
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
|
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
|
||||||
is_logits, 'conf with logit')
|
None, is_logits, 'conf with logit')
|
||||||
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
|
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
|
||||||
is_logits, 'hinge')
|
None, is_logits, 'hinge')
|
||||||
|
|
||||||
|
def test_calculate_statistic_logit_with_sample_weights(self):
|
||||||
|
"""Test calculate_statistic with input as logit."""
|
||||||
|
is_logits = True
|
||||||
|
logit = np.array([[1, 2, -3.], [-1, 1, 0]])
|
||||||
|
# expected probability vector
|
||||||
|
# array([[0.26762315, 0.72747516, 0.00490169],
|
||||||
|
# [0.09003057, 0.66524096, 0.24472847]])
|
||||||
|
labels = np.array([1, 2])
|
||||||
|
sample_weight = np.array([1.0, 0.5])
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits,
|
||||||
|
'conf with prob')
|
||||||
|
np.testing.assert_allclose(stat, np.array([0.72747516, 0.24472847]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits,
|
||||||
|
'xe')
|
||||||
|
np.testing.assert_allclose(stat, np.array([0.31817543, 0.70380298]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits,
|
||||||
|
'logit')
|
||||||
|
np.testing.assert_allclose(stat, np.array([0.98185009, -1.12692802]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits,
|
||||||
|
'conf with logit')
|
||||||
|
np.testing.assert_allclose(stat, np.array([2, 0.]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits,
|
||||||
|
'hinge')
|
||||||
|
np.testing.assert_allclose(stat, np.array([1, -1.]))
|
||||||
|
|
||||||
|
def test_calculate_statistic_prob_with_sample_weights(self):
|
||||||
|
"""Test calculate_statistic with input as probability vector."""
|
||||||
|
is_logits = False
|
||||||
|
prob = np.array([[0.1, 0.85, 0.05], [0.1, 0.5, 0.4]])
|
||||||
|
labels = np.array([1, 2])
|
||||||
|
sample_weight = np.array([1.0, 0.5])
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(prob, labels, sample_weight, is_logits,
|
||||||
|
'conf with prob')
|
||||||
|
np.testing.assert_allclose(stat, np.array([0.85, 0.4]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(prob, labels, sample_weight, is_logits,
|
||||||
|
'xe')
|
||||||
|
np.testing.assert_allclose(stat, np.array([0.16251893, 0.458145365]))
|
||||||
|
|
||||||
|
stat = amia.calculate_statistic(prob, labels, sample_weight, is_logits,
|
||||||
|
'logit')
|
||||||
|
np.testing.assert_allclose(stat, np.array([1.73460106, -0.40546511]))
|
||||||
|
|
||||||
|
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
|
||||||
|
None, is_logits, 'conf with logit')
|
||||||
|
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
|
||||||
|
None, is_logits, 'hinge')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -20,7 +20,7 @@ import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, Callable, Iterable, MutableSequence, Optional, Union
|
from typing import Any, Iterable, MutableSequence, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -203,6 +203,10 @@ class AttackInputData:
|
||||||
labels_train: Optional[np.ndarray] = None
|
labels_train: Optional[np.ndarray] = None
|
||||||
labels_test: Optional[np.ndarray] = None
|
labels_test: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
# Sample weights, if provided.
|
||||||
|
sample_weight_train: Optional[np.ndarray] = None
|
||||||
|
sample_weight_test: Optional[np.ndarray] = None
|
||||||
|
|
||||||
# Explicitly specified loss. If provided, this is used instead of deriving
|
# Explicitly specified loss. If provided, this is used instead of deriving
|
||||||
# loss from logits and labels
|
# loss from logits and labels
|
||||||
loss_train: Optional[np.ndarray] = None
|
loss_train: Optional[np.ndarray] = None
|
||||||
|
@ -219,8 +223,7 @@ class AttackInputData:
|
||||||
# string representation, or a callable.
|
# string representation, or a callable.
|
||||||
# If a callable is provided, it should take in two argument, the 1st is
|
# If a callable is provided, it should take in two argument, the 1st is
|
||||||
# labels, the 2nd is logits or probs.
|
# labels, the 2nd is logits or probs.
|
||||||
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], str,
|
loss_function: utils.LossFunctionCallable = utils.LossFunction.CROSS_ENTROPY
|
||||||
utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
|
|
||||||
# Whether `loss_function` will be called with logits or probs. If not set
|
# Whether `loss_function` will be called with logits or probs. If not set
|
||||||
# (None), will decide by availablity of logits and probs and logits is
|
# (None), will decide by availablity of logits and probs and logits is
|
||||||
# preferred when both are available.
|
# preferred when both are available.
|
||||||
|
@ -309,7 +312,8 @@ class AttackInputData:
|
||||||
self.loss_function_using_logits = (self.logits_train is not None)
|
self.loss_function_using_logits = (self.logits_train is not None)
|
||||||
return utils.get_loss(self.loss_train, self.labels_train, self.logits_train,
|
return utils.get_loss(self.loss_train, self.labels_train, self.logits_train,
|
||||||
self.probs_train, self.loss_function,
|
self.probs_train, self.loss_function,
|
||||||
self.loss_function_using_logits, self.multilabel_data)
|
self.loss_function_using_logits, self.multilabel_data,
|
||||||
|
self.sample_weight_train)
|
||||||
|
|
||||||
def get_loss_test(self):
|
def get_loss_test(self):
|
||||||
"""Calculates (if needed) cross-entropy losses for the test set.
|
"""Calculates (if needed) cross-entropy losses for the test set.
|
||||||
|
@ -321,7 +325,8 @@ class AttackInputData:
|
||||||
self.loss_function_using_logits = bool(self.logits_test)
|
self.loss_function_using_logits = bool(self.logits_test)
|
||||||
return utils.get_loss(self.loss_test, self.labels_test, self.logits_test,
|
return utils.get_loss(self.loss_test, self.labels_test, self.logits_test,
|
||||||
self.probs_test, self.loss_function,
|
self.probs_test, self.loss_function,
|
||||||
self.loss_function_using_logits, self.multilabel_data)
|
self.loss_function_using_logits, self.multilabel_data,
|
||||||
|
self.sample_weight_test)
|
||||||
|
|
||||||
def get_entropy_train(self):
|
def get_entropy_train(self):
|
||||||
"""Calculates prediction entropy for the training set."""
|
"""Calculates prediction entropy for the training set."""
|
||||||
|
@ -367,6 +372,11 @@ class AttackInputData:
|
||||||
"""Returns the number of examples of the test set."""
|
"""Returns the number of examples of the test set."""
|
||||||
return self.get_test_shape()[0]
|
return self.get_test_shape()[0]
|
||||||
|
|
||||||
|
def has_nonnull_sample_weights(self):
|
||||||
|
"""Whether both the train and test input data have sample weights."""
|
||||||
|
return (self.sample_weight_train is not None and
|
||||||
|
self.sample_weight_test is not None)
|
||||||
|
|
||||||
def is_multihot_labels(self, arr, arr_name) -> bool:
|
def is_multihot_labels(self, arr, arr_name) -> bool:
|
||||||
"""Check if the 2D array is multihot, with values in [0, 1].
|
"""Check if the 2D array is multihot, with values in [0, 1].
|
||||||
|
|
||||||
|
@ -556,6 +566,8 @@ class AttackInputData:
|
||||||
_append_array_shape(self.probs_test, 'probs_test', result)
|
_append_array_shape(self.probs_test, 'probs_test', result)
|
||||||
_append_array_shape(self.labels_train, 'labels_train', result)
|
_append_array_shape(self.labels_train, 'labels_train', result)
|
||||||
_append_array_shape(self.labels_test, 'labels_test', result)
|
_append_array_shape(self.labels_test, 'labels_test', result)
|
||||||
|
_append_array_shape(self.sample_weight_train, 'sample_weight_train', result)
|
||||||
|
_append_array_shape(self.sample_weight_test, 'sample_weight_test', result)
|
||||||
result.append(')')
|
result.append(')')
|
||||||
return '\n'.join(result)
|
return '\n'.join(result)
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,20 @@ class AttackInputDataTest(parameterized.TestCase):
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7)
|
attack_input.get_loss_test(), [0.29860897, 0.95618669], atol=1e-7)
|
||||||
|
|
||||||
|
def test_get_xe_loss_from_logits_with_sample_weights(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
logits_train=np.array([[-0.3, 1.5, 0.2], [2, 3, 0.5]]),
|
||||||
|
logits_test=np.array([[2, 0.3, 0.2], [0.3, -0.5, 0.2]]),
|
||||||
|
labels_train=np.array([1, 0]),
|
||||||
|
labels_test=np.array([0, 2]),
|
||||||
|
sample_weight_train=np.array([1.0, 0.5]),
|
||||||
|
sample_weight_test=np.array([0.5, 1.0]))
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_train(), [0.36313551, 0.685769515], atol=1e-7)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_test(), [0.149304485, 0.95618669], atol=1e-7)
|
||||||
|
|
||||||
def test_get_xe_loss_from_probs(self):
|
def test_get_xe_loss_from_probs(self):
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]),
|
probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]),
|
||||||
|
@ -88,6 +102,30 @@ class AttackInputDataTest(parameterized.TestCase):
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
attack_input.get_loss_test(), expected_loss0[::-1], rtol=1e-2)
|
attack_input.get_loss_test(), expected_loss0[::-1], rtol=1e-2)
|
||||||
|
|
||||||
|
def test_get_binary_xe_loss_from_logits_with_sample_weights(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
logits_train=np.array([-10, -5, 0., 5, 10]),
|
||||||
|
logits_test=np.array([-10, -5, 0., 5, 10]),
|
||||||
|
labels_train=np.zeros((5,)),
|
||||||
|
labels_test=np.ones((5,)),
|
||||||
|
sample_weight_train=np.array([1.0, 0.0, 0.5, 0.2, 1.0]),
|
||||||
|
sample_weight_test=np.array([0.0, 0.1, 0.5, 0.2, 1.0]),
|
||||||
|
loss_function_using_logits=True)
|
||||||
|
expected_train_loss = np.array(
|
||||||
|
[4.539890e-05, 0.0, 0.3465736, 1.001343, 10.0])
|
||||||
|
expected_test_loss = np.array(
|
||||||
|
[4.539890e-05, 0.001343070, 0.3465736, 0.5006715, 0.0])
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_train(),
|
||||||
|
expected_train_loss,
|
||||||
|
rtol=1e-2,
|
||||||
|
err_msg='Failure in binary xe training loss calculation.')
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_test(),
|
||||||
|
expected_test_loss[::-1],
|
||||||
|
rtol=1e-2,
|
||||||
|
err_msg='Failure in binary xe test loss calculation.')
|
||||||
|
|
||||||
def test_get_binary_xe_loss_from_probs(self):
|
def test_get_binary_xe_loss_from_probs(self):
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
probs_train=np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]),
|
probs_train=np.array([0.2, 0.7, 0.1, 0.99, 0.002, 0.008]),
|
||||||
|
@ -137,7 +175,8 @@ class AttackInputDataTest(parameterized.TestCase):
|
||||||
def test_get_customized_loss(self, loss_function_using_logits, expected_train,
|
def test_get_customized_loss(self, loss_function_using_logits, expected_train,
|
||||||
expected_test):
|
expected_test):
|
||||||
|
|
||||||
def fake_loss(x, y):
|
def fake_loss(x, y, sample_weight=None):
|
||||||
|
del sample_weight # Unused.
|
||||||
return 2 * x + y
|
return 2 * x + y
|
||||||
|
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
|
@ -332,6 +371,23 @@ class AttackInputDataTest(parameterized.TestCase):
|
||||||
attack_input.get_loss_test(), [[0.22314354, 0.35667493, 2.30258499]],
|
attack_input.get_loss_test(), [[0.22314354, 0.35667493, 2.30258499]],
|
||||||
atol=1e-6)
|
atol=1e-6)
|
||||||
|
|
||||||
|
def test_multilabel_get_bce_loss_from_probs_with_sample_weights(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||||
|
probs_test=np.array([[0.8, 0.7, 0.9]]),
|
||||||
|
labels_train=np.array([[0, 1, 1], [1, 1, 0]]),
|
||||||
|
labels_test=np.array([[1, 1, 0]]),
|
||||||
|
sample_weight_train=np.array([1.0, 0.5]),
|
||||||
|
sample_weight_test=np.array([0.5]))
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_train(), [[0.22314343, 1.20397247, 0.3566748],
|
||||||
|
[0.111571715, 0.25541273, 1.151292045]],
|
||||||
|
atol=1e-6)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_test(), [[0.11157177, 0.17833747, 1.1512925]],
|
||||||
|
atol=1e-6)
|
||||||
|
|
||||||
def test_multilabel_get_bce_loss_from_logits(self):
|
def test_multilabel_get_bce_loss_from_logits(self):
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||||
|
@ -349,6 +405,25 @@ class AttackInputDataTest(parameterized.TestCase):
|
||||||
[[0.68815966, 0.20141327], [0.47407697, 0.04858734]],
|
[[0.68815966, 0.20141327], [0.47407697, 0.04858734]],
|
||||||
atol=1e-6)
|
atol=1e-6)
|
||||||
|
|
||||||
|
def test_multilabel_get_bce_loss_from_logits_with_sample_weights(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
logits_train=np.array([[-1.0, -2.0], [0.01, 1.5], [0.5, -3]]),
|
||||||
|
logits_test=np.array([[0.01, 1.5], [0.5, -3]]),
|
||||||
|
labels_train=np.array([[0, 0], [0, 1], [1, 1]]),
|
||||||
|
labels_test=np.array([[1, 1], [1, 0]]),
|
||||||
|
sample_weight_train=np.array([1.0, 0.0, 0.5]),
|
||||||
|
sample_weight_test=np.array([0.0, 0.5]))
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_train(),
|
||||||
|
[[0.31326167, 0.126928], [0.0, 0.0], [0.23703848, 1.52429357]],
|
||||||
|
atol=1e-6,
|
||||||
|
err_msg='Failure in multilabel bce training loss calculation.')
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
attack_input.get_loss_test(), [[0.0, 0.0], [0.23703848, 0.02429367]],
|
||||||
|
atol=1e-6,
|
||||||
|
err_msg='Failure in multilabel bce test loss calculation.')
|
||||||
|
|
||||||
def test_multilabel_get_loss_explicitly_provided(self):
|
def test_multilabel_get_loss_explicitly_provided(self):
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
loss_train=np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]),
|
loss_train=np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]),
|
||||||
|
@ -359,6 +434,23 @@ class AttackInputDataTest(parameterized.TestCase):
|
||||||
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
||||||
np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
|
np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]))
|
||||||
|
|
||||||
|
def test_multilabel_get_loss_explicitly_provided_with_sample_weights(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
loss_train=np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]),
|
||||||
|
loss_test=np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]),
|
||||||
|
sample_weight_train=np.array([1.0, 0.5]),
|
||||||
|
sample_weight_test=np.array([0.0, 0.5]))
|
||||||
|
|
||||||
|
# Since loss is provided, sample weights have no effect.
|
||||||
|
np.testing.assert_equal(
|
||||||
|
attack_input.get_loss_train().tolist(),
|
||||||
|
np.array([[1.0, 3.0, 6.0], [6.0, 8.0, 9.0]]),
|
||||||
|
err_msg='Failure in multilabel get training loss calculation.')
|
||||||
|
np.testing.assert_equal(
|
||||||
|
attack_input.get_loss_test().tolist(),
|
||||||
|
np.array([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]]),
|
||||||
|
err_msg='Failure in multilabel get test loss calculation.')
|
||||||
|
|
||||||
def test_validate_with_force_multilabel_false(self):
|
def test_validate_with_force_multilabel_false(self):
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
probs_train=np.array([[0.2, 0.3, 0.7], [0.8, 0.6, 0.9]]),
|
||||||
|
|
|
@ -42,6 +42,9 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
|
||||||
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
|
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
|
||||||
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
|
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
|
||||||
result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train)
|
result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train)
|
||||||
|
# Copy over sample weights if provided.
|
||||||
|
result.sample_weight_train = data.sample_weight_train
|
||||||
|
result.sample_weight_test = data.sample_weight_test
|
||||||
|
|
||||||
# Slice test data.
|
# Slice test data.
|
||||||
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
|
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
|
||||||
|
|
|
@ -115,6 +115,8 @@ class GetSliceTest(absltest.TestCase):
|
||||||
loss_test = np.array([0.5, 3.5, 7, 4.5])
|
loss_test = np.array([0.5, 3.5, 7, 4.5])
|
||||||
entropy_train = np.array([0.4, 8, 0.6, 10])
|
entropy_train = np.array([0.4, 8, 0.6, 10])
|
||||||
entropy_test = np.array([15, 10.5, 4.5, 0.3])
|
entropy_test = np.array([15, 10.5, 4.5, 0.3])
|
||||||
|
sample_weight_train = np.array([1.0, 0.5])
|
||||||
|
sample_weight_test = np.array([0.5, 1.0])
|
||||||
|
|
||||||
self.input_data = AttackInputData(
|
self.input_data = AttackInputData(
|
||||||
logits_train=logits_train,
|
logits_train=logits_train,
|
||||||
|
@ -126,7 +128,9 @@ class GetSliceTest(absltest.TestCase):
|
||||||
loss_train=loss_train,
|
loss_train=loss_train,
|
||||||
loss_test=loss_test,
|
loss_test=loss_test,
|
||||||
entropy_train=entropy_train,
|
entropy_train=entropy_train,
|
||||||
entropy_test=entropy_test)
|
entropy_test=entropy_test,
|
||||||
|
sample_weight_train=sample_weight_train,
|
||||||
|
sample_weight_test=sample_weight_test)
|
||||||
|
|
||||||
def test_slice_entire_dataset(self):
|
def test_slice_entire_dataset(self):
|
||||||
entire_dataset_slice = SingleSliceSpec()
|
entire_dataset_slice = SingleSliceSpec()
|
||||||
|
@ -168,6 +172,12 @@ class GetSliceTest(absltest.TestCase):
|
||||||
self.assertTrue((output.entropy_train == [0.4, 0.6]).all())
|
self.assertTrue((output.entropy_train == [0.4, 0.6]).all())
|
||||||
self.assertTrue((output.entropy_test == [15]).all())
|
self.assertTrue((output.entropy_test == [15]).all())
|
||||||
|
|
||||||
|
# Check sample weights
|
||||||
|
self.assertLen(output.sample_weight_train, 2)
|
||||||
|
np.testing.assert_array_equal(output.sample_weight_train, [1.0, 0.5])
|
||||||
|
self.assertLen(output.sample_weight_test, 2)
|
||||||
|
np.testing.assert_array_equal(output.sample_weight_test, [0.5, 1.0])
|
||||||
|
|
||||||
def test_slice_by_percentile(self):
|
def test_slice_by_percentile(self):
|
||||||
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
||||||
output = get_slice(self.input_data, percentile_slice)
|
output = get_slice(self.input_data, percentile_slice)
|
||||||
|
|
|
@ -72,6 +72,19 @@ def get_multilabel_test_input(n_train, n_test):
|
||||||
labels_test=get_multihot_labels_for_test(n_test, num_classes))
|
labels_test=get_multihot_labels_for_test(n_test, num_classes))
|
||||||
|
|
||||||
|
|
||||||
|
def get_multilabel_test_input_with_sample_weights(n_train, n_test):
|
||||||
|
"""Get example multilabel inputs for attacks."""
|
||||||
|
rng = np.random.RandomState(4)
|
||||||
|
num_classes = max(n_train // 20, 5) # use at least 5 classes.
|
||||||
|
return AttackInputData(
|
||||||
|
logits_train=rng.randn(n_train, num_classes) + 0.2,
|
||||||
|
logits_test=rng.randn(n_test, num_classes) + 0.2,
|
||||||
|
labels_train=get_multihot_labels_for_test(n_train, num_classes),
|
||||||
|
labels_test=get_multihot_labels_for_test(n_test, num_classes),
|
||||||
|
sample_weight_train=rng.randn(n_train, 1),
|
||||||
|
sample_weight_test=rng.randn(n_test, 1))
|
||||||
|
|
||||||
|
|
||||||
def get_test_input_logits_only(n_train, n_test):
|
def get_test_input_logits_only(n_train, n_test):
|
||||||
"""Get example input logits for attacks."""
|
"""Get example input logits for attacks."""
|
||||||
rng = np.random.RandomState(4)
|
rng = np.random.RandomState(4)
|
||||||
|
@ -80,6 +93,16 @@ def get_test_input_logits_only(n_train, n_test):
|
||||||
logits_test=rng.randn(n_test, 5) + 0.2)
|
logits_test=rng.randn(n_test, 5) + 0.2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_input_logits_only_with_sample_weights(n_train, n_test):
|
||||||
|
"""Get example input logits for attacks."""
|
||||||
|
rng = np.random.RandomState(4)
|
||||||
|
return AttackInputData(
|
||||||
|
logits_train=rng.randn(n_train, 5) + 0.2,
|
||||||
|
logits_test=rng.randn(n_test, 5) + 0.2,
|
||||||
|
sample_weight_train=rng.randn(n_train, 1),
|
||||||
|
sample_weight_test=rng.randn(n_test, 1))
|
||||||
|
|
||||||
|
|
||||||
class RunAttacksTest(parameterized.TestCase):
|
class RunAttacksTest(parameterized.TestCase):
|
||||||
|
|
||||||
def test_run_attacks_size(self):
|
def test_run_attacks_size(self):
|
||||||
|
@ -113,6 +136,17 @@ class RunAttacksTest(parameterized.TestCase):
|
||||||
|
|
||||||
self.assertLen(result.single_attack_results, 2)
|
self.assertLen(result.single_attack_results, 2)
|
||||||
|
|
||||||
|
def test_run_attacks_parallel_backend_with_sample_weights(self):
|
||||||
|
result = mia.run_attacks(
|
||||||
|
get_multilabel_test_input_with_sample_weights(100, 100),
|
||||||
|
SlicingSpec(), (
|
||||||
|
AttackType.LOGISTIC_REGRESSION,
|
||||||
|
AttackType.RANDOM_FOREST,
|
||||||
|
),
|
||||||
|
backend='threading')
|
||||||
|
|
||||||
|
self.assertLen(result.single_attack_results, 2)
|
||||||
|
|
||||||
def test_trained_attacks_logits_only_size(self):
|
def test_trained_attacks_logits_only_size(self):
|
||||||
result = mia.run_attacks(
|
result = mia.run_attacks(
|
||||||
get_test_input_logits_only(100, 100), SlicingSpec(),
|
get_test_input_logits_only(100, 100), SlicingSpec(),
|
||||||
|
@ -120,6 +154,13 @@ class RunAttacksTest(parameterized.TestCase):
|
||||||
|
|
||||||
self.assertLen(result.single_attack_results, 1)
|
self.assertLen(result.single_attack_results, 1)
|
||||||
|
|
||||||
|
def test_trained_attacks_logits_only_with_sample_weights_size(self):
|
||||||
|
result = mia.run_attacks(
|
||||||
|
get_test_input_logits_only_with_sample_weights(100, 100), SlicingSpec(),
|
||||||
|
(AttackType.LOGISTIC_REGRESSION,))
|
||||||
|
|
||||||
|
self.assertLen(result.single_attack_results, 1)
|
||||||
|
|
||||||
def test_run_attack_trained_sets_attack_type(self):
|
def test_run_attack_trained_sets_attack_type(self):
|
||||||
result = mia._run_attack(
|
result = mia._run_attack(
|
||||||
get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
||||||
|
@ -271,6 +312,15 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase):
|
||||||
|
|
||||||
self.assertLen(result.single_attack_results, 1)
|
self.assertLen(result.single_attack_results, 1)
|
||||||
|
|
||||||
|
def test_run_attacks_parallel_backend_with_sample_weights(self):
|
||||||
|
result = mia.run_attacks(
|
||||||
|
get_multilabel_test_input_with_sample_weights(100, 100),
|
||||||
|
SlicingSpec(),
|
||||||
|
(AttackType.LOGISTIC_REGRESSION, AttackType.RANDOM_FOREST),
|
||||||
|
backend='threading')
|
||||||
|
|
||||||
|
self.assertLen(result.single_attack_results, 2)
|
||||||
|
|
||||||
def test_run_attack_trained_sets_attack_type(self):
|
def test_run_attack_trained_sets_attack_type(self):
|
||||||
result = mia._run_attack(
|
result = mia._run_attack(
|
||||||
get_multilabel_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
get_multilabel_test_input(100, 100), AttackType.LOGISTIC_REGRESSION)
|
||||||
|
|
|
@ -39,6 +39,8 @@ class AttackerData:
|
||||||
features_all: Optional[np.ndarray] = None
|
features_all: Optional[np.ndarray] = None
|
||||||
# Indicator for whether the example is in-training (0) or out-of-training (1).
|
# Indicator for whether the example is in-training (0) or out-of-training (1).
|
||||||
labels_all: Optional[np.ndarray] = None
|
labels_all: Optional[np.ndarray] = None
|
||||||
|
# Sample weights of in-training and out-of-training examples, if provided.
|
||||||
|
sample_weights_all: Optional[np.ndarray] = None
|
||||||
|
|
||||||
# Indices for `features_all` and `labels_all` that are going to be used for
|
# Indices for `features_all` and `labels_all` that are going to be used for
|
||||||
# training the attackers.
|
# training the attackers.
|
||||||
|
@ -75,6 +77,12 @@ def create_attacker_data(attack_input_data: data_structures.AttackInputData,
|
||||||
ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0]
|
ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0]
|
||||||
features_all = np.concatenate((attack_input_train, attack_input_test))
|
features_all = np.concatenate((attack_input_train, attack_input_test))
|
||||||
labels_all = np.concatenate((np.zeros(ntrain), np.ones(ntest)))
|
labels_all = np.concatenate((np.zeros(ntrain), np.ones(ntest)))
|
||||||
|
if attack_input_data.has_nonnull_sample_weights():
|
||||||
|
sample_weights_all = np.concatenate((attack_input_data.sample_weight_train,
|
||||||
|
attack_input_data.sample_weight_test),
|
||||||
|
axis=0)
|
||||||
|
else:
|
||||||
|
sample_weights_all = None
|
||||||
|
|
||||||
fold_indices = np.arange(ntrain + ntest)
|
fold_indices = np.arange(ntrain + ntest)
|
||||||
left_out_indices = np.asarray([], dtype=np.int32)
|
left_out_indices = np.asarray([], dtype=np.int32)
|
||||||
|
@ -99,6 +107,7 @@ def create_attacker_data(attack_input_data: data_structures.AttackInputData,
|
||||||
return AttackerData(
|
return AttackerData(
|
||||||
features_all=features_all,
|
features_all=features_all,
|
||||||
labels_all=labels_all,
|
labels_all=labels_all,
|
||||||
|
sample_weights_all=sample_weights_all,
|
||||||
fold_indices=fold_indices,
|
fold_indices=fold_indices,
|
||||||
left_out_indices=left_out_indices,
|
left_out_indices=left_out_indices,
|
||||||
data_size=data_structures.DataSize(ntrain=ntrain, ntest=ntest))
|
data_size=data_structures.DataSize(ntrain=ntrain, ntest=ntest))
|
||||||
|
@ -158,7 +167,7 @@ class TrainedAttacker(object):
|
||||||
n_jobs=self.n_jobs)
|
n_jobs=self.n_jobs)
|
||||||
logging.info('Using %s backend for training.', backend)
|
logging.info('Using %s backend for training.', backend)
|
||||||
|
|
||||||
def train_model(self, input_features, is_training_labels):
|
def train_model(self, input_features, is_training_labels, sample_weight=None):
|
||||||
"""Train an attacker model.
|
"""Train an attacker model.
|
||||||
|
|
||||||
This is trained on examples from train and test datasets.
|
This is trained on examples from train and test datasets.
|
||||||
|
@ -168,6 +177,10 @@ class TrainedAttacker(object):
|
||||||
number of features.
|
number of features.
|
||||||
is_training_labels : a vector of booleans of shape (n_samples, )
|
is_training_labels : a vector of booleans of shape (n_samples, )
|
||||||
representing whether the sample is in the training set or not.
|
representing whether the sample is in the training set or not.
|
||||||
|
sample_weight: a vector of weights of shape (n_samples, ) that are
|
||||||
|
assigned to individual samples. If not provided, then each sample is
|
||||||
|
given unit weight. Only the LogisticRegressionAttacker and the
|
||||||
|
RandomForestAttacker support sample weights.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -193,7 +206,7 @@ class LogisticRegressionAttacker(TrainedAttacker):
|
||||||
def __init__(self, backend: Optional[str] = None):
|
def __init__(self, backend: Optional[str] = None):
|
||||||
super().__init__(backend=backend)
|
super().__init__(backend=backend)
|
||||||
|
|
||||||
def train_model(self, input_features, is_training_labels):
|
def train_model(self, input_features, is_training_labels, sample_weight=None):
|
||||||
with self.ctx_mgr:
|
with self.ctx_mgr:
|
||||||
lr = linear_model.LogisticRegression(solver='lbfgs', n_jobs=self.n_jobs)
|
lr = linear_model.LogisticRegression(solver='lbfgs', n_jobs=self.n_jobs)
|
||||||
param_grid = {
|
param_grid = {
|
||||||
|
@ -201,7 +214,7 @@ class LogisticRegressionAttacker(TrainedAttacker):
|
||||||
}
|
}
|
||||||
model = model_selection.GridSearchCV(
|
model = model_selection.GridSearchCV(
|
||||||
lr, param_grid=param_grid, cv=3, n_jobs=self.n_jobs, verbose=0)
|
lr, param_grid=param_grid, cv=3, n_jobs=self.n_jobs, verbose=0)
|
||||||
model.fit(input_features, is_training_labels)
|
model.fit(input_features, is_training_labels, sample_weight=sample_weight)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
|
||||||
|
@ -211,7 +224,8 @@ class MultilayerPerceptronAttacker(TrainedAttacker):
|
||||||
def __init__(self, backend: Optional[str] = None):
|
def __init__(self, backend: Optional[str] = None):
|
||||||
super().__init__(backend=backend)
|
super().__init__(backend=backend)
|
||||||
|
|
||||||
def train_model(self, input_features, is_training_labels):
|
def train_model(self, input_features, is_training_labels, sample_weight=None):
|
||||||
|
del sample_weight # MLP attacker does not use sample weights.
|
||||||
with self.ctx_mgr:
|
with self.ctx_mgr:
|
||||||
mlp_model = neural_network.MLPClassifier()
|
mlp_model = neural_network.MLPClassifier()
|
||||||
param_grid = {
|
param_grid = {
|
||||||
|
@ -231,7 +245,7 @@ class RandomForestAttacker(TrainedAttacker):
|
||||||
def __init__(self, backend: Optional[str] = None):
|
def __init__(self, backend: Optional[str] = None):
|
||||||
super().__init__(backend=backend)
|
super().__init__(backend=backend)
|
||||||
|
|
||||||
def train_model(self, input_features, is_training_labels):
|
def train_model(self, input_features, is_training_labels, sample_weight=None):
|
||||||
"""Setup a random forest pipeline with cross-validation."""
|
"""Setup a random forest pipeline with cross-validation."""
|
||||||
with self.ctx_mgr:
|
with self.ctx_mgr:
|
||||||
rf_model = ensemble.RandomForestClassifier(n_jobs=self.n_jobs)
|
rf_model = ensemble.RandomForestClassifier(n_jobs=self.n_jobs)
|
||||||
|
@ -241,11 +255,11 @@ class RandomForestAttacker(TrainedAttacker):
|
||||||
'max_features': ['auto', 'sqrt'],
|
'max_features': ['auto', 'sqrt'],
|
||||||
'max_depth': [5, 10, 20, None],
|
'max_depth': [5, 10, 20, None],
|
||||||
'min_samples_split': [2, 5, 10],
|
'min_samples_split': [2, 5, 10],
|
||||||
'min_samples_leaf': [1, 2, 4]
|
'min_samples_leaf': [1, 2, 4],
|
||||||
}
|
}
|
||||||
model = model_selection.GridSearchCV(
|
model = model_selection.GridSearchCV(
|
||||||
rf_model, param_grid=param_grid, cv=3, n_jobs=self.n_jobs, verbose=0)
|
rf_model, param_grid=param_grid, cv=3, n_jobs=self.n_jobs, verbose=0)
|
||||||
model.fit(input_features, is_training_labels)
|
model.fit(input_features, is_training_labels, sample_weight=sample_weight)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
|
||||||
|
@ -255,7 +269,8 @@ class KNearestNeighborsAttacker(TrainedAttacker):
|
||||||
def __init__(self, backend: Optional[str] = None):
|
def __init__(self, backend: Optional[str] = None):
|
||||||
super().__init__(backend=backend)
|
super().__init__(backend=backend)
|
||||||
|
|
||||||
def train_model(self, input_features, is_training_labels):
|
def train_model(self, input_features, is_training_labels, sample_weight=None):
|
||||||
|
del sample_weight # K-NN attacker does not use sample weights.
|
||||||
with self.ctx_mgr:
|
with self.ctx_mgr:
|
||||||
knn_model = neighbors.KNeighborsClassifier(n_jobs=self.n_jobs)
|
knn_model = neighbors.KNeighborsClassifier(n_jobs=self.n_jobs)
|
||||||
param_grid = {
|
param_grid = {
|
||||||
|
|
|
@ -68,6 +68,25 @@ class TrainedAttackerTest(parameterized.TestCase):
|
||||||
attack_input.is_multilabel_data(),
|
attack_input.is_multilabel_data(),
|
||||||
msg='Expected multilabel check to pass.')
|
msg='Expected multilabel check to pass.')
|
||||||
|
|
||||||
|
def test_multilabel_create_attacker_data_logits_labels_sample_weights(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
||||||
|
logits_test=np.array([[10, 11], [14, 15]]),
|
||||||
|
labels_train=np.array([[0, 1], [1, 1], [1, 0]]),
|
||||||
|
labels_test=np.array([[1, 0], [1, 1]]),
|
||||||
|
sample_weight_train=np.array([1.0, 0.5, 0.0]),
|
||||||
|
sample_weight_test=np.array([1.0, 0.5]))
|
||||||
|
attacker_data = models.create_attacker_data(attack_input, balance=False)
|
||||||
|
self.assertLen(attacker_data.features_all, 5)
|
||||||
|
self.assertLen(attacker_data.fold_indices, 5)
|
||||||
|
self.assertEmpty(attacker_data.left_out_indices)
|
||||||
|
self.assertTrue(
|
||||||
|
attack_input.is_multilabel_data(),
|
||||||
|
msg='Expected multilabel check to pass.')
|
||||||
|
self.assertTrue(
|
||||||
|
attack_input.has_nonnull_sample_weights(),
|
||||||
|
msg='Expected to have sample weights.')
|
||||||
|
|
||||||
def test_unbalanced_create_attacker_data_loss_and_logits(self):
|
def test_unbalanced_create_attacker_data_loss_and_logits(self):
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
"""A hook and a function in tf estimator for membership inference attack."""
|
"""A hook and a function in tf estimator for membership inference attack."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Iterable
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -26,7 +26,10 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import utils_tensorboard
|
||||||
|
|
||||||
|
|
||||||
def calculate_losses(estimator, input_fn, labels):
|
def calculate_losses(estimator,
|
||||||
|
input_fn,
|
||||||
|
labels,
|
||||||
|
sample_weight: Optional[np.ndarray] = None):
|
||||||
"""Get predictions and losses for samples.
|
"""Get predictions and losses for samples.
|
||||||
|
|
||||||
The assumptions are 1) the loss is cross-entropy loss, and 2) user have
|
The assumptions are 1) the loss is cross-entropy loss, and 2) user have
|
||||||
|
@ -38,13 +41,17 @@ def calculate_losses(estimator, input_fn, labels):
|
||||||
estimator: model to make prediction
|
estimator: model to make prediction
|
||||||
input_fn: input function to be used in estimator.predict
|
input_fn: input function to be used in estimator.predict
|
||||||
labels: array of size (n_samples, ), true labels of samples (integer valued)
|
labels: array of size (n_samples, ), true labels of samples (integer valued)
|
||||||
|
sample_weight: a vector of weights of shape (n_samples, ) that are
|
||||||
|
assigned to individual samples. If not provided, then each sample is
|
||||||
|
given unit weight. Only the LogisticRegressionAttacker and the
|
||||||
|
RandomForestAttacker support sample weights.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
preds: probability vector of each sample
|
preds: probability vector of each sample
|
||||||
loss: cross entropy loss of each sample
|
loss: cross entropy loss of each sample
|
||||||
"""
|
"""
|
||||||
pred = np.array(list(estimator.predict(input_fn=input_fn)))
|
pred = np.array(list(estimator.predict(input_fn=input_fn)))
|
||||||
loss = utils.log_loss(labels, pred)
|
loss = utils.log_loss(labels, pred, sample_weight=sample_weight)
|
||||||
return pred, loss
|
return pred, loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,8 +72,10 @@ class MembershipInferenceTrainingHook(tf_estimator.SessionRunHook):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
estimator: model to be tested
|
estimator: model to be tested
|
||||||
in_train: (in_training samples, in_training labels)
|
in_train: (in_training samples, in_training labels,
|
||||||
out_train: (out_training samples, out_training labels)
|
in_train_sample_weights=None)
|
||||||
|
out_train: (out_training samples, out_training labels,
|
||||||
|
out_train_sample_weights=None)
|
||||||
input_fn_constructor: a function that receives sample, label and construct
|
input_fn_constructor: a function that receives sample, label and construct
|
||||||
the input_fn for model prediction
|
the input_fn for model prediction
|
||||||
slicing_spec: slicing specification of the attack
|
slicing_spec: slicing specification of the attack
|
||||||
|
@ -75,8 +84,24 @@ class MembershipInferenceTrainingHook(tf_estimator.SessionRunHook):
|
||||||
tensorboard_merge_classifiers: if true, plot different classifiers with
|
tensorboard_merge_classifiers: if true, plot different classifiers with
|
||||||
the same slicing_spec and metric in the same figure
|
the same slicing_spec and metric in the same figure
|
||||||
"""
|
"""
|
||||||
in_train_data, self._in_train_labels = in_train
|
if len(in_train) == 2:
|
||||||
out_train_data, self._out_train_labels = out_train
|
in_train_data, self._in_train_labels = in_train
|
||||||
|
self._in_train_sample_weights = None
|
||||||
|
elif len(in_train) == 3:
|
||||||
|
(in_train_data, self._in_train_labels,
|
||||||
|
self._in_train_sample_weights) = in_train
|
||||||
|
else:
|
||||||
|
raise ValueError('`in_train` should be length 2 or 3, received '
|
||||||
|
f'{len(in_train)}')
|
||||||
|
if len(out_train) == 2:
|
||||||
|
out_train_data, self._out_train_labels = out_train
|
||||||
|
self._out_train_sample_weights = None
|
||||||
|
elif len(out_train) == 3:
|
||||||
|
(out_train_data, self._out_train_labels,
|
||||||
|
self._out_train_sample_weights) = in_train
|
||||||
|
else:
|
||||||
|
raise ValueError('`out_train` should be length 2 or 3, received '
|
||||||
|
f'{len(out_train)}')
|
||||||
|
|
||||||
# Define the input functions for both in and out-training samples.
|
# Define the input functions for both in and out-training samples.
|
||||||
self._in_train_input_fn = input_fn_constructor(in_train_data,
|
self._in_train_input_fn = input_fn_constructor(in_train_data,
|
||||||
|
@ -105,8 +130,10 @@ class MembershipInferenceTrainingHook(tf_estimator.SessionRunHook):
|
||||||
def end(self, session):
|
def end(self, session):
|
||||||
results = run_attack_helper(self._estimator, self._in_train_input_fn,
|
results = run_attack_helper(self._estimator, self._in_train_input_fn,
|
||||||
self._out_train_input_fn, self._in_train_labels,
|
self._out_train_input_fn, self._in_train_labels,
|
||||||
self._out_train_labels, self._slicing_spec,
|
self._out_train_labels,
|
||||||
self._attack_types)
|
self._in_train_sample_weights,
|
||||||
|
self._out_train_sample_weights,
|
||||||
|
self._slicing_spec, self._attack_types)
|
||||||
logging.info(results)
|
logging.info(results)
|
||||||
|
|
||||||
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
||||||
|
@ -137,8 +164,10 @@ def run_attack_on_tf_estimator_model(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
estimator: model to be tested
|
estimator: model to be tested
|
||||||
in_train: (in_training samples, in_training labels)
|
in_train: (in_training samples, in_training labels,
|
||||||
out_train: (out_training samples, out_training labels)
|
in_training_sample_weights=None)
|
||||||
|
out_train: (out_training samples, out_training labels,
|
||||||
|
out_training_sample_weights=None)
|
||||||
input_fn_constructor: a function that receives sample, label and construct
|
input_fn_constructor: a function that receives sample, label and construct
|
||||||
the input_fn for model prediction
|
the input_fn for model prediction
|
||||||
slicing_spec: slicing specification of the attack
|
slicing_spec: slicing specification of the attack
|
||||||
|
@ -147,8 +176,20 @@ def run_attack_on_tf_estimator_model(
|
||||||
Returns:
|
Returns:
|
||||||
Results of the attack
|
Results of the attack
|
||||||
"""
|
"""
|
||||||
in_train_data, in_train_labels = in_train
|
|
||||||
out_train_data, out_train_labels = out_train
|
def unpack(data):
|
||||||
|
sample_weights = None
|
||||||
|
if len(data) == 2:
|
||||||
|
inputs, labels = data
|
||||||
|
elif len(in_train) == 3:
|
||||||
|
inputs, labels, sample_weights = in_train
|
||||||
|
else:
|
||||||
|
raise ValueError('`data` should be length 2 or 3, received '
|
||||||
|
f'{len(data)}')
|
||||||
|
return inputs, labels, sample_weights
|
||||||
|
|
||||||
|
in_train_data, in_train_labels, in_train_sample_weights = unpack(in_train)
|
||||||
|
out_train_data, out_train_labels, out_train_sample_weights = unpack(out_train)
|
||||||
|
|
||||||
# Define the input functions for both in and out-training samples.
|
# Define the input functions for both in and out-training samples.
|
||||||
in_train_input_fn = input_fn_constructor(in_train_data, in_train_labels)
|
in_train_input_fn = input_fn_constructor(in_train_data, in_train_labels)
|
||||||
|
@ -156,8 +197,9 @@ def run_attack_on_tf_estimator_model(
|
||||||
|
|
||||||
# Call the helper to run the attack.
|
# Call the helper to run the attack.
|
||||||
results = run_attack_helper(estimator, in_train_input_fn, out_train_input_fn,
|
results = run_attack_helper(estimator, in_train_input_fn, out_train_input_fn,
|
||||||
in_train_labels, out_train_labels, slicing_spec,
|
in_train_labels, out_train_labels,
|
||||||
attack_types)
|
in_train_sample_weights, out_train_sample_weights,
|
||||||
|
slicing_spec, attack_types)
|
||||||
logging.info('End of training attack:')
|
logging.info('End of training attack:')
|
||||||
logging.info(results)
|
logging.info(results)
|
||||||
return results
|
return results
|
||||||
|
@ -168,6 +210,8 @@ def run_attack_helper(estimator,
|
||||||
out_train_input_fn,
|
out_train_input_fn,
|
||||||
in_train_labels,
|
in_train_labels,
|
||||||
out_train_labels,
|
out_train_labels,
|
||||||
|
in_train_sample_weight: Optional[np.ndarray] = None,
|
||||||
|
out_train_sample_weight: Optional[np.ndarray] = None,
|
||||||
slicing_spec: data_structures.SlicingSpec = None,
|
slicing_spec: data_structures.SlicingSpec = None,
|
||||||
attack_types: Iterable[data_structures.AttackType] = (
|
attack_types: Iterable[data_structures.AttackType] = (
|
||||||
data_structures.AttackType.THRESHOLD_ATTACK,)):
|
data_structures.AttackType.THRESHOLD_ATTACK,)):
|
||||||
|
@ -179,6 +223,14 @@ def run_attack_helper(estimator,
|
||||||
out_train_input_fn: input_fn for out of training data
|
out_train_input_fn: input_fn for out of training data
|
||||||
in_train_labels: in training labels
|
in_train_labels: in training labels
|
||||||
out_train_labels: out of training labels
|
out_train_labels: out of training labels
|
||||||
|
in_train_sample_weight: a vector of weights of shape (n_train, ) that are
|
||||||
|
assigned to individual samples. If not provided, then each sample is
|
||||||
|
given unit weight. Only the LogisticRegressionAttacker and the
|
||||||
|
RandomForestAttacker support sample weights.
|
||||||
|
out_train_sample_weight: a vector of weights of shape (n_test, ) that are
|
||||||
|
assigned to individual samples. If not provided, then each sample is
|
||||||
|
given unit weight. Only the LogisticRegressionAttacker and the
|
||||||
|
RandomForestAttacker support sample weights.
|
||||||
slicing_spec: slicing specification of the attack
|
slicing_spec: slicing specification of the attack
|
||||||
attack_types: a list of attacks, each of type AttackType
|
attack_types: a list of attacks, each of type AttackType
|
||||||
|
|
||||||
|
@ -186,18 +238,25 @@ def run_attack_helper(estimator,
|
||||||
Results of the attack
|
Results of the attack
|
||||||
"""
|
"""
|
||||||
# Compute predictions and losses
|
# Compute predictions and losses
|
||||||
in_train_pred, in_train_loss = calculate_losses(estimator, in_train_input_fn,
|
in_train_pred, in_train_loss = calculate_losses(
|
||||||
in_train_labels)
|
estimator,
|
||||||
out_train_pred, out_train_loss = calculate_losses(estimator,
|
in_train_input_fn,
|
||||||
out_train_input_fn,
|
in_train_labels,
|
||||||
out_train_labels)
|
sample_weight=in_train_sample_weight)
|
||||||
|
out_train_pred, out_train_loss = calculate_losses(
|
||||||
|
estimator,
|
||||||
|
out_train_input_fn,
|
||||||
|
out_train_labels,
|
||||||
|
sample_weight=out_train_sample_weight)
|
||||||
attack_input = data_structures.AttackInputData(
|
attack_input = data_structures.AttackInputData(
|
||||||
logits_train=in_train_pred,
|
logits_train=in_train_pred,
|
||||||
logits_test=out_train_pred,
|
logits_test=out_train_pred,
|
||||||
labels_train=in_train_labels,
|
labels_train=in_train_labels,
|
||||||
labels_test=out_train_labels,
|
labels_test=out_train_labels,
|
||||||
loss_train=in_train_loss,
|
loss_train=in_train_loss,
|
||||||
loss_test=out_train_loss)
|
loss_test=out_train_loss,
|
||||||
|
sample_weight_train=in_train_sample_weight,
|
||||||
|
sample_weight_test=out_train_sample_weight)
|
||||||
results = mia.run_attacks(
|
results = mia.run_attacks(
|
||||||
attack_input, slicing_spec=slicing_spec, attack_types=attack_types)
|
attack_input, slicing_spec=slicing_spec, attack_types=attack_types)
|
||||||
return results
|
return results
|
||||||
|
|
|
@ -96,6 +96,8 @@ def main(unused_argv):
|
||||||
|
|
||||||
# Load training and test data.
|
# Load training and test data.
|
||||||
x_train, y_train, x_test, y_test = load_cifar10()
|
x_train, y_train, x_test, y_test = load_cifar10()
|
||||||
|
# Sample weights are set to `None` by default, but can be changed here.
|
||||||
|
sample_weight_train, sample_weight_test = None, None
|
||||||
|
|
||||||
# Instantiate the tf.Estimator.
|
# Instantiate the tf.Estimator.
|
||||||
classifier = tf_estimator.Estimator(
|
classifier = tf_estimator.Estimator(
|
||||||
|
@ -142,7 +144,8 @@ def main(unused_argv):
|
||||||
|
|
||||||
print('End of training attack')
|
print('End of training attack')
|
||||||
attack_results = run_attack_on_tf_estimator_model(
|
attack_results = run_attack_on_tf_estimator_model(
|
||||||
classifier, (x_train, y_train), (x_test, y_test),
|
classifier, (x_train, y_train, sample_weight_train),
|
||||||
|
(x_test, y_test, sample_weight_test),
|
||||||
input_fn_constructor,
|
input_fn_constructor,
|
||||||
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
slicing_spec=SlicingSpec(entire_dataset=True, by_class=True),
|
||||||
attack_types=[
|
attack_types=[
|
||||||
|
|
|
@ -36,6 +36,8 @@ class UtilsTest(absltest.TestCase):
|
||||||
self.test_data = np.random.rand(self.ntest, self.ndim)
|
self.test_data = np.random.rand(self.ntest, self.ndim)
|
||||||
self.train_labels = np.random.randint(self.nclass, size=self.ntrain)
|
self.train_labels = np.random.randint(self.nclass, size=self.ntrain)
|
||||||
self.test_labels = np.random.randint(self.nclass, size=self.ntest)
|
self.test_labels = np.random.randint(self.nclass, size=self.ntest)
|
||||||
|
self.sample_weight_train = np.random.rand(self.ntrain)
|
||||||
|
self.sample_weight_test = np.random.rand(self.ntest)
|
||||||
|
|
||||||
# Define a simple model function
|
# Define a simple model function
|
||||||
def model_fn(features, labels, mode):
|
def model_fn(features, labels, mode):
|
||||||
|
@ -70,6 +72,14 @@ class UtilsTest(absltest.TestCase):
|
||||||
self.assertEqual(pred.shape, (self.ntrain, self.nclass))
|
self.assertEqual(pred.shape, (self.ntrain, self.nclass))
|
||||||
self.assertEqual(loss.shape, (self.ntrain,))
|
self.assertEqual(loss.shape, (self.ntrain,))
|
||||||
|
|
||||||
|
pred, loss = tf_estimator_evaluation.calculate_losses(
|
||||||
|
self.classifier,
|
||||||
|
self.input_fn_train,
|
||||||
|
self.train_labels,
|
||||||
|
sample_weight=self.sample_weight_train)
|
||||||
|
self.assertEqual(pred.shape, (self.ntrain, self.nclass))
|
||||||
|
self.assertEqual(loss.shape, (self.ntrain,))
|
||||||
|
|
||||||
pred, loss = tf_estimator_evaluation.calculate_losses(
|
pred, loss = tf_estimator_evaluation.calculate_losses(
|
||||||
self.classifier, self.input_fn_test, self.test_labels)
|
self.classifier, self.input_fn_test, self.test_labels)
|
||||||
self.assertEqual(pred.shape, (self.ntest, self.nclass))
|
self.assertEqual(pred.shape, (self.ntest, self.nclass))
|
||||||
|
@ -83,6 +93,27 @@ class UtilsTest(absltest.TestCase):
|
||||||
self.input_fn_test,
|
self.input_fn_test,
|
||||||
self.train_labels,
|
self.train_labels,
|
||||||
self.test_labels,
|
self.test_labels,
|
||||||
|
self.sample_weight_train,
|
||||||
|
self.sample_weight_test,
|
||||||
|
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
||||||
|
self.assertIsInstance(results, data_structures.AttackResults)
|
||||||
|
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
||||||
|
results)
|
||||||
|
self.assertLen(att_types, 2)
|
||||||
|
self.assertLen(att_slices, 2)
|
||||||
|
self.assertLen(att_metrics, 2)
|
||||||
|
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
||||||
|
|
||||||
|
def test_run_attack_helper_with_sample_weights(self):
|
||||||
|
"""Test the attack."""
|
||||||
|
results = tf_estimator_evaluation.run_attack_helper(
|
||||||
|
self.classifier,
|
||||||
|
self.input_fn_train,
|
||||||
|
self.input_fn_test,
|
||||||
|
self.train_labels,
|
||||||
|
self.test_labels,
|
||||||
|
in_train_sample_weight=self.sample_weight_train,
|
||||||
|
out_train_sample_weight=self.sample_weight_test,
|
||||||
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
||||||
self.assertIsInstance(results, data_structures.AttackResults)
|
self.assertIsInstance(results, data_structures.AttackResults)
|
||||||
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
||||||
|
@ -112,6 +143,27 @@ class UtilsTest(absltest.TestCase):
|
||||||
self.assertLen(att_metrics, 2)
|
self.assertLen(att_metrics, 2)
|
||||||
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
||||||
|
|
||||||
|
def test_run_attack_on_tf_estimator_model_with_sample_weights(self):
|
||||||
|
"""Test the attack on the final models."""
|
||||||
|
|
||||||
|
def input_fn_constructor(x, y):
|
||||||
|
return tf_compat_v1_estimator.inputs.numpy_input_fn(
|
||||||
|
x={'x': x}, y=y, shuffle=False)
|
||||||
|
|
||||||
|
results = tf_estimator_evaluation.run_attack_on_tf_estimator_model(
|
||||||
|
self.classifier,
|
||||||
|
(self.train_data, self.train_labels, self.sample_weight_train),
|
||||||
|
(self.test_data, self.test_labels),
|
||||||
|
input_fn_constructor,
|
||||||
|
attack_types=[data_structures.AttackType.THRESHOLD_ATTACK])
|
||||||
|
self.assertIsInstance(results, data_structures.AttackResults)
|
||||||
|
att_types, att_slices, att_metrics, att_values = data_structures.get_flattened_attack_metrics(
|
||||||
|
results)
|
||||||
|
self.assertLen(att_types, 2)
|
||||||
|
self.assertLen(att_slices, 2)
|
||||||
|
self.assertLen(att_metrics, 2)
|
||||||
|
self.assertLen(att_values, 3) # Attacker Advantage, AUC, PPV
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -21,8 +21,20 @@ import numpy as np
|
||||||
from scipy import special
|
from scipy import special
|
||||||
|
|
||||||
|
|
||||||
|
class LossFunction(enum.Enum):
|
||||||
|
"""An enum that defines loss function."""
|
||||||
|
CROSS_ENTROPY = 'cross_entropy'
|
||||||
|
SQUARED = 'squared'
|
||||||
|
|
||||||
|
|
||||||
|
LossFunctionCallable = Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]],
|
||||||
|
np.ndarray]
|
||||||
|
LossFunctionType = Union[LossFunctionCallable, LossFunction, str]
|
||||||
|
|
||||||
|
|
||||||
def log_loss(labels: np.ndarray,
|
def log_loss(labels: np.ndarray,
|
||||||
pred: np.ndarray,
|
pred: np.ndarray,
|
||||||
|
sample_weight: Optional[np.ndarray] = None,
|
||||||
from_logits=False,
|
from_logits=False,
|
||||||
small_value=1e-8) -> np.ndarray:
|
small_value=1e-8) -> np.ndarray:
|
||||||
"""Computes the per-example cross entropy loss.
|
"""Computes the per-example cross entropy loss.
|
||||||
|
@ -35,6 +47,10 @@ def log_loss(labels: np.ndarray,
|
||||||
num_classes) and pred[i] is the logits or probability vector of the i-th
|
num_classes) and pred[i] is the logits or probability vector of the i-th
|
||||||
sample. For binary logistic loss, the shape should be (num_samples,) and
|
sample. For binary logistic loss, the shape should be (num_samples,) and
|
||||||
pred[i] is the probability of the positive class.
|
pred[i] is the probability of the positive class.
|
||||||
|
sample_weight: a vector of weights of shape (num_samples, ) that are
|
||||||
|
assigned to individual samples. If not provided, then each sample is
|
||||||
|
given unit weight. Only the LogisticRegressionAttacker and the
|
||||||
|
RandomForestAttacker support sample weights.
|
||||||
from_logits: whether `pred` is logits or probability vector.
|
from_logits: whether `pred` is logits or probability vector.
|
||||||
small_value: a scalar. np.log can become -inf if the probability is too
|
small_value: a scalar. np.log can become -inf if the probability is too
|
||||||
close to 0, so the probability is clipped below by small_value.
|
close to 0, so the probability is clipped below by small_value.
|
||||||
|
@ -46,6 +62,15 @@ def log_loss(labels: np.ndarray,
|
||||||
raise ValueError('labels and pred should have the same number of examples,',
|
raise ValueError('labels and pred should have the same number of examples,',
|
||||||
f'but got {labels.shape[0]} and {pred.shape[0]}.')
|
f'but got {labels.shape[0]} and {pred.shape[0]}.')
|
||||||
classes = np.unique(labels)
|
classes = np.unique(labels)
|
||||||
|
if sample_weight is None:
|
||||||
|
# If sample weights are not provided, set them to 1.0.
|
||||||
|
sample_weight = 1.0
|
||||||
|
else:
|
||||||
|
if np.shape(sample_weight)[0] != np.shape(labels)[0]:
|
||||||
|
# Number of elements should be the same.
|
||||||
|
raise ValueError(
|
||||||
|
'Expected sample weights to have the same length as the labels, '
|
||||||
|
f'received {np.shape(sample_weight)[0]} and {np.shape(labels)[0]}.')
|
||||||
|
|
||||||
# Binary logistic loss
|
# Binary logistic loss
|
||||||
if pred.size == pred.shape[0]:
|
if pred.size == pred.shape[0]:
|
||||||
|
@ -59,22 +84,29 @@ def log_loss(labels: np.ndarray,
|
||||||
indices_class0 = (labels == 0)
|
indices_class0 = (labels == 0)
|
||||||
prob_correct = np.copy(pred)
|
prob_correct = np.copy(pred)
|
||||||
prob_correct[indices_class0] = 1 - prob_correct[indices_class0]
|
prob_correct[indices_class0] = 1 - prob_correct[indices_class0]
|
||||||
return -np.log(np.maximum(prob_correct, small_value))
|
return -np.log(np.maximum(prob_correct, small_value)) * sample_weight
|
||||||
|
|
||||||
# Multi-class categorical cross entropy loss
|
# Multi-class categorical cross entropy loss
|
||||||
if classes.min() < 0 or classes.max() >= pred.shape[1]:
|
if classes.min() < 0 or classes.max() >= pred.shape[1]:
|
||||||
raise ValueError('labels should be in the range [0, num_classes-1].')
|
raise ValueError('labels should be in the range [0, num_classes-1].')
|
||||||
if from_logits:
|
if from_logits:
|
||||||
pred = special.softmax(pred, axis=-1)
|
pred = special.softmax(pred, axis=-1)
|
||||||
return -np.log(np.maximum(pred[range(labels.size), labels], small_value))
|
return (-np.log(np.maximum(pred[range(labels.size), labels], small_value)) *
|
||||||
|
sample_weight)
|
||||||
|
|
||||||
|
|
||||||
def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
|
def squared_loss(y_true: np.ndarray,
|
||||||
|
y_pred: np.ndarray,
|
||||||
|
sample_weight: Optional[np.ndarray] = None) -> np.ndarray:
|
||||||
"""Computes the per-example squared loss.
|
"""Computes the per-example squared loss.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y_true: numpy array of shape (num_samples,) representing the true labels.
|
y_true: numpy array of shape (num_samples,) representing the true labels.
|
||||||
y_pred: numpy array of shape (num_samples,) representing the predictions.
|
y_pred: numpy array of shape (num_samples,) representing the predictions.
|
||||||
|
sample_weight: a vector of weights of shape (num_samples, ) that are
|
||||||
|
assigned to individual samples. If not provided, then each sample is
|
||||||
|
given unit weight. Only the LogisticRegressionAttacker and the
|
||||||
|
RandomForestAttacker support sample weights.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the squared loss of each sample.
|
the squared loss of each sample.
|
||||||
|
@ -93,11 +125,15 @@ def squared_loss(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
|
||||||
raise ValueError('Squared loss expects the labels and predictions to have '
|
raise ValueError('Squared loss expects the labels and predictions to have '
|
||||||
'shape (num_examples, ), but after np.squeeze, the shapes '
|
'shape (num_examples, ), but after np.squeeze, the shapes '
|
||||||
'are %s and %s.' % (y_true.shape, y_pred.shape))
|
'are %s and %s.' % (y_true.shape, y_pred.shape))
|
||||||
return (y_true - y_pred)**2
|
if sample_weight is None:
|
||||||
|
# If sample weights are not provided, set them to 1.0.
|
||||||
|
sample_weight = 1.0
|
||||||
|
return sample_weight * (y_true - y_pred)**2
|
||||||
|
|
||||||
|
|
||||||
def multilabel_bce_loss(labels: np.ndarray,
|
def multilabel_bce_loss(labels: np.ndarray,
|
||||||
pred: np.ndarray,
|
pred: np.ndarray,
|
||||||
|
sample_weight: Optional[np.ndarray] = None,
|
||||||
from_logits=False,
|
from_logits=False,
|
||||||
small_value=1e-8) -> np.ndarray:
|
small_value=1e-8) -> np.ndarray:
|
||||||
"""Computes the per-multi-label-example cross entropy loss.
|
"""Computes the per-multi-label-example cross entropy loss.
|
||||||
|
@ -108,6 +144,10 @@ def multilabel_bce_loss(labels: np.ndarray,
|
||||||
of the vector is one of {0, 1}.
|
of the vector is one of {0, 1}.
|
||||||
pred: numpy array of shape (num_samples, num_classes). pred[i] is the
|
pred: numpy array of shape (num_samples, num_classes). pred[i] is the
|
||||||
logits or probability vector of the i-th sample.
|
logits or probability vector of the i-th sample.
|
||||||
|
sample_weight: a vector of weights of shape (num_samples, ) that are
|
||||||
|
assigned to individual samples. If not provided, then each sample is
|
||||||
|
given unit weight. Only the LogisticRegressionAttacker and the
|
||||||
|
RandomForestAttacker support sample weights.
|
||||||
from_logits: whether `pred` is logits or probability vector.
|
from_logits: whether `pred` is logits or probability vector.
|
||||||
small_value: a scalar. np.log can become -inf if the probability is too
|
small_value: a scalar. np.log can become -inf if the probability is too
|
||||||
close to 0, so the probability is clipped below by small_value.
|
close to 0, so the probability is clipped below by small_value.
|
||||||
|
@ -132,19 +172,21 @@ def multilabel_bce_loss(labels: np.ndarray,
|
||||||
if not from_logits and ((pred < 0.0) | (pred > 1.0)).any():
|
if not from_logits and ((pred < 0.0) | (pred > 1.0)).any():
|
||||||
raise ValueError(('Prediction probabilities are not in [0, 1] and '
|
raise ValueError(('Prediction probabilities are not in [0, 1] and '
|
||||||
'`from_logits` is set to False.'))
|
'`from_logits` is set to False.'))
|
||||||
|
if sample_weight is None:
|
||||||
|
# If sample weights are not provided, set them to 1.0.
|
||||||
|
sample_weight = 1.0
|
||||||
|
if isinstance(sample_weight, list):
|
||||||
|
sample_weight = np.asarray(sample_weight)
|
||||||
|
if isinstance(sample_weight, np.ndarray) and (sample_weight.ndim == 1):
|
||||||
|
# NOMUTANTS--np.reshape(X, (-1, 1)) == np.reshape(X, (-N, 1)), N >=1.
|
||||||
|
sample_weight = np.reshape(sample_weight, (-1, 1))
|
||||||
|
|
||||||
# Multi-class multi-label binary cross entropy loss
|
# Multi-class multi-label binary cross entropy loss
|
||||||
if from_logits:
|
if from_logits:
|
||||||
pred = special.expit(pred)
|
pred = special.expit(pred)
|
||||||
bce = labels * np.log(pred + small_value)
|
bce = labels * np.log(pred + small_value)
|
||||||
bce += (1 - labels) * np.log(1 - pred + small_value)
|
bce += (1 - labels) * np.log(1 - pred + small_value)
|
||||||
return -bce
|
return -bce * sample_weight
|
||||||
|
|
||||||
|
|
||||||
class LossFunction(enum.Enum):
|
|
||||||
"""An enum that defines loss function."""
|
|
||||||
CROSS_ENTROPY = 'cross_entropy'
|
|
||||||
SQUARED = 'squared'
|
|
||||||
|
|
||||||
|
|
||||||
def string_to_loss_function(string: str):
|
def string_to_loss_function(string: str):
|
||||||
|
@ -157,12 +199,15 @@ def string_to_loss_function(string: str):
|
||||||
raise ValueError(f'{string} is not a valid loss function name.')
|
raise ValueError(f'{string} is not a valid loss function name.')
|
||||||
|
|
||||||
|
|
||||||
def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
def get_loss(
|
||||||
logits: Optional[np.ndarray], probs: Optional[np.ndarray],
|
loss: Optional[np.ndarray],
|
||||||
loss_function: Union[Callable[[np.ndarray, np.ndarray],
|
labels: Optional[np.ndarray],
|
||||||
np.ndarray], LossFunction, str],
|
logits: Optional[np.ndarray],
|
||||||
loss_function_using_logits: Optional[bool],
|
probs: Optional[np.ndarray],
|
||||||
multilabel_data: Optional[bool]) -> Optional[np.ndarray]:
|
loss_function: LossFunctionCallable,
|
||||||
|
loss_function_using_logits: Optional[bool],
|
||||||
|
multilabel_data: Optional[bool],
|
||||||
|
sample_weight: Optional[np.ndarray] = None) -> Optional[np.ndarray]:
|
||||||
"""Calculates (if needed) losses.
|
"""Calculates (if needed) losses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -176,6 +221,10 @@ def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||||
loss_function_using_logits: if `loss_function` expects `logits` or
|
loss_function_using_logits: if `loss_function` expects `logits` or
|
||||||
`probs`.
|
`probs`.
|
||||||
multilabel_data: if the data is from a multilabel classification problem.
|
multilabel_data: if the data is from a multilabel classification problem.
|
||||||
|
sample_weight: a vector of weights of shape (num_samples, ) that are
|
||||||
|
assigned to individual samples. If not provided, then each sample is
|
||||||
|
given unit weight. Only the LogisticRegressionAttacker and the
|
||||||
|
RandomForestAttacker support sample weights.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loss (or None if neither the loss nor the labels are present).
|
Loss (or None if neither the loss nor the labels are present).
|
||||||
|
@ -195,12 +244,13 @@ def get_loss(loss: Optional[np.ndarray], labels: Optional[np.ndarray],
|
||||||
loss_function = string_to_loss_function(loss_function)
|
loss_function = string_to_loss_function(loss_function)
|
||||||
if loss_function == LossFunction.CROSS_ENTROPY:
|
if loss_function == LossFunction.CROSS_ENTROPY:
|
||||||
if multilabel_data:
|
if multilabel_data:
|
||||||
loss = multilabel_bce_loss(labels, predictions,
|
loss = multilabel_bce_loss(labels, predictions, sample_weight,
|
||||||
loss_function_using_logits)
|
loss_function_using_logits)
|
||||||
else:
|
else:
|
||||||
loss = log_loss(labels, predictions, loss_function_using_logits)
|
loss = log_loss(labels, predictions, sample_weight,
|
||||||
|
loss_function_using_logits)
|
||||||
elif loss_function == LossFunction.SQUARED:
|
elif loss_function == LossFunction.SQUARED:
|
||||||
loss = squared_loss(labels, predictions)
|
loss = squared_loss(labels, predictions, sample_weight)
|
||||||
else:
|
else:
|
||||||
loss = loss_function(labels, predictions)
|
loss = loss_function(labels, predictions, sample_weight)
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -89,6 +89,15 @@ class TestLogLoss(parameterized.TestCase):
|
||||||
loss = utils.log_loss(labels, logits, from_logits=True)
|
loss = utils.log_loss(labels, logits, from_logits=True)
|
||||||
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
|
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
|
||||||
|
|
||||||
|
def test_log_loss_from_logits_with_sample_weights(self):
|
||||||
|
logits = np.array([[1, 2, 0, -1], [1, 2, 0, -1], [-1, 3, 0, 0]])
|
||||||
|
labels = np.array([0, 3, 1])
|
||||||
|
sample_weight = np.array([1.0, 0.0, 0.5])
|
||||||
|
expected_loss = np.array([1.4401897, 0.0, 0.05572139])
|
||||||
|
|
||||||
|
loss = utils.log_loss(labels, logits, sample_weight, from_logits=True)
|
||||||
|
np.testing.assert_allclose(expected_loss, loss, atol=1e-7)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('label0', 0,
|
('label0', 0,
|
||||||
np.array([
|
np.array([
|
||||||
|
@ -161,6 +170,14 @@ class TestSquaredLoss(parameterized.TestCase):
|
||||||
y_pred = np.array([4, 3, 2])
|
y_pred = np.array([4, 3, 2])
|
||||||
self.assertRaises(ValueError, utils.squared_loss, y_true, y_pred)
|
self.assertRaises(ValueError, utils.squared_loss, y_true, y_pred)
|
||||||
|
|
||||||
|
def test_squared_loss_with_sample_weights(self):
|
||||||
|
y_true = np.array([1, 2, 3, 4.])
|
||||||
|
y_pred = np.array([4, 3, 2, 1.])
|
||||||
|
sample_weight = np.array([1.0, 0.0, 0.5, 1.0])
|
||||||
|
expected_loss = np.array([9, 0, 0.5, 9.])
|
||||||
|
loss = utils.squared_loss(y_true, y_pred, sample_weight)
|
||||||
|
np.testing.assert_allclose(loss, expected_loss, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
class TestMultilabelBCELoss(parameterized.TestCase):
|
class TestMultilabelBCELoss(parameterized.TestCase):
|
||||||
|
|
||||||
|
@ -186,6 +203,16 @@ class TestMultilabelBCELoss(parameterized.TestCase):
|
||||||
loss = utils.multilabel_bce_loss(label, pred, from_logits=from_logits)
|
loss = utils.multilabel_bce_loss(label, pred, from_logits=from_logits)
|
||||||
np.testing.assert_allclose(loss, expected_losses, atol=1e-6)
|
np.testing.assert_allclose(loss, expected_losses, atol=1e-6)
|
||||||
|
|
||||||
|
def test_multilabel_bce_loss_with_sample_weights(self):
|
||||||
|
label = np.array([[0, 1, 0], [1, 1, 0]])
|
||||||
|
pred = np.array([[-1.2, -0.3, 2.1], [0.0, 0.5, 1.5]])
|
||||||
|
sample_weight = np.array([1.0, 0.5])
|
||||||
|
expected_loss = np.array([[0.26328245, 0.85435522, 2.21551943],
|
||||||
|
[0.34657358, 0.23703848, 0.85070661]])
|
||||||
|
loss = utils.multilabel_bce_loss(
|
||||||
|
label, pred, sample_weight=sample_weight, from_logits=True)
|
||||||
|
np.testing.assert_allclose(loss, expected_loss, atol=1e-6)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('from_logits_true_and_incorrect_values_example1',
|
('from_logits_true_and_incorrect_values_example1',
|
||||||
np.array([[0, 1, 1], [1, 1, 0]
|
np.array([[0, 1, 1], [1, 1, 0]
|
||||||
|
@ -216,7 +243,7 @@ class TestMultilabelBCELoss(parameterized.TestCase):
|
||||||
)
|
)
|
||||||
def test_multilabel_bce_loss_raises(self, label, pred, from_logits, regex):
|
def test_multilabel_bce_loss_raises(self, label, pred, from_logits, regex):
|
||||||
self.assertRaisesRegex(ValueError, regex, utils.multilabel_bce_loss, label,
|
self.assertRaisesRegex(ValueError, regex, utils.multilabel_bce_loss, label,
|
||||||
pred, from_logits)
|
pred, None, from_logits)
|
||||||
|
|
||||||
|
|
||||||
class TestGetLoss(parameterized.TestCase):
|
class TestGetLoss(parameterized.TestCase):
|
||||||
|
|
Loading…
Reference in a new issue