forked from 626_privacy/tensorflow_privacy
Overrides default __str__ methods.
PiperOrigin-RevId: 327423772
This commit is contained in:
parent
6dccd9b537
commit
52c1f8fdfe
1 changed files with 55 additions and 1 deletions
|
@ -82,6 +82,25 @@ class SlicingSpec:
|
||||||
# examples will be generated.
|
# examples will be generated.
|
||||||
by_classification_correctness: bool = False
|
by_classification_correctness: bool = False
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Only keeps the True values."""
|
||||||
|
result = ['SlicingSpec(']
|
||||||
|
if self.entire_dataset:
|
||||||
|
result.append(' Entire dataset,')
|
||||||
|
if self.by_class:
|
||||||
|
if isinstance(self.by_class, Iterable):
|
||||||
|
result.append(' Into classes %s,' % self.by_class)
|
||||||
|
elif isinstance(self.by_class, int):
|
||||||
|
result.append(' Up to class %d,' % self.by_class)
|
||||||
|
else:
|
||||||
|
result.append(' By classes,')
|
||||||
|
if self.by_percentiles:
|
||||||
|
result.append(' By percentiles,')
|
||||||
|
if self.by_classification_correctness:
|
||||||
|
result.append(' By classification correctness,')
|
||||||
|
result.append(')')
|
||||||
|
return '\n'.join(result)
|
||||||
|
|
||||||
|
|
||||||
class AttackType(enum.Enum):
|
class AttackType(enum.Enum):
|
||||||
"""An enum define attack types."""
|
"""An enum define attack types."""
|
||||||
|
@ -96,8 +115,8 @@ class AttackType(enum.Enum):
|
||||||
"""Returns whether this type of attack requires training a model."""
|
"""Returns whether this type of attack requires training a model."""
|
||||||
return self != AttackType.THRESHOLD_ATTACK
|
return self != AttackType.THRESHOLD_ATTACK
|
||||||
|
|
||||||
# Return LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
"""Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION."""
|
||||||
return '%s' % self.name
|
return '%s' % self.name
|
||||||
|
|
||||||
|
|
||||||
|
@ -220,6 +239,23 @@ class AttackInputData:
|
||||||
_is_array_one_dimensional(self.labels_train, 'labels_train')
|
_is_array_one_dimensional(self.labels_train, 'labels_train')
|
||||||
_is_array_one_dimensional(self.labels_test, 'labels_test')
|
_is_array_one_dimensional(self.labels_test, 'labels_test')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Return the shapes of variables that are not None."""
|
||||||
|
result = ['AttackInputData(']
|
||||||
|
_append_array_shape(self.loss_train, 'loss_train', result)
|
||||||
|
_append_array_shape(self.loss_test, 'loss_test', result)
|
||||||
|
_append_array_shape(self.logits_train, 'logits_train', result)
|
||||||
|
_append_array_shape(self.logits_test, 'logits_test', result)
|
||||||
|
_append_array_shape(self.labels_train, 'labels_train', result)
|
||||||
|
_append_array_shape(self.labels_test, 'labels_test', result)
|
||||||
|
result.append(')')
|
||||||
|
return '\n'.join(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _append_array_shape(arr: np.array, arr_name: str, result):
|
||||||
|
if arr is not None:
|
||||||
|
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RocCurve:
|
class RocCurve:
|
||||||
|
@ -251,6 +287,14 @@ class RocCurve:
|
||||||
"""
|
"""
|
||||||
return max(np.abs(self.tpr - self.fpr))
|
return max(np.abs(self.tpr - self.fpr))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Returns AUC and advantage metrics."""
|
||||||
|
return '\n'.join([
|
||||||
|
'RocCurve(',
|
||||||
|
' AUC: %f.02' % self.get_auc(),
|
||||||
|
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SingleAttackResult:
|
class SingleAttackResult:
|
||||||
|
@ -272,6 +316,16 @@ class SingleAttackResult:
|
||||||
def get_auc(self):
|
def get_auc(self):
|
||||||
return self.roc_curve.get_auc()
|
return self.roc_curve.get_auc()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Returns SliceSpec, AttackType, AUC and advantage metrics."""
|
||||||
|
return '\n'.join([
|
||||||
|
'SingleAttackResult(',
|
||||||
|
' SliceSpec: %s' % str(self.slice_spec),
|
||||||
|
' AttackType: %s' % str(self.attack_type),
|
||||||
|
' AUC: %f.02' % self.get_auc(),
|
||||||
|
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AttackResults:
|
class AttackResults:
|
||||||
|
|
Loading…
Reference in a new issue