forked from 626_privacy/tensorflow_privacy
Overrides default __str__ methods.
PiperOrigin-RevId: 327423772
This commit is contained in:
parent
6dccd9b537
commit
52c1f8fdfe
1 changed files with 55 additions and 1 deletions
|
@ -82,6 +82,25 @@ class SlicingSpec:
|
|||
# examples will be generated.
|
||||
by_classification_correctness: bool = False
|
||||
|
||||
def __str__(self):
|
||||
"""Only keeps the True values."""
|
||||
result = ['SlicingSpec(']
|
||||
if self.entire_dataset:
|
||||
result.append(' Entire dataset,')
|
||||
if self.by_class:
|
||||
if isinstance(self.by_class, Iterable):
|
||||
result.append(' Into classes %s,' % self.by_class)
|
||||
elif isinstance(self.by_class, int):
|
||||
result.append(' Up to class %d,' % self.by_class)
|
||||
else:
|
||||
result.append(' By classes,')
|
||||
if self.by_percentiles:
|
||||
result.append(' By percentiles,')
|
||||
if self.by_classification_correctness:
|
||||
result.append(' By classification correctness,')
|
||||
result.append(')')
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
class AttackType(enum.Enum):
|
||||
"""An enum define attack types."""
|
||||
|
@ -96,8 +115,8 @@ class AttackType(enum.Enum):
|
|||
"""Returns whether this type of attack requires training a model."""
|
||||
return self != AttackType.THRESHOLD_ATTACK
|
||||
|
||||
# Return LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION
|
||||
def __str__(self):
|
||||
"""Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION."""
|
||||
return '%s' % self.name
|
||||
|
||||
|
||||
|
@ -220,6 +239,23 @@ class AttackInputData:
|
|||
_is_array_one_dimensional(self.labels_train, 'labels_train')
|
||||
_is_array_one_dimensional(self.labels_test, 'labels_test')
|
||||
|
||||
def __str__(self):
|
||||
"""Return the shapes of variables that are not None."""
|
||||
result = ['AttackInputData(']
|
||||
_append_array_shape(self.loss_train, 'loss_train', result)
|
||||
_append_array_shape(self.loss_test, 'loss_test', result)
|
||||
_append_array_shape(self.logits_train, 'logits_train', result)
|
||||
_append_array_shape(self.logits_test, 'logits_test', result)
|
||||
_append_array_shape(self.labels_train, 'labels_train', result)
|
||||
_append_array_shape(self.labels_test, 'labels_test', result)
|
||||
result.append(')')
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
def _append_array_shape(arr: np.array, arr_name: str, result):
|
||||
if arr is not None:
|
||||
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
||||
|
||||
|
||||
@dataclass
|
||||
class RocCurve:
|
||||
|
@ -251,6 +287,14 @@ class RocCurve:
|
|||
"""
|
||||
return max(np.abs(self.tpr - self.fpr))
|
||||
|
||||
def __str__(self):
|
||||
"""Returns AUC and advantage metrics."""
|
||||
return '\n'.join([
|
||||
'RocCurve(',
|
||||
' AUC: %f.02' % self.get_auc(),
|
||||
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
|
||||
])
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleAttackResult:
|
||||
|
@ -272,6 +316,16 @@ class SingleAttackResult:
|
|||
def get_auc(self):
|
||||
return self.roc_curve.get_auc()
|
||||
|
||||
def __str__(self):
|
||||
"""Returns SliceSpec, AttackType, AUC and advantage metrics."""
|
||||
return '\n'.join([
|
||||
'SingleAttackResult(',
|
||||
' SliceSpec: %s' % str(self.slice_spec),
|
||||
' AttackType: %s' % str(self.attack_type),
|
||||
' AUC: %f.02' % self.get_auc(),
|
||||
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
|
||||
])
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttackResults:
|
||||
|
|
Loading…
Reference in a new issue