Allows one to run a test on probabilities alone.
PiperOrigin-RevId: 409095932
This commit is contained in:
parent
9757e1bc87
commit
7c4f5bab09
1 changed files with 16 additions and 13 deletions
|
@ -15,12 +15,13 @@
|
||||||
# Lint as: python3
|
# Lint as: python3
|
||||||
"""Data structures representing attack inputs, configuration, outputs."""
|
"""Data structures representing attack inputs, configuration, outputs."""
|
||||||
import collections
|
import collections
|
||||||
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, Iterable, Union
|
from typing import Any, Iterable, Union
|
||||||
from dataclasses import dataclass
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from scipy import special
|
from scipy import special
|
||||||
|
@ -37,7 +38,7 @@ class SlicingFeature(enum.Enum):
|
||||||
CORRECTLY_CLASSIFIED = 'correctly_classified'
|
CORRECTLY_CLASSIFIED = 'correctly_classified'
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class SingleSliceSpec:
|
class SingleSliceSpec:
|
||||||
"""Specifies a slice.
|
"""Specifies a slice.
|
||||||
|
|
||||||
|
@ -64,7 +65,7 @@ class SingleSliceSpec:
|
||||||
return '%s=%s' % (self.feature.name, self.value)
|
return '%s=%s' % (self.feature.name, self.value)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class SlicingSpec:
|
class SlicingSpec:
|
||||||
"""Specification of a slicing procedure.
|
"""Specification of a slicing procedure.
|
||||||
|
|
||||||
|
@ -165,7 +166,7 @@ def _log_value(probs, small_value=1e-30):
|
||||||
return -np.log(np.maximum(probs, small_value))
|
return -np.log(np.maximum(probs, small_value))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class AttackInputData:
|
class AttackInputData:
|
||||||
"""Input data for running an attack.
|
"""Input data for running an attack.
|
||||||
|
|
||||||
|
@ -334,9 +335,11 @@ class AttackInputData:
|
||||||
'labels_train and labels_test should both be either set or unset')
|
'labels_train and labels_test should both be either set or unset')
|
||||||
|
|
||||||
if (self.labels_train is None and self.loss_train is None and
|
if (self.labels_train is None and self.loss_train is None and
|
||||||
self.logits_train is None and self.entropy_train is None):
|
self.logits_train is None and self.entropy_train is None and
|
||||||
|
self.probs_train is None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'At least one of labels, logits, losses or entropy should be set')
|
'At least one of labels, logits, losses, probabilities or entropy should be set'
|
||||||
|
)
|
||||||
|
|
||||||
if self.labels_train is not None and not _is_integer_type_array(
|
if self.labels_train is not None and not _is_integer_type_array(
|
||||||
self.labels_train):
|
self.labels_train):
|
||||||
|
@ -390,7 +393,7 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
|
||||||
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class RocCurve:
|
class RocCurve:
|
||||||
"""Represents ROC curve of a membership inference classifier."""
|
"""Represents ROC curve of a membership inference classifier."""
|
||||||
# Thresholds used to define points on ROC curve.
|
# Thresholds used to define points on ROC curve.
|
||||||
|
@ -433,7 +436,7 @@ class RocCurve:
|
||||||
DataSize = collections.namedtuple('DataSize', 'ntrain ntest')
|
DataSize = collections.namedtuple('DataSize', 'ntrain ntest')
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class SingleAttackResult:
|
class SingleAttackResult:
|
||||||
"""Results from running a single attack."""
|
"""Results from running a single attack."""
|
||||||
|
|
||||||
|
@ -488,7 +491,7 @@ class SingleAttackResult:
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class SingleMembershipProbabilityResult:
|
class SingleMembershipProbabilityResult:
|
||||||
"""Results from computing membership probabilities (denoted as privacy risk score in https://arxiv.org/abs/2003.10595).
|
"""Results from computing membership probabilities (denoted as privacy risk score in https://arxiv.org/abs/2003.10595).
|
||||||
|
|
||||||
|
@ -578,7 +581,7 @@ class SingleMembershipProbabilityResult:
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class MembershipProbabilityResults:
|
class MembershipProbabilityResults:
|
||||||
"""Membership probability results from multiple data slices."""
|
"""Membership probability results from multiple data slices."""
|
||||||
|
|
||||||
|
@ -593,7 +596,7 @@ class MembershipProbabilityResults:
|
||||||
return '\n'.join(summary)
|
return '\n'.join(summary)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class PrivacyReportMetadata:
|
class PrivacyReportMetadata:
|
||||||
"""Metadata about the evaluated model.
|
"""Metadata about the evaluated model.
|
||||||
|
|
||||||
|
@ -622,7 +625,7 @@ class AttackResultsDFColumns(enum.Enum):
|
||||||
return '%s' % self.value
|
return '%s' % self.value
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class AttackResults:
|
class AttackResults:
|
||||||
"""Results from running multiple attacks."""
|
"""Results from running multiple attacks."""
|
||||||
single_attack_results: Iterable[SingleAttackResult]
|
single_attack_results: Iterable[SingleAttackResult]
|
||||||
|
@ -759,7 +762,7 @@ class AttackResults:
|
||||||
return pickle.load(inp)
|
return pickle.load(inp)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class AttackResultsCollection:
|
class AttackResultsCollection:
|
||||||
"""A collection of AttackResults."""
|
"""A collection of AttackResults."""
|
||||||
attack_results_list: Iterable[AttackResults]
|
attack_results_list: Iterable[AttackResults]
|
||||||
|
|
Loading…
Reference in a new issue