Allows one to run a test on probabilities alone.

PiperOrigin-RevId: 409095932
This commit is contained in:
David Marn 2021-11-11 02:16:31 -08:00 committed by A. Unique TensorFlower
parent 9757e1bc87
commit 7c4f5bab09

View file

@ -15,12 +15,13 @@
# Lint as: python3 # Lint as: python3
"""Data structures representing attack inputs, configuration, outputs.""" """Data structures representing attack inputs, configuration, outputs."""
import collections import collections
import dataclasses
import enum import enum
import glob import glob
import os import os
import pickle import pickle
from typing import Any, Iterable, Union from typing import Any, Iterable, Union
from dataclasses import dataclass
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from scipy import special from scipy import special
@ -37,7 +38,7 @@ class SlicingFeature(enum.Enum):
CORRECTLY_CLASSIFIED = 'correctly_classified' CORRECTLY_CLASSIFIED = 'correctly_classified'
@dataclass @dataclasses.dataclass
class SingleSliceSpec: class SingleSliceSpec:
"""Specifies a slice. """Specifies a slice.
@ -64,7 +65,7 @@ class SingleSliceSpec:
return '%s=%s' % (self.feature.name, self.value) return '%s=%s' % (self.feature.name, self.value)
@dataclass @dataclasses.dataclass
class SlicingSpec: class SlicingSpec:
"""Specification of a slicing procedure. """Specification of a slicing procedure.
@ -165,7 +166,7 @@ def _log_value(probs, small_value=1e-30):
return -np.log(np.maximum(probs, small_value)) return -np.log(np.maximum(probs, small_value))
@dataclass @dataclasses.dataclass
class AttackInputData: class AttackInputData:
"""Input data for running an attack. """Input data for running an attack.
@ -334,9 +335,11 @@ class AttackInputData:
'labels_train and labels_test should both be either set or unset') 'labels_train and labels_test should both be either set or unset')
if (self.labels_train is None and self.loss_train is None and if (self.labels_train is None and self.loss_train is None and
self.logits_train is None and self.entropy_train is None): self.logits_train is None and self.entropy_train is None and
self.probs_train is None):
raise ValueError( raise ValueError(
'At least one of labels, logits, losses or entropy should be set') 'At least one of labels, logits, losses, probabilities or entropy should be set'
)
if self.labels_train is not None and not _is_integer_type_array( if self.labels_train is not None and not _is_integer_type_array(
self.labels_train): self.labels_train):
@ -390,7 +393,7 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
result.append(' %s with shape: %s,' % (arr_name, arr.shape)) result.append(' %s with shape: %s,' % (arr_name, arr.shape))
@dataclass @dataclasses.dataclass
class RocCurve: class RocCurve:
"""Represents ROC curve of a membership inference classifier.""" """Represents ROC curve of a membership inference classifier."""
# Thresholds used to define points on ROC curve. # Thresholds used to define points on ROC curve.
@ -433,7 +436,7 @@ class RocCurve:
DataSize = collections.namedtuple('DataSize', 'ntrain ntest') DataSize = collections.namedtuple('DataSize', 'ntrain ntest')
@dataclass @dataclasses.dataclass
class SingleAttackResult: class SingleAttackResult:
"""Results from running a single attack.""" """Results from running a single attack."""
@ -488,7 +491,7 @@ class SingleAttackResult:
]) ])
@dataclass @dataclasses.dataclass
class SingleMembershipProbabilityResult: class SingleMembershipProbabilityResult:
"""Results from computing membership probabilities (denoted as privacy risk score in https://arxiv.org/abs/2003.10595). """Results from computing membership probabilities (denoted as privacy risk score in https://arxiv.org/abs/2003.10595).
@ -578,7 +581,7 @@ class SingleMembershipProbabilityResult:
return summary return summary
@dataclass @dataclasses.dataclass
class MembershipProbabilityResults: class MembershipProbabilityResults:
"""Membership probability results from multiple data slices.""" """Membership probability results from multiple data slices."""
@ -593,7 +596,7 @@ class MembershipProbabilityResults:
return '\n'.join(summary) return '\n'.join(summary)
@dataclass @dataclasses.dataclass
class PrivacyReportMetadata: class PrivacyReportMetadata:
"""Metadata about the evaluated model. """Metadata about the evaluated model.
@ -622,7 +625,7 @@ class AttackResultsDFColumns(enum.Enum):
return '%s' % self.value return '%s' % self.value
@dataclass @dataclasses.dataclass
class AttackResults: class AttackResults:
"""Results from running multiple attacks.""" """Results from running multiple attacks."""
single_attack_results: Iterable[SingleAttackResult] single_attack_results: Iterable[SingleAttackResult]
@ -759,7 +762,7 @@ class AttackResults:
return pickle.load(inp) return pickle.load(inp)
@dataclass @dataclasses.dataclass
class AttackResultsCollection: class AttackResultsCollection:
"""A collection of AttackResults.""" """A collection of AttackResults."""
attack_results_list: Iterable[AttackResults] attack_results_list: Iterable[AttackResults]