diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index d9137b8..da21ebd 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -15,12 +15,13 @@ # Lint as: python3 """Data structures representing attack inputs, configuration, outputs.""" import collections +import dataclasses import enum import glob import os import pickle from typing import Any, Iterable, Union -from dataclasses import dataclass + import numpy as np import pandas as pd from scipy import special @@ -37,7 +38,7 @@ class SlicingFeature(enum.Enum): CORRECTLY_CLASSIFIED = 'correctly_classified' -@dataclass +@dataclasses.dataclass class SingleSliceSpec: """Specifies a slice. @@ -64,7 +65,7 @@ class SingleSliceSpec: return '%s=%s' % (self.feature.name, self.value) -@dataclass +@dataclasses.dataclass class SlicingSpec: """Specification of a slicing procedure. @@ -165,7 +166,7 @@ def _log_value(probs, small_value=1e-30): return -np.log(np.maximum(probs, small_value)) -@dataclass +@dataclasses.dataclass class AttackInputData: """Input data for running an attack. @@ -334,9 +335,11 @@ class AttackInputData: 'labels_train and labels_test should both be either set or unset') if (self.labels_train is None and self.loss_train is None and - self.logits_train is None and self.entropy_train is None): + self.logits_train is None and self.entropy_train is None and + self.probs_train is None): raise ValueError( - 'At least one of labels, logits, losses or entropy should be set') + 'At least one of labels, logits, losses, probabilities or entropy should be set' + ) if self.labels_train is not None and not _is_integer_type_array( self.labels_train): @@ -390,7 +393,7 @@ def _append_array_shape(arr: np.array, arr_name: str, result): result.append(' %s with shape: %s,' % (arr_name, arr.shape)) -@dataclass +@dataclasses.dataclass class RocCurve: """Represents ROC curve of a membership inference classifier.""" # Thresholds used to define points on ROC curve. @@ -433,7 +436,7 @@ class RocCurve: DataSize = collections.namedtuple('DataSize', 'ntrain ntest') -@dataclass +@dataclasses.dataclass class SingleAttackResult: """Results from running a single attack.""" @@ -488,7 +491,7 @@ class SingleAttackResult: ]) -@dataclass +@dataclasses.dataclass class SingleMembershipProbabilityResult: """Results from computing membership probabilities (denoted as privacy risk score in https://arxiv.org/abs/2003.10595). @@ -578,7 +581,7 @@ class SingleMembershipProbabilityResult: return summary -@dataclass +@dataclasses.dataclass class MembershipProbabilityResults: """Membership probability results from multiple data slices.""" @@ -593,7 +596,7 @@ class MembershipProbabilityResults: return '\n'.join(summary) -@dataclass +@dataclasses.dataclass class PrivacyReportMetadata: """Metadata about the evaluated model. @@ -622,7 +625,7 @@ class AttackResultsDFColumns(enum.Enum): return '%s' % self.value -@dataclass +@dataclasses.dataclass class AttackResults: """Results from running multiple attacks.""" single_attack_results: Iterable[SingleAttackResult] @@ -759,7 +762,7 @@ class AttackResults: return pickle.load(inp) -@dataclass +@dataclasses.dataclass class AttackResultsCollection: """A collection of AttackResults.""" attack_results_list: Iterable[AttackResults]