Add number of examples in the attack result.
PiperOrigin-RevId: 348812773
This commit is contained in:
parent
6460c3feb8
commit
c8a26ce7be
8 changed files with 139 additions and 51 deletions
|
@ -57,8 +57,8 @@ Then, we can view the attack results by:
|
|||
print(attacks_result.summary())
|
||||
# Example output:
|
||||
# -> Best-performing attacks over all slices
|
||||
# THRESHOLD_ATTACK achieved an AUC of 0.60 on slice Entire dataset
|
||||
# THRESHOLD_ATTACK achieved an advantage of 0.22 on slice Entire dataset
|
||||
# THRESHOLD_ATTACK (with 50000 training and 10000 test examples) achieved an AUC of 0.59 on slice Entire dataset
|
||||
# THRESHOLD_ATTACK (with 50000 training and 10000 test examples) achieved an advantage of 0.20 on slice Entire dataset
|
||||
```
|
||||
|
||||
### Advanced usage / Other codelabs
|
||||
|
|
|
@ -14,12 +14,12 @@
|
|||
|
||||
# Lint as: python3
|
||||
"""Data structures representing attack inputs, configuration, outputs."""
|
||||
import collections
|
||||
import enum
|
||||
import glob
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Iterable, Union
|
||||
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -417,6 +417,10 @@ class RocCurve:
|
|||
])
|
||||
|
||||
|
||||
# (no. of training examples, no. of test examples) for the test.
|
||||
DataSize = collections.namedtuple('DataSize', 'ntrain ntest')
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleAttackResult:
|
||||
"""Results from running a single attack."""
|
||||
|
@ -424,6 +428,8 @@ class SingleAttackResult:
|
|||
# Data slice this result was calculated for.
|
||||
slice_spec: SingleSliceSpec
|
||||
|
||||
# (no. of training examples, no. of test examples) for the test.
|
||||
data_size: DataSize
|
||||
attack_type: AttackType
|
||||
|
||||
# NOTE: roc_curve could theoretically be derived from membership scores.
|
||||
|
@ -462,6 +468,8 @@ class SingleAttackResult:
|
|||
return '\n'.join([
|
||||
'SingleAttackResult(',
|
||||
' SliceSpec: %s' % str(self.slice_spec),
|
||||
' DataSize: (ntrain=%d, ntest=%d)' % (self.data_size.ntrain,
|
||||
self.data_size.ntest),
|
||||
' AttackType: %s' % str(self.attack_type),
|
||||
' AUC: %.2f' % self.get_auc(),
|
||||
' Attacker advantage: %.2f' % self.get_attacker_advantage(), ')'
|
||||
|
@ -593,6 +601,8 @@ class AttackResultsDFColumns(enum.Enum):
|
|||
"""Columns for the Pandas DataFrame that stores AttackResults metrics."""
|
||||
SLICE_FEATURE = 'slice feature'
|
||||
SLICE_VALUE = 'slice value'
|
||||
DATA_SIZE_TRAIN = 'train size'
|
||||
DATA_SIZE_TEST = 'test size'
|
||||
ATTACK_TYPE = 'attack type'
|
||||
|
||||
def __str__(self):
|
||||
|
@ -611,6 +621,8 @@ class AttackResults:
|
|||
"""Returns all metrics as a Pandas DataFrame."""
|
||||
slice_features = []
|
||||
slice_values = []
|
||||
data_size_train = []
|
||||
data_size_test = []
|
||||
attack_types = []
|
||||
advantages = []
|
||||
aucs = []
|
||||
|
@ -623,6 +635,8 @@ class AttackResults:
|
|||
slice_feature, slice_value = slice_spec.feature.value, slice_spec.value
|
||||
slice_features.append(str(slice_feature))
|
||||
slice_values.append(str(slice_value))
|
||||
data_size_train.append(attack_result.data_size.ntrain)
|
||||
data_size_test.append(attack_result.data_size.ntest)
|
||||
attack_types.append(str(attack_result.attack_type))
|
||||
advantages.append(float(attack_result.get_attacker_advantage()))
|
||||
aucs.append(float(attack_result.get_auc()))
|
||||
|
@ -630,6 +644,8 @@ class AttackResults:
|
|||
df = pd.DataFrame({
|
||||
str(AttackResultsDFColumns.SLICE_FEATURE): slice_features,
|
||||
str(AttackResultsDFColumns.SLICE_VALUE): slice_values,
|
||||
str(AttackResultsDFColumns.DATA_SIZE_TRAIN): data_size_train,
|
||||
str(AttackResultsDFColumns.DATA_SIZE_TEST): data_size_test,
|
||||
str(AttackResultsDFColumns.ATTACK_TYPE): attack_types,
|
||||
str(PrivacyMetric.ATTACKER_ADVANTAGE): advantages,
|
||||
str(PrivacyMetric.AUC): aucs
|
||||
|
@ -653,29 +669,42 @@ class AttackResults:
|
|||
max_auc_result_all = self.get_result_with_max_attacker_advantage()
|
||||
summary.append('Best-performing attacks over all slices')
|
||||
summary.append(
|
||||
' %s achieved an AUC of %.2f on slice %s' %
|
||||
(max_auc_result_all.attack_type, max_auc_result_all.get_auc(),
|
||||
' %s (with %d training and %d test examples) achieved an AUC of %.2f on slice %s'
|
||||
% (max_auc_result_all.attack_type,
|
||||
max_auc_result_all.data_size.ntrain,
|
||||
max_auc_result_all.data_size.ntest,
|
||||
max_auc_result_all.get_auc(),
|
||||
max_auc_result_all.slice_spec))
|
||||
|
||||
max_advantage_result_all = self.get_result_with_max_attacker_advantage()
|
||||
summary.append(' %s achieved an advantage of %.2f on slice %s' %
|
||||
(max_advantage_result_all.attack_type,
|
||||
summary.append(
|
||||
' %s (with %d training and %d test examples) achieved an advantage of %.2f on slice %s'
|
||||
% (max_advantage_result_all.attack_type,
|
||||
max_advantage_result_all.data_size.ntrain,
|
||||
max_advantage_result_all.data_size.ntest,
|
||||
max_advantage_result_all.get_attacker_advantage(),
|
||||
max_advantage_result_all.slice_spec))
|
||||
|
||||
slice_dict = self._group_results_by_slice()
|
||||
|
||||
if len(slice_dict.keys()) > 1 and by_slices:
|
||||
if by_slices and len(slice_dict.keys()) > 1:
|
||||
for slice_str in slice_dict:
|
||||
results = slice_dict[slice_str]
|
||||
summary.append('\nBest-performing attacks over slice: \"%s\"' %
|
||||
slice_str)
|
||||
max_auc_result = results.get_result_with_max_auc()
|
||||
summary.append(' %s achieved an AUC of %.2f' %
|
||||
(max_auc_result.attack_type, max_auc_result.get_auc()))
|
||||
summary.append(
|
||||
' %s (with %d training and %d test examples) achieved an AUC of %.2f'
|
||||
% (max_auc_result.attack_type,
|
||||
max_auc_result.data_size.ntrain,
|
||||
max_auc_result.data_size.ntest,
|
||||
max_auc_result.get_auc()))
|
||||
max_advantage_result = results.get_result_with_max_attacker_advantage()
|
||||
summary.append(' %s achieved an advantage of %.2f' %
|
||||
(max_advantage_result.attack_type,
|
||||
summary.append(
|
||||
' %s (with %d training and %d test examples) achieved an advantage of %.2f'
|
||||
% (max_advantage_result.attack_type,
|
||||
max_advantage_result.data_size.ntrain,
|
||||
max_auc_result.data_size.ntest,
|
||||
max_advantage_result.get_attacker_advantage()))
|
||||
|
||||
return '\n'.join(summary)
|
||||
|
|
|
@ -25,6 +25,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
|
@ -200,7 +201,8 @@ class SingleAttackResultTest(absltest.TestCase):
|
|||
result = SingleAttackResult(
|
||||
roc_curve=roc,
|
||||
slice_spec=SingleSliceSpec(None),
|
||||
attack_type=AttackType.THRESHOLD_ATTACK)
|
||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||
data_size=DataSize(ntrain=1, ntest=1))
|
||||
|
||||
self.assertEqual(result.get_auc(), 0.5)
|
||||
|
||||
|
@ -214,7 +216,8 @@ class SingleAttackResultTest(absltest.TestCase):
|
|||
result = SingleAttackResult(
|
||||
roc_curve=roc,
|
||||
slice_spec=SingleSliceSpec(None),
|
||||
attack_type=AttackType.THRESHOLD_ATTACK)
|
||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||
data_size=DataSize(ntrain=1, ntest=1))
|
||||
|
||||
self.assertEqual(result.get_attacker_advantage(), 0.0)
|
||||
|
||||
|
@ -249,7 +252,8 @@ class AttackResultsCollectionTest(absltest.TestCase):
|
|||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 0.5, 1.0]),
|
||||
fpr=np.array([0.0, 0.5, 1.0]),
|
||||
thresholds=np.array([0, 1, 2])))
|
||||
thresholds=np.array([0, 1, 2])),
|
||||
data_size=DataSize(ntrain=1, ntest=1))
|
||||
|
||||
self.results_epoch_10 = AttackResults(
|
||||
single_attack_results=[self.some_attack_result],
|
||||
|
@ -308,7 +312,8 @@ class AttackResultsTest(absltest.TestCase):
|
|||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 1.0, 1.0]),
|
||||
fpr=np.array([1.0, 1.0, 0.0]),
|
||||
thresholds=np.array([0, 1, 2])))
|
||||
thresholds=np.array([0, 1, 2])),
|
||||
data_size=DataSize(ntrain=1, ntest=1))
|
||||
|
||||
# ROC curve of a random classifier
|
||||
self.random_classifier_result = SingleAttackResult(
|
||||
|
@ -317,7 +322,8 @@ class AttackResultsTest(absltest.TestCase):
|
|||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 0.5, 1.0]),
|
||||
fpr=np.array([0.0, 0.5, 1.0]),
|
||||
thresholds=np.array([0, 1, 2])))
|
||||
thresholds=np.array([0, 1, 2])),
|
||||
data_size=DataSize(ntrain=1, ntest=1))
|
||||
|
||||
def test_get_result_with_max_auc_first(self):
|
||||
results = AttackResults(
|
||||
|
@ -349,16 +355,20 @@ class AttackResultsTest(absltest.TestCase):
|
|||
self.assertEqual(
|
||||
results.summary(by_slices=True),
|
||||
'Best-performing attacks over all slices\n' +
|
||||
' THRESHOLD_ATTACK achieved an AUC of 1.00 ' +
|
||||
'on slice CORRECTLY_CLASSIFIED=True\n' +
|
||||
' THRESHOLD_ATTACK achieved an advantage of 1.00 ' +
|
||||
'on slice CORRECTLY_CLASSIFIED=True\n\n' +
|
||||
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' +
|
||||
' AUC of 1.00 on slice CORRECTLY_CLASSIFIED=True\n' +
|
||||
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' +
|
||||
' advantage of 1.00 on slice CORRECTLY_CLASSIFIED=True\n\n' +
|
||||
'Best-performing attacks over slice: "CORRECTLY_CLASSIFIED=True"\n' +
|
||||
' THRESHOLD_ATTACK achieved an AUC of 1.00\n' +
|
||||
' THRESHOLD_ATTACK achieved an advantage of 1.00\n\n' +
|
||||
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' +
|
||||
' AUC of 1.00\n' +
|
||||
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' +
|
||||
' advantage of 1.00\n\n' +
|
||||
'Best-performing attacks over slice: "Entire dataset"\n' +
|
||||
' THRESHOLD_ATTACK achieved an AUC of 0.50\n' +
|
||||
' THRESHOLD_ATTACK achieved an advantage of 0.00')
|
||||
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' +
|
||||
' AUC of 0.50\n' +
|
||||
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' +
|
||||
' advantage of 0.00')
|
||||
|
||||
def test_summary_without_slices(self):
|
||||
results = AttackResults(
|
||||
|
@ -366,10 +376,10 @@ class AttackResultsTest(absltest.TestCase):
|
|||
self.assertEqual(
|
||||
results.summary(by_slices=False),
|
||||
'Best-performing attacks over all slices\n' +
|
||||
' THRESHOLD_ATTACK achieved an AUC of 1.00 ' +
|
||||
'on slice CORRECTLY_CLASSIFIED=True\n' +
|
||||
' THRESHOLD_ATTACK achieved an advantage of 1.00 ' +
|
||||
'on slice CORRECTLY_CLASSIFIED=True')
|
||||
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' +
|
||||
' AUC of 1.00 on slice CORRECTLY_CLASSIFIED=True\n' +
|
||||
' THRESHOLD_ATTACK (with 1 training and 1 test examples) achieved an' +
|
||||
' advantage of 1.00 on slice CORRECTLY_CLASSIFIED=True')
|
||||
|
||||
def test_save_load(self):
|
||||
results = AttackResults(
|
||||
|
@ -391,6 +401,8 @@ class AttackResultsTest(absltest.TestCase):
|
|||
df_expected = pd.DataFrame({
|
||||
'slice feature': ['correctly_classified', 'Entire dataset'],
|
||||
'slice value': ['True', ''],
|
||||
'train size': [1, 1],
|
||||
'test size': [1, 1],
|
||||
'attack type': ['THRESHOLD_ATTACK', 'THRESHOLD_ATTACK'],
|
||||
'Attacker advantage': [1.0, 0.0],
|
||||
'AUC': [1.0, 0.5]
|
||||
|
|
|
@ -27,6 +27,7 @@ from tensorflow_privacy.privacy.membership_inference_attack import models
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import MembershipProbabilityResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||
PrivacyReportMetadata
|
||||
|
@ -86,14 +87,15 @@ def _run_trained_attack(attack_input: AttackInputData,
|
|||
|
||||
return SingleAttackResult(
|
||||
slice_spec=_get_slice_spec(attack_input),
|
||||
data_size=prepared_attacker_data.data_size,
|
||||
attack_type=attack_type,
|
||||
roc_curve=roc_curve)
|
||||
|
||||
|
||||
def _run_threshold_attack(attack_input: AttackInputData):
|
||||
ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size()
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
np.concatenate((np.zeros(attack_input.get_train_size()),
|
||||
np.ones(attack_input.get_test_size()))),
|
||||
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
|
||||
np.concatenate(
|
||||
(attack_input.get_loss_train(), attack_input.get_loss_test())))
|
||||
|
||||
|
@ -101,6 +103,7 @@ def _run_threshold_attack(attack_input: AttackInputData):
|
|||
|
||||
return SingleAttackResult(
|
||||
slice_spec=_get_slice_spec(attack_input),
|
||||
data_size=DataSize(ntrain=ntrain, ntest=ntest),
|
||||
attack_type=AttackType.THRESHOLD_ATTACK,
|
||||
membership_scores_train=-attack_input.get_loss_train(),
|
||||
membership_scores_test=-attack_input.get_loss_test(),
|
||||
|
@ -108,9 +111,9 @@ def _run_threshold_attack(attack_input: AttackInputData):
|
|||
|
||||
|
||||
def _run_threshold_entropy_attack(attack_input: AttackInputData):
|
||||
ntrain, ntest = attack_input.get_train_size(), attack_input.get_test_size()
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
np.concatenate((np.zeros(attack_input.get_train_size()),
|
||||
np.ones(attack_input.get_test_size()))),
|
||||
np.concatenate((np.zeros(ntrain), np.ones(ntest))),
|
||||
np.concatenate(
|
||||
(attack_input.get_entropy_train(), attack_input.get_entropy_test())))
|
||||
|
||||
|
@ -118,6 +121,7 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData):
|
|||
|
||||
return SingleAttackResult(
|
||||
slice_spec=_get_slice_spec(attack_input),
|
||||
data_size=DataSize(ntrain=ntrain, ntest=ntest),
|
||||
attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK,
|
||||
membership_scores_train=-attack_input.get_entropy_train(),
|
||||
membership_scores_test=-attack_input.get_entropy_test(),
|
||||
|
@ -126,8 +130,27 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData):
|
|||
|
||||
def _run_attack(attack_input: AttackInputData,
|
||||
attack_type: AttackType,
|
||||
balance_attacker_training: bool = True):
|
||||
balance_attacker_training: bool = True,
|
||||
min_num_samples: int = 1):
|
||||
"""Runs membership inference attacks for specified input and type.
|
||||
|
||||
Args:
|
||||
attack_input: input data for running an attack
|
||||
attack_type: the attack to run
|
||||
balance_attacker_training: Whether the training and test sets for the
|
||||
membership inference attacker should have a balanced (roughly equal)
|
||||
number of samples from the training and test sets used to develop
|
||||
the model under attack.
|
||||
min_num_samples: minimum number of examples in either training or test data.
|
||||
|
||||
Returns:
|
||||
the attack result.
|
||||
"""
|
||||
attack_input.validate()
|
||||
if min(attack_input.get_train_size(),
|
||||
attack_input.get_test_size()) < min_num_samples:
|
||||
return None
|
||||
|
||||
if attack_type.is_trained_attack:
|
||||
return _run_trained_attack(attack_input, attack_type,
|
||||
balance_attacker_training)
|
||||
|
@ -141,7 +164,8 @@ def run_attacks(attack_input: AttackInputData,
|
|||
attack_types: Iterable[AttackType] = (
|
||||
AttackType.THRESHOLD_ATTACK,),
|
||||
privacy_report_metadata: PrivacyReportMetadata = None,
|
||||
balance_attacker_training: bool = True) -> AttackResults:
|
||||
balance_attacker_training: bool = True,
|
||||
min_num_samples: int = 1) -> AttackResults:
|
||||
"""Runs membership inference attacks on a classification model.
|
||||
|
||||
It runs attacks specified by attack_types on each attack_input slice which is
|
||||
|
@ -156,6 +180,7 @@ def run_attacks(attack_input: AttackInputData,
|
|||
membership inference attacker should have a balanced (roughly equal)
|
||||
number of samples from the training and test sets used to develop
|
||||
the model under attack.
|
||||
min_num_samples: minimum number of examples in either training or test data.
|
||||
|
||||
Returns:
|
||||
the attack result.
|
||||
|
@ -172,9 +197,11 @@ def run_attacks(attack_input: AttackInputData,
|
|||
for single_slice_spec in input_slice_specs:
|
||||
attack_input_slice = get_slice(attack_input, single_slice_spec)
|
||||
for attack_type in attack_types:
|
||||
attack_results.append(
|
||||
_run_attack(attack_input_slice, attack_type,
|
||||
balance_attacker_training))
|
||||
attack_result = _run_attack(attack_input_slice, attack_type,
|
||||
balance_attacker_training,
|
||||
min_num_samples)
|
||||
if attack_result is not None:
|
||||
attack_results.append(attack_result)
|
||||
|
||||
privacy_report_metadata = _compute_missing_privacy_report_metadata(
|
||||
privacy_report_metadata, attack_input)
|
||||
|
|
|
@ -20,6 +20,7 @@ import numpy as np
|
|||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
|
@ -123,6 +124,15 @@ class RunAttacksTest(absltest.TestCase):
|
|||
np.testing.assert_almost_equal(
|
||||
result.test_membership_probs, [0.5, 0.33, 0.33, 0, 0], decimal=2)
|
||||
|
||||
def test_run_attack_data_size(self):
|
||||
result = mia.run_attacks(
|
||||
get_test_input(100, 80), SlicingSpec(by_class=True),
|
||||
(AttackType.THRESHOLD_ATTACK,))
|
||||
self.assertEqual(result.single_attack_results[0].data_size,
|
||||
DataSize(ntrain=100, ntest=80))
|
||||
self.assertEqual(result.single_attack_results[3].data_size,
|
||||
DataSize(ntrain=20, ntest=16))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -24,6 +24,7 @@ from sklearn import neighbors
|
|||
from sklearn import neural_network
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -41,6 +42,8 @@ class AttackerData:
|
|||
# element-wise boolean array denoting if the example was part of training.
|
||||
is_training_labels_test: np.ndarray = None
|
||||
|
||||
data_size: DataSize = None
|
||||
|
||||
|
||||
def create_attacker_data(attack_input_data: AttackInputData,
|
||||
test_fraction: float = 0.25,
|
||||
|
@ -72,11 +75,11 @@ def create_attacker_data(attack_input_data: AttackInputData,
|
|||
min_size)
|
||||
attack_input_test = _sample_multidimensional_array(attack_input_test,
|
||||
min_size)
|
||||
ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0]
|
||||
|
||||
features_all = np.concatenate((attack_input_train, attack_input_test))
|
||||
|
||||
labels_all = np.concatenate(
|
||||
((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test)))))
|
||||
labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest))))
|
||||
|
||||
# Perform a train-test split
|
||||
features_train, features_test, \
|
||||
|
@ -84,7 +87,8 @@ def create_attacker_data(attack_input_data: AttackInputData,
|
|||
model_selection.train_test_split(
|
||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||
return AttackerData(features_train, is_training_labels_train, features_test,
|
||||
is_training_labels_test)
|
||||
is_training_labels_test,
|
||||
DataSize(ntrain=ntrain, ntest=ntest))
|
||||
|
||||
|
||||
def _sample_multidimensional_array(array, size):
|
||||
|
|
|
@ -22,6 +22,7 @@ from tensorflow_privacy.privacy.membership_inference_attack import privacy_repor
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||
PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
|
@ -41,7 +42,8 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 0.5, 1.0]),
|
||||
fpr=np.array([0.0, 0.5, 1.0]),
|
||||
thresholds=np.array([0, 1, 2])))
|
||||
thresholds=np.array([0, 1, 2])),
|
||||
data_size=DataSize(ntrain=1, ntest=1))
|
||||
|
||||
# Classifier that achieves an AUC of 1.0.
|
||||
self.perfect_classifier_result = SingleAttackResult(
|
||||
|
@ -50,7 +52,8 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
roc_curve=RocCurve(
|
||||
tpr=np.array([0.0, 1.0, 1.0]),
|
||||
fpr=np.array([1.0, 1.0, 0.0]),
|
||||
thresholds=np.array([0, 1, 2])))
|
||||
thresholds=np.array([0, 1, 2])),
|
||||
data_size=DataSize(ntrain=1, ntest=1))
|
||||
|
||||
self.results_epoch_0 = AttackResults(
|
||||
single_attack_results=[self.imperfect_classifier_result],
|
||||
|
|
|
@ -31,6 +31,7 @@ import tensorflow as tf
|
|||
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import DataSize
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
|
@ -293,12 +294,12 @@ def create_seq2seq_attacker_data(
|
|||
min_size)
|
||||
|
||||
features_all = np.concatenate((attack_input_train, attack_input_test))
|
||||
ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0]
|
||||
|
||||
# Reshape for classifying one-dimensional features
|
||||
features_all = features_all.reshape(-1, 1)
|
||||
|
||||
labels_all = np.concatenate(
|
||||
((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test)))))
|
||||
labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest))))
|
||||
|
||||
# Perform a train-test split
|
||||
features_train, features_test, \
|
||||
|
@ -313,7 +314,8 @@ def create_seq2seq_attacker_data(
|
|||
privacy_report_metadata.accuracy_test = accuracy_test
|
||||
|
||||
return AttackerData(features_train, is_training_labels_train, features_test,
|
||||
is_training_labels_test)
|
||||
is_training_labels_test,
|
||||
DataSize(ntrain=ntrain, ntest=ntest))
|
||||
|
||||
|
||||
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||
|
@ -362,7 +364,8 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
|||
SingleAttackResult(
|
||||
slice_spec=SingleSliceSpec(),
|
||||
attack_type=AttackType.LOGISTIC_REGRESSION,
|
||||
roc_curve=roc_curve)
|
||||
roc_curve=roc_curve,
|
||||
data_size=prepared_attacker_data.data_size)
|
||||
]
|
||||
|
||||
return AttackResults(
|
||||
|
|
Loading…
Reference in a new issue