forked from 626_privacy/tensorflow_privacy
Allows one to run a test on probabilities alone.
PiperOrigin-RevId: 409095932
This commit is contained in:
parent
9757e1bc87
commit
7c4f5bab09
1 changed files with 16 additions and 13 deletions
|
@ -15,12 +15,13 @@
|
|||
# Lint as: python3
|
||||
"""Data structures representing attack inputs, configuration, outputs."""
|
||||
import collections
|
||||
import dataclasses
|
||||
import enum
|
||||
import glob
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Iterable, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy import special
|
||||
|
@ -37,7 +38,7 @@ class SlicingFeature(enum.Enum):
|
|||
CORRECTLY_CLASSIFIED = 'correctly_classified'
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class SingleSliceSpec:
|
||||
"""Specifies a slice.
|
||||
|
||||
|
@ -64,7 +65,7 @@ class SingleSliceSpec:
|
|||
return '%s=%s' % (self.feature.name, self.value)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class SlicingSpec:
|
||||
"""Specification of a slicing procedure.
|
||||
|
||||
|
@ -165,7 +166,7 @@ def _log_value(probs, small_value=1e-30):
|
|||
return -np.log(np.maximum(probs, small_value))
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class AttackInputData:
|
||||
"""Input data for running an attack.
|
||||
|
||||
|
@ -334,9 +335,11 @@ class AttackInputData:
|
|||
'labels_train and labels_test should both be either set or unset')
|
||||
|
||||
if (self.labels_train is None and self.loss_train is None and
|
||||
self.logits_train is None and self.entropy_train is None):
|
||||
self.logits_train is None and self.entropy_train is None and
|
||||
self.probs_train is None):
|
||||
raise ValueError(
|
||||
'At least one of labels, logits, losses or entropy should be set')
|
||||
'At least one of labels, logits, losses, probabilities or entropy should be set'
|
||||
)
|
||||
|
||||
if self.labels_train is not None and not _is_integer_type_array(
|
||||
self.labels_train):
|
||||
|
@ -390,7 +393,7 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
|
|||
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class RocCurve:
|
||||
"""Represents ROC curve of a membership inference classifier."""
|
||||
# Thresholds used to define points on ROC curve.
|
||||
|
@ -433,7 +436,7 @@ class RocCurve:
|
|||
DataSize = collections.namedtuple('DataSize', 'ntrain ntest')
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class SingleAttackResult:
|
||||
"""Results from running a single attack."""
|
||||
|
||||
|
@ -488,7 +491,7 @@ class SingleAttackResult:
|
|||
])
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class SingleMembershipProbabilityResult:
|
||||
"""Results from computing membership probabilities (denoted as privacy risk score in https://arxiv.org/abs/2003.10595).
|
||||
|
||||
|
@ -578,7 +581,7 @@ class SingleMembershipProbabilityResult:
|
|||
return summary
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class MembershipProbabilityResults:
|
||||
"""Membership probability results from multiple data slices."""
|
||||
|
||||
|
@ -593,7 +596,7 @@ class MembershipProbabilityResults:
|
|||
return '\n'.join(summary)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class PrivacyReportMetadata:
|
||||
"""Metadata about the evaluated model.
|
||||
|
||||
|
@ -622,7 +625,7 @@ class AttackResultsDFColumns(enum.Enum):
|
|||
return '%s' % self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class AttackResults:
|
||||
"""Results from running multiple attacks."""
|
||||
single_attack_results: Iterable[SingleAttackResult]
|
||||
|
@ -759,7 +762,7 @@ class AttackResults:
|
|||
return pickle.load(inp)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class AttackResultsCollection:
|
||||
"""A collection of AttackResults."""
|
||||
attack_results_list: Iterable[AttackResults]
|
||||
|
|
Loading…
Reference in a new issue